Skip to content

Commit 68c9911

Browse files
author
Casey Hong
authored
FairMOT-01 (#553)
1 parent aad3637 commit 68c9911

File tree

9 files changed

+449
-147
lines changed

9 files changed

+449
-147
lines changed

scenarios/tracking/01_training_introduction.ipynb

Lines changed: 194 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"cell_type": "markdown",
2121
"metadata": {},
2222
"source": [
23-
"## Initialization"
23+
"## 00 Initialization"
2424
]
2525
},
2626
{
@@ -45,22 +45,18 @@
4545
"sys.path.append(\"../../\")\n",
4646
"\n",
4747
"import os\n",
48+
"import os.path as osp\n",
4849
"import time\n",
50+
"from ipywidgets import Video\n",
4951
"import matplotlib.pyplot as plt\n",
50-
"from typing import Iterator\n",
51-
"from pathlib import Path\n",
52-
"from PIL import Image\n",
53-
"from random import randrange\n",
54-
"from typing import Tuple\n",
5552
"import torch\n",
5653
"import torchvision\n",
57-
"from torchvision import transforms\n",
58-
"import scrapbook as sb\n",
5954
"\n",
60-
"from ipywidgets import Video\n",
55+
"from utils_cv.tracking.data import Urls\n",
6156
"from utils_cv.tracking.dataset import TrackingDataset\n",
62-
"from utils_cv.tracking.model import TrackingLearner\n",
57+
"from utils_cv.tracking.model import TrackingLearner, write_video\n",
6358
"\n",
59+
"from utils_cv.common.data import data_path, download, unzip_url\n",
6460
"from utils_cv.common.gpu import which_processor, is_windows\n",
6561
"\n",
6662
"# Change matplotlib backend so that plots are shown for windows\n",
@@ -115,10 +111,22 @@
115111
}
116112
],
117113
"source": [
118-
"EPOCHS = 2\n",
114+
"EPOCHS = 1\n",
119115
"LEARNING_RATE = 0.0001\n",
120116
"BATCH_SIZE = 1\n",
117+
"\n",
121118
"SAVE_MODEL = True\n",
119+
"FRAME_RATE = 30\n",
120+
"\n",
121+
"CONF_THRES = 0.3\n",
122+
"TRACK_BUFFER = 300\n",
123+
"IM_SIZE = (1080, 1920)\n",
124+
"\n",
125+
"TRAIN_DATA_PATH = unzip_url(Urls.fridge_objects_path, exist_ok=True)\n",
126+
"EVAL_DATA_PATH = unzip_url(Urls.carcans_annotations_path, exist_ok=True)\n",
127+
"\n",
128+
"BASELINE_MODEL = \"./models/all_dla34_new.pth\"\n",
129+
"FT_MODEL = \"./models/model_30.pth\"\n",
122130
"\n",
123131
"# train on the GPU or on the CPU, if a GPU is not available\n",
124132
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
@@ -129,43 +137,20 @@
129137
"cell_type": "markdown",
130138
"metadata": {},
131139
"source": [
132-
"## Prepare Training Dataset"
133-
]
134-
},
135-
{
136-
"cell_type": "code",
137-
"execution_count": 4,
138-
"metadata": {},
139-
"outputs": [
140-
{
141-
"data": {
142-
"text/plain": [
143-
"['labels_with_ids', '.ipynb_checkpoints', 'images']"
144-
]
145-
},
146-
"execution_count": 4,
147-
"metadata": {},
148-
"output_type": "execute_result"
149-
}
150-
],
151-
"source": [
152-
"DATA_PATH_TRAIN = \"./data/odFridgeObjects_FairMOTformat/\"\n",
153-
"os.listdir(DATA_PATH_TRAIN)"
140+
"## 01 Finetune a Pretrained Model"
154141
]
155142
},
156143
{
157144
"cell_type": "markdown",
158145
"metadata": {},
159146
"source": [
160-
"## Load Training Dataset"
147+
"Initialize the training dataset."
161148
]
162149
},
163150
{
164151
"cell_type": "code",
165-
"execution_count": 5,
166-
"metadata": {
167-
"scrolled": true
168-
},
152+
"execution_count": 4,
153+
"metadata": {},
169154
"outputs": [
170155
{
171156
"name": "stdout",
@@ -183,7 +168,7 @@
183168
],
184169
"source": [
185170
"data_train = TrackingDataset(\n",
186-
" DATA_PATH_TRAIN,\n",
171+
" TRAIN_DATA_PATH,\n",
187172
" batch_size=BATCH_SIZE\n",
188173
")"
189174
]
@@ -192,12 +177,12 @@
192177
"cell_type": "markdown",
193178
"metadata": {},
194179
"source": [
195-
"## Finetune a Pretrained Model"
180+
"Initialize and load the model. We use the baseline FairMOT model, which can be downloaded [here](https://drive.google.com/file/d/1udpOPum8fJdoEQm6n0jsIgMMViOMFinu/view)."
196181
]
197182
},
198183
{
199184
"cell_type": "code",
200-
"execution_count": 6,
185+
"execution_count": 5,
201186
"metadata": {},
202187
"outputs": [
203188
{
@@ -209,28 +194,193 @@
209194
}
210195
],
211196
"source": [
212-
"tracker = TrackingLearner(data_train) \n",
197+
"tracker = TrackingLearner(data_train, \"./models/fairmot_ft.pth\")\n",
213198
"print(f\"Model: {type(tracker.model)}\")"
214199
]
215200
},
216201
{
217202
"cell_type": "code",
218-
"execution_count": 7,
203+
"execution_count": 6,
219204
"metadata": {
220205
"scrolled": true
221206
},
222207
"outputs": [
208+
{
209+
"name": "stdout",
210+
"output_type": "stream",
211+
"text": [
212+
"Loading /home/jihon/computervision-recipes/scenarios/tracking/models/all_dla34.pth\n",
213+
"loaded /home/jihon/computervision-recipes/scenarios/tracking/models/all_dla34.pth, epoch 10\n",
214+
"Resumed optimizer with start lr 0.0001\n",
215+
"===== Epoch: 11/11 =====\n"
216+
]
217+
},
223218
{
224219
"name": "stderr",
225220
"output_type": "stream",
226221
"text": [
227222
"/anaconda/envs/cv/lib/python3.7/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
228223
" warnings.warn(warning.format(ret))\n"
229224
]
225+
},
226+
{
227+
"name": "stdout",
228+
"output_type": "stream",
229+
"text": [
230+
"loss: 1.1128346400433464\n",
231+
"hm_loss: 0.06353224289612051\n",
232+
"wh_loss: 1.57920023114543\n",
233+
"off_loss: 0.18636367223715702\n",
234+
"id_loss: 0.8860541224528692\n",
235+
"time: 44.016666666666666\n",
236+
"Model saved to ./models/fairmot_ft.pth\n"
237+
]
238+
}
239+
],
240+
"source": [
241+
"tracker.fit(num_epochs=EPOCHS, lr=LEARNING_RATE, resume=True)"
242+
]
243+
},
244+
{
245+
"cell_type": "code",
246+
"execution_count": 7,
247+
"metadata": {
248+
"scrolled": true
249+
},
250+
"outputs": [
251+
{
252+
"name": "stdout",
253+
"output_type": "stream",
254+
"text": [
255+
"Model saved to ./models/model_01.pth\n"
256+
]
230257
}
231258
],
232259
"source": [
233-
"tracker.fit(num_epochs=EPOCHS, lr=LEARNING_RATE)"
260+
"if SAVE_MODEL:\n",
261+
" tracker.save(f\"./models/model_{EPOCHS:02d}.pth\")"
262+
]
263+
},
264+
{
265+
"cell_type": "markdown",
266+
"metadata": {},
267+
"source": [
268+
"## 02 Evaluate"
269+
]
270+
},
271+
{
272+
"cell_type": "markdown",
273+
"metadata": {},
274+
"source": [
275+
"Note that `EVAL_DATA_PATH` follows the FairMOT input format."
276+
]
277+
},
278+
{
279+
"cell_type": "code",
280+
"execution_count": 10,
281+
"metadata": {
282+
"scrolled": true
283+
},
284+
"outputs": [
285+
{
286+
"name": "stdout",
287+
"output_type": "stream",
288+
"text": [
289+
"Creating model...\n",
290+
"loaded ./models/fairmot_ft.pth, epoch 11\n"
291+
]
292+
}
293+
],
294+
"source": [
295+
"eval_results = tracker.predict(\n",
296+
" EVAL_DATA_PATH,\n",
297+
" conf_thres=CONF_THRES,\n",
298+
" track_buffer=TRACK_BUFFER,\n",
299+
" im_size=IM_SIZE,\n",
300+
" frame_rate=FRAME_RATE\n",
301+
")"
302+
]
303+
},
304+
{
305+
"cell_type": "code",
306+
"execution_count": 11,
307+
"metadata": {
308+
"scrolled": true
309+
},
310+
"outputs": [],
311+
"source": [
312+
"eval_metrics = tracker.evaluate(eval_results, EVAL_DATA_PATH) "
313+
]
314+
},
315+
{
316+
"cell_type": "markdown",
317+
"metadata": {},
318+
"source": [
319+
"## 03 Predict"
320+
]
321+
},
322+
{
323+
"cell_type": "code",
324+
"execution_count": 12,
325+
"metadata": {},
326+
"outputs": [],
327+
"source": [
328+
"input_video = download(\n",
329+
" Urls.carcans_video_path, osp.join(data_path(), \"carcans.mp4\")\n",
330+
")"
331+
]
332+
},
333+
{
334+
"cell_type": "code",
335+
"execution_count": 15,
336+
"metadata": {
337+
"scrolled": false
338+
},
339+
"outputs": [
340+
{
341+
"name": "stdout",
342+
"output_type": "stream",
343+
"text": [
344+
"Creating model...\n",
345+
"loaded ./models/fairmot_ft.pth, epoch 11\n",
346+
"Lenth of the video: 251 frames\n"
347+
]
348+
}
349+
],
350+
"source": [
351+
"test_results = tracker.predict(\n",
352+
" input_video,\n",
353+
" conf_thres=CONF_THRES,\n",
354+
" track_buffer=TRACK_BUFFER,\n",
355+
" im_size=IM_SIZE,\n",
356+
")"
357+
]
358+
},
359+
{
360+
"cell_type": "code",
361+
"execution_count": 16,
362+
"metadata": {},
363+
"outputs": [],
364+
"source": [
365+
"output_video = osp.join(data_path(), \"carcans_output.mp4\")"
366+
]
367+
},
368+
{
369+
"cell_type": "code",
370+
"execution_count": null,
371+
"metadata": {},
372+
"outputs": [],
373+
"source": [
374+
"write_video(test_results, input_video, output_video)"
375+
]
376+
},
377+
{
378+
"cell_type": "code",
379+
"execution_count": null,
380+
"metadata": {},
381+
"outputs": [],
382+
"source": [
383+
"Video.from_file(output_video)"
234384
]
235385
}
236386
],

utils_cv/classification/plot.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ def plot_thresholds(
3434
metric_function: The metric function
3535
y_pred: predicted probabilities.
3636
y_true: True class indices.
37-
<<<<<<< HEAD
3837
samples: Number of threshold samples
39-
=======
40-
>>>>>>> master
4138
figsize: Figure size (w, h)
4239
"""
4340
metric_name = metric_function.__name__

utils_cv/tracking/data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
from typing import List
5+
from urllib.parse import urljoin
6+
7+
8+
class Urls:
9+
base = "https://cvbp.blob.core.windows.net/public/datasets/tracking/"
10+
11+
fridge_objects_path = urljoin(base, "odFridgeObjects_FairMOT-Format.zip")
12+
carcans_annotations_path = urljoin(base, "carcans_vott-csv-export.zip")
13+
carcans_video_path = urljoin(base, "car_cans_8s.mp4")
14+
15+
@classmethod
16+
def all(cls) -> List[str]:
17+
return [v for k, v in cls.__dict__.items() if k.endswith("_path")]

utils_cv/tracking/dataset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ class TrackingDataset:
1515
"""A multi-object tracking dataset."""
1616

1717
def __init__(
18-
self,
19-
data_root: str,
20-
name: str = "default",
21-
batch_size: int = 12,
18+
self, data_root: str, name: str = "default", batch_size: int = 12,
2219
) -> None:
2320
"""
2421
Args:

0 commit comments

Comments
 (0)