Skip to content

Commit 99e3176

Browse files
authored
Update pytorch examples to use gpu if device is set (#849)
1 parent ece1bfb commit 99e3176

File tree

7 files changed

+22
-11
lines changed

7 files changed

+22
-11
lines changed

examples/pytorch/image-classifier/cortex.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
predictor:
55
type: python
66
path: predictor.py
7+
config:
8+
device: cuda # use "cpu" to run on CPUs
79
compute:
810
cpu: 1
911
gpu: 1

examples/pytorch/image-classifier/predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class PythonPredictor:
1212
def __init__(self, config):
13-
model = torchvision.models.alexnet(pretrained=True)
13+
model = torchvision.models.alexnet(pretrained=True).to(config["device"])
1414
model.eval()
1515
# https://github.com/pytorch/examples/blob/447974f6337543d4de6b888e244a964d3c9b71f6/imagenet/main.py#L198-L199
1616
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
@@ -22,12 +22,14 @@ def __init__(self, config):
2222
"https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"
2323
).text.split("\n")[1:]
2424
self.model = model
25+
self.device = config["device"]
2526

2627
def predict(self, payload):
2728
image = requests.get(payload["url"]).content
2829
img_pil = Image.open(BytesIO(image))
2930
img_tensor = self.preprocess(img_pil)
3031
img_tensor.unsqueeze_(0)
32+
img_tensor = img_tensor.to(self.device)
3133
with torch.no_grad():
3234
prediction = self.model(img_tensor)
3335
_, index = prediction[0].max(0)

examples/pytorch/language-identifier/cortex.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,3 @@
66
path: predictor.py
77
tracker:
88
model_type: classification
9-
compute:
10-
cpu: 1
11-
gpu: 1
12-
mem: 4G

examples/pytorch/object-detector/cortex.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
predictor:
55
type: python
66
path: predictor.py
7+
config:
8+
device: cuda # use "cpu" to run on CPUs
79
compute:
810
cpu: 1
911
gpu: 1

examples/pytorch/object-detector/predictor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
class PythonPredictor:
1313
def __init__(self, config):
14-
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
14+
self.device = config["device"]
15+
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True).to(self.device)
1516
model.eval()
1617

1718
self.preprocess = transforms.Compose([transforms.ToTensor()])
@@ -25,17 +26,17 @@ def predict(self, payload):
2526
threshold = float(payload["threshold"])
2627
image = requests.get(payload["url"]).content
2728
img_pil = Image.open(BytesIO(image))
28-
img_tensor = self.preprocess(img_pil)
29+
img_tensor = self.preprocess(img_pil).to(self.device)
2930
img_tensor.unsqueeze_(0)
3031

3132
with torch.no_grad():
3233
pred = self.model(img_tensor)
3334

34-
predicted_class = [self.coco_labels[i] for i in list(pred[0]["labels"].numpy())]
35+
predicted_class = [self.coco_labels[i] for i in list(pred[0]["labels"].cpu().numpy())]
3536
predicted_boxes = [
36-
[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]["boxes"].detach().numpy())
37+
[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]["boxes"].detach().cpu().numpy())
3738
]
38-
predicted_score = list(pred[0]["scores"].detach().numpy())
39+
predicted_score = list(pred[0]["scores"].detach().cpu().numpy())
3940
predicted_t = [predicted_score.index(x) for x in predicted_score if x > threshold]
4041
if len(predicted_t) == 0:
4142
return [], []

examples/pytorch/reading-comprehender/cortex.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
predictor:
55
type: python
66
path: predictor.py
7+
config:
8+
device: cuda # use "cpu" to run on CPUs
79
compute:
810
cpu: 1
911
gpu: 1

examples/pytorch/reading-comprehender/predictor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@
55

66
class PythonPredictor:
77
def __init__(self, config):
8+
9+
cuda_device = -1
10+
if config["device"] == "cuda":
11+
cuda_device = 0
12+
813
self.predictor = AllenNLPPredictor.from_path(
9-
"https://storage.googleapis.com/allennlp-public-models/bidaf-elmo-model-2018.11.30-charpad.tar.gz"
14+
"https://storage.googleapis.com/allennlp-public-models/bidaf-elmo-model-2018.11.30-charpad.tar.gz",
15+
cuda_device=cuda_device,
1016
)
1117

1218
def predict(self, payload):

0 commit comments

Comments
 (0)