Build a Pytorch Server with celery and RabbitMQ

Tools intro

  1. Celery: Celery is an asynchronous task queue/job queue based on distributed message passing
  2. RabbitMQ: RabbitMQ is the most widely deployed open source message broker.
  3. PyTorch: deep learning framework used here.

Server side

We are going to use a toy MNIST model here.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class ToyNet(nn.Module):
def __init__(self):
super(ToyNet, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch as th
from celery import Celery
from mnist_model import ToyNet
import json
import cv2
import sys


app = Celery('tasks', backend='amqp', broker='amqp://')
model = ToyNet()
static_dict = th.load('./model/toynet.pth.tar')
model.load_state_dict(static_dict)
print('hope this is done only once')


@app.task
def inference(json_str):
task_spec = json.loads(json_str)
try:
img_path = task_spec['img_path']
is_cuda = task_spec['is_cuda']
is_file = task_spec['is_file']
except KeyError as err:
print('Key not found in json file.')
print(err)
res = dict(massage='Read Image Error')
json_res = json.dumps(res)
return json_res
except:
print('Json load error.')
res = dict(massage='Read Image Error')
json_res = json.dumps(res)
return json_res
if is_file:
try:
img = cv2.imread(img_path, 0)[:, :, None]
except:
print(sys.exc_info()[0])
res = dict(massage='Read Image Error')
json_res = json.dumps(res)
return json_res

if is_cuda:
img = th.Tensor(img[None, :, :, :]).permute(0, 3, 1, 2).cuda().float()
else:
img = th.Tensor(img[None, :, :, :]).permute(0, 3, 1, 2).float()

res = model(img)

if is_cuda:
res = res.cpu()

res = th.argmax(res, dim=0, keepdim=False)

res = dict(
message='success',
res=res
)

json_res = json.dumps(res)
return json_res

Run

1
celery -A tasks worker --loglevel=info

Test

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import json
import th_task
import time


if __name__ == '__main__':
task_t = time.time()
spec = dict(
img_path='./test.jpg',
is_cuda=False,
is_file=True
)
json_str = json.dumps(spec)
res_list = []
for i in range(10):
res = th_task.inference.delay(json_str)
res_list.append(res)
print('-----------')
print(time.time()-task_t)
print('-' * 10)
for k, t in enumerate(res_list):
x = t.wait()
print(time.time()-task_t)