|
428 | 428 | ], |
429 | 429 | "source": [ |
430 | 430 | "import torch\n", |
431 | | - "import torchvision\n", |
432 | 431 | "\n", |
433 | 432 | "torch.hub._validate_not_a_forked_repo=lambda a,b,c: True\n", |
434 | 433 | "\n", |
|
558 | 557 | "from PIL import Image\n", |
559 | 558 | "from torchvision import transforms\n", |
560 | 559 | "import matplotlib.pyplot as plt\n", |
561 | | - "import json \n", |
| 560 | + "import json\n", |
562 | 561 | "\n", |
563 | 562 | "fig, axes = plt.subplots(nrows=2, ncols=2)\n", |
564 | 563 | "\n", |
|
571 | 570 | " transforms.ToTensor(),\n", |
572 | 571 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", |
573 | 572 | " ])\n", |
574 | | - " input_tensor = preprocess(img) \n", |
| 573 | + " input_tensor = preprocess(img)\n", |
575 | 574 | " plt.subplot(2,2,i+1)\n", |
576 | 575 | " plt.imshow(img)\n", |
577 | 576 | " plt.axis('off')\n", |
578 | 577 | "\n", |
579 | | - "# loading labels \n", |
580 | | - "with open(\"./data/imagenet_class_index.json\") as json_file: \n", |
| 578 | + "# loading labels\n", |
| 579 | + "with open(\"./data/imagenet_class_index.json\") as json_file:\n", |
581 | 580 | " d = json.load(json_file)" |
582 | 581 | ] |
583 | 582 | }, |
|
614 | 613 | " preprocess = rn50_preprocess()\n", |
615 | 614 | " input_tensor = preprocess(img)\n", |
616 | 615 | " input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model\n", |
617 | | - " \n", |
| 616 | + "\n", |
618 | 617 | " # move the input and model to GPU for speed if available\n", |
619 | 618 | " if torch.cuda.is_available():\n", |
620 | 619 | " input_batch = input_batch.to('cuda')\n", |
|
624 | 623 | " output = model(input_batch)\n", |
625 | 624 | " # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes\n", |
626 | 625 | " sm_output = torch.nn.functional.softmax(output[0], dim=0)\n", |
627 | | - " \n", |
| 626 | + "\n", |
628 | 627 | " ind = torch.argmax(sm_output)\n", |
629 | 628 | " return d[str(ind.item())], sm_output[ind] #([predicted class, description], probability)\n", |
630 | 629 | "\n", |
|
633 | 632 | " input_data = input_data.to(\"cuda\")\n", |
634 | 633 | " if dtype=='fp16':\n", |
635 | 634 | " input_data = input_data.half()\n", |
636 | | - " \n", |
| 635 | + "\n", |
637 | 636 | " print(\"Warm up ...\")\n", |
638 | 637 | " with torch.no_grad():\n", |
639 | 638 | " for _ in range(nwarmup):\n", |
|
695 | 694 | "for i in range(4):\n", |
696 | 695 | " img_path = './data/img%d.JPG'%i\n", |
697 | 696 | " img = Image.open(img_path)\n", |
698 | | - " \n", |
| 697 | + "\n", |
699 | 698 | " pred, prob = predict(img_path, resnet50_model)\n", |
700 | 699 | " print('{} - Predicted: {}, Probablility: {}'.format(img_path, pred, prob))\n", |
701 | 700 | "\n", |
702 | 701 | " plt.subplot(2,2,i+1)\n", |
703 | | - " plt.imshow(img);\n", |
704 | | - " plt.axis('off');\n", |
| 702 | + " plt.imshow(img)\n", |
| 703 | + " plt.axis('off')\n", |
705 | 704 | " plt.title(pred[1])" |
706 | 705 | ] |
707 | 706 | }, |
|
0 commit comments