Skip to content

Commit 4fe6034

Browse files
committed
Updated examples
1 parent caf04f5 commit 4fe6034

24 files changed

+1680
-507
lines changed

docs/examples/plot_object_detection_simple.py renamed to docs/examples/plot_object_detection_checkpoint.py

Lines changed: 92 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,54 @@
11
#!/usr/bin/env python
22
# coding: utf-8
33
"""
4-
Object Detection Test
5-
=====================
4+
Object Detection From TF2 Checkpoint
5+
====================================
66
"""
77

88
# %%
9-
# This demo will take you through the steps of running an "out-of-the-box" detection model on a
10-
# collection of images.
11-
12-
# %%
13-
# Create the data directory
14-
# ~~~~~~~~~~~~~~~~~~~~~~~~~
15-
# The snippet shown below will create the ``data`` directory where all our data will be stored. The
16-
# code will create a directory structure as shown bellow:
17-
#
18-
# .. code-block:: bash
19-
#
20-
# data
21-
# ├── images
22-
# └── models
23-
#
24-
# where the ``images`` folder will contain the downlaoded test images, while ``models`` will
25-
# contain the downloaded models.
26-
import os
27-
28-
DATA_DIR = os.path.join(os.getcwd(), 'data')
29-
IMAGES_DIR = os.path.join(DATA_DIR, 'images')
30-
MODELS_DIR = os.path.join(DATA_DIR, 'models')
31-
for dir in [DATA_DIR, IMAGES_DIR, MODELS_DIR]:
32-
if not os.path.exists(dir):
33-
os.mkdir(dir)
9+
# This demo will take you through the steps of running an "out-of-the-box" TensorFlow 2 compatible
10+
# detection model on a collection of images. More specifically, in this example we will be using
11+
# the `Checkpoint Format <https://www.tensorflow.org/guide/checkpoint>`__ to load the model.
3412

3513
# %%
3614
# Download the test images
3715
# ~~~~~~~~~~~~~~~~~~~~~~~~
3816
# First we will download the images that we will use throughout this tutorial. The code snippet
3917
# shown bellow will download the test images from the `TensorFlow Model Garden <https://github.com/tensorflow/models/tree/master/research/object_detection/test_images>`_
4018
# and save them inside the ``data/images`` folder.
41-
import urllib.request
19+
import os
20+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1)
21+
import pathlib
22+
import tensorflow as tf
4223

43-
IMAGE_FILENAMES = ['image1.jpg', 'image2.jpg']
44-
IMAGES_DOWNLOAD_BASE = \
45-
'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/test_images/'
24+
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow logging (2)
4625

47-
for image_filename in IMAGE_FILENAMES:
26+
# Enable GPU dynamic memory allocation
27+
gpus = tf.config.experimental.list_physical_devices('GPU')
28+
for gpu in gpus:
29+
tf.config.experimental.set_memory_growth(gpu, True)
4830

49-
image_path = os.path.join(IMAGES_DIR, image_filename)
31+
def download_images():
32+
base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/test_images/'
33+
filenames = ['image1.jpg', 'image2.jpg']
34+
image_paths = []
35+
for filename in filenames:
36+
image_path = tf.keras.utils.get_file(fname=filename,
37+
origin=base_url + filename,
38+
untar=False)
39+
image_path = pathlib.Path(image_path)
40+
image_paths.append(str(image_path))
41+
return image_paths
5042

51-
# Download image
52-
if not os.path.exists(image_path):
53-
print('Downloading {}... '.format(image_filename), end='')
54-
urllib.request.urlretrieve(IMAGES_DOWNLOAD_BASE + image_filename, image_path)
55-
print('Done')
43+
IMAGE_PATHS = download_images()
5644

5745

5846
# %%
5947
# Download the model
6048
# ~~~~~~~~~~~~~~~~~~
61-
# The code snippet shown below is used to download the object detection model checkpoint file,
62-
# as well as the labels file (.pbtxt) which contains a list of strings used to add the correct
63-
# label to each detection (e.g. person). Once downloaded the files will be stored under the
64-
# ``data/models`` folder.
65-
#
66-
# The particular detection algorithm we will use is the `CenterNet HourGlass104 1024x1024`. More
67-
# models can be found in the `TensorFlow 2 Detection Model Zoo <https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md>`_.
49+
# The code snippet shown below is used to download the pre-trained object detection model we shall
50+
# use to perform inference. The particular detection algorithm we will use is the
51+
# `CenterNet HourGlass104 1024x1024`. More models can be found in the `TensorFlow 2 Detection Model Zoo <https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md>`_.
6852
# To use a different model you will need the URL name of the specific model. This can be done as
6953
# follows:
7054
#
@@ -76,62 +60,63 @@
7660
#
7761
# For example, the download link for the model used below is: ``download.tensorflow.org/models/object_detection/tf2/20200711/centernet_hg104_1024x1024_coco17_tpu-32.tar.gz``
7862

79-
import tarfile
80-
8163
# Download and extract model
64+
def download_model(model_name, model_date):
65+
base_url = 'http://download.tensorflow.org/models/object_detection/tf2/'
66+
model_file = model_name + '.tar.gz'
67+
model_dir = tf.keras.utils.get_file(fname=model_name,
68+
origin=base_url + model_date + '/' + model_file,
69+
untar=True)
70+
return str(model_dir)
71+
8272
MODEL_DATE = '20200711'
8373
MODEL_NAME = 'centernet_hg104_1024x1024_coco17_tpu-32'
84-
MODEL_TAR_FILENAME = MODEL_NAME + '.tar.gz'
85-
MODELS_DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/tf2/'
86-
MODEL_DOWNLOAD_LINK = MODELS_DOWNLOAD_BASE + MODEL_DATE + '/' + MODEL_TAR_FILENAME
87-
PATH_TO_MODEL_TAR = os.path.join(MODELS_DIR, MODEL_TAR_FILENAME)
88-
PATH_TO_CKPT = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, 'checkpoint/'))
89-
PATH_TO_CFG = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, 'pipeline.config'))
90-
if not os.path.exists(PATH_TO_CKPT):
91-
print('Downloading model. This may take a while... ', end='')
92-
urllib.request.urlretrieve(MODEL_DOWNLOAD_LINK, PATH_TO_MODEL_TAR)
93-
tar_file = tarfile.open(PATH_TO_MODEL_TAR)
94-
tar_file.extractall(MODELS_DIR)
95-
tar_file.close()
96-
os.remove(PATH_TO_MODEL_TAR)
97-
print('Done')
74+
PATH_TO_MODEL_DIR = download_model(MODEL_NAME, MODEL_DATE)
75+
76+
# %%
77+
# Download the labels
78+
# ~~~~~~~~~~~~~~~~~~~
79+
# The coode snippet shown below is used to download the labels file (.pbtxt) which contains a list
80+
# of strings used to add the correct label to each detection (e.g. person). Since the pre-trained
81+
# model we will use has been trained on the COCO dataset, we will need to download the labels file
82+
# corresponding to this dataset, named ``mscoco_label_map.pbtxt``. A full list of the labels files
83+
# included in the TensorFlow Models Garden can be found `here <https://github.com/tensorflow/models/tree/master/research/object_detection/data>`__.
9884

9985
# Download labels file
86+
def download_labels(filename):
87+
base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/'
88+
label_dir = tf.keras.utils.get_file(fname=filename,
89+
origin=base_url + filename,
90+
untar=False)
91+
label_dir = pathlib.Path(label_dir)
92+
return str(label_dir)
93+
10094
LABEL_FILENAME = 'mscoco_label_map.pbtxt'
101-
LABELS_DOWNLOAD_BASE = \
102-
'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/'
103-
PATH_TO_LABELS = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, LABEL_FILENAME))
104-
if not os.path.exists(PATH_TO_LABELS):
105-
print('Downloading label file... ', end='')
106-
urllib.request.urlretrieve(LABELS_DOWNLOAD_BASE + LABEL_FILENAME, PATH_TO_LABELS)
107-
print('Done')
95+
PATH_TO_LABELS = download_labels(LABEL_FILENAME)
10896

10997
# %%
11098
# Load the model
11199
# ~~~~~~~~~~~~~~
112100
# Next we load the downloaded model
113-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1)
114-
import tensorflow as tf
101+
import time
115102
from object_detection.utils import label_map_util
116103
from object_detection.utils import config_util
117104
from object_detection.utils import visualization_utils as viz_utils
118105
from object_detection.builders import model_builder
119106

120-
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow logging (2)
107+
PATH_TO_CFG = PATH_TO_MODEL_DIR + "/pipeline.config"
108+
PATH_TO_CKPT = PATH_TO_MODEL_DIR + "/checkpoint"
121109

122-
# Enable GPU dynamic memory allocation
123-
gpus = tf.config.experimental.list_physical_devices('GPU')
124-
for gpu in gpus:
125-
tf.config.experimental.set_memory_growth(gpu, True)
110+
print('Loading model... ', end='')
111+
start_time = time.time()
126112

127113
# Load pipeline config and build a detection model
128114
configs = config_util.get_configs_from_pipeline_file(PATH_TO_CFG)
129115
model_config = configs['model']
130116
detection_model = model_builder.build(model_config=model_config, is_training=False)
131117

132118
# Restore checkpoint
133-
ckpt = tf.compat.v2.train.Checkpoint(
134-
model=detection_model)
119+
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
135120
ckpt.restore(os.path.join(PATH_TO_CKPT, 'ckpt-0')).expect_partial()
136121

137122
@tf.function
@@ -142,8 +127,11 @@ def detect_fn(image):
142127
prediction_dict = detection_model.predict(image, shapes)
143128
detections = detection_model.postprocess(prediction_dict, shapes)
144129

145-
return detections, prediction_dict, tf.reshape(shapes, [-1])
130+
return detections
146131

132+
end_time = time.time()
133+
elapsed_time = end_time - start_time
134+
print('Done! Took {} seconds'.format(elapsed_time))
147135

148136
# %%
149137
# Load label map data (for plotting)
@@ -172,7 +160,6 @@ def detect_fn(image):
172160
# * Print out `detections['detection_boxes']` and try to match the box locations to the boxes in the image. Notice that coordinates are given in normalized form (i.e., in the interval [0, 1]).
173161
# * Set ``min_score_thresh`` to other values (between 0 and 1) to allow more detections in or to filter out more detections.
174162
import numpy as np
175-
from six import BytesIO
176163
from PIL import Image
177164
import matplotlib.pyplot as plt
178165
import warnings
@@ -191,18 +178,13 @@ def load_image_into_numpy_array(path):
191178
Returns:
192179
uint8 numpy array with shape (img_height, img_width, 3)
193180
"""
194-
img_data = tf.io.gfile.GFile(path, 'rb').read()
195-
image = Image.open(BytesIO(img_data))
196-
(im_width, im_height) = image.size
197-
return np.array(image.getdata()).reshape(
198-
(im_height, im_width, 3)).astype(np.uint8)
181+
return np.array(Image.open(path))
199182

200183

201-
for image_filename in IMAGE_FILENAMES:
184+
for image_path in IMAGE_PATHS:
202185

203-
print('Running inference for {}... '.format(image_filename), end='')
186+
print('Running inference for {}... '.format(image_path), end='')
204187

205-
image_path = os.path.join(IMAGES_DIR, image_filename)
206188
image_np = load_image_into_numpy_array(image_path)
207189

208190
# Things to try:
@@ -213,23 +195,34 @@ def load_image_into_numpy_array(path):
213195
# image_np = np.tile(
214196
# np.mean(image_np, 2, keepdims=True), (1, 1, 3)).astype(np.uint8)
215197

216-
input_tensor = tf.convert_to_tensor(
217-
np.expand_dims(image_np, 0), dtype=tf.float32)
218-
detections, predictions_dict, shapes = detect_fn(input_tensor)
198+
input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
199+
200+
detections = detect_fn(input_tensor)
201+
202+
# All outputs are batches tensors.
203+
# Convert to numpy arrays, and take index [0] to remove the batch dimension.
204+
# We're only interested in the first num_detections.
205+
num_detections = int(detections.pop('num_detections'))
206+
detections = {key: value[0, :num_detections].numpy()
207+
for key, value in detections.items()}
208+
detections['num_detections'] = num_detections
209+
210+
# detection_classes should be ints.
211+
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)
219212

220213
label_id_offset = 1
221214
image_np_with_detections = image_np.copy()
222215

223216
viz_utils.visualize_boxes_and_labels_on_image_array(
224-
image_np_with_detections,
225-
detections['detection_boxes'][0].numpy(),
226-
(detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
227-
detections['detection_scores'][0].numpy(),
228-
category_index,
229-
use_normalized_coordinates=True,
230-
max_boxes_to_draw=200,
231-
min_score_thresh=.30,
232-
agnostic_mode=False)
217+
image_np_with_detections,
218+
detections['detection_boxes'],
219+
detections['detection_classes']+label_id_offset,
220+
detections['detection_scores'],
221+
category_index,
222+
use_normalized_coordinates=True,
223+
max_boxes_to_draw=200,
224+
min_score_thresh=.30,
225+
agnostic_mode=False)
233226

234227
plt.figure()
235228
plt.imshow(image_np_with_detections)

0 commit comments

Comments
 (0)