|
20 | 20 | "cell_type": "markdown", |
21 | 21 | "metadata": {}, |
22 | 22 | "source": [ |
23 | | - "## Initialization" |
| 23 | + "## 00 Initialization" |
24 | 24 | ] |
25 | 25 | }, |
26 | 26 | { |
|
45 | 45 | "sys.path.append(\"../../\")\n", |
46 | 46 | "\n", |
47 | 47 | "import os\n", |
| 48 | + "import os.path as osp\n", |
48 | 49 | "import time\n", |
| 50 | + "from ipywidgets import Video\n", |
49 | 51 | "import matplotlib.pyplot as plt\n", |
50 | | - "from typing import Iterator\n", |
51 | | - "from pathlib import Path\n", |
52 | | - "from PIL import Image\n", |
53 | | - "from random import randrange\n", |
54 | | - "from typing import Tuple\n", |
55 | 52 | "import torch\n", |
56 | 53 | "import torchvision\n", |
57 | | - "from torchvision import transforms\n", |
58 | | - "import scrapbook as sb\n", |
59 | 54 | "\n", |
60 | | - "from ipywidgets import Video\n", |
| 55 | + "from utils_cv.tracking.data import Urls\n", |
61 | 56 | "from utils_cv.tracking.dataset import TrackingDataset\n", |
62 | | - "from utils_cv.tracking.model import TrackingLearner\n", |
| 57 | + "from utils_cv.tracking.model import TrackingLearner, write_video\n", |
63 | 58 | "\n", |
| 59 | + "from utils_cv.common.data import data_path, download, unzip_url\n", |
64 | 60 | "from utils_cv.common.gpu import which_processor, is_windows\n", |
65 | 61 | "\n", |
66 | 62 | "# Change matplotlib backend so that plots are shown for windows\n", |
|
115 | 111 | } |
116 | 112 | ], |
117 | 113 | "source": [ |
118 | | - "EPOCHS = 2\n", |
| 114 | + "EPOCHS = 1\n", |
119 | 115 | "LEARNING_RATE = 0.0001\n", |
120 | 116 | "BATCH_SIZE = 1\n", |
| 117 | + "\n", |
121 | 118 | "SAVE_MODEL = True\n", |
| 119 | + "FRAME_RATE = 30\n", |
| 120 | + "\n", |
| 121 | + "CONF_THRES = 0.3\n", |
| 122 | + "TRACK_BUFFER = 300\n", |
| 123 | + "IM_SIZE = (1080, 1920)\n", |
| 124 | + "\n", |
| 125 | + "TRAIN_DATA_PATH = unzip_url(Urls.fridge_objects_path, exist_ok=True)\n", |
| 126 | + "EVAL_DATA_PATH = unzip_url(Urls.carcans_annotations_path, exist_ok=True)\n", |
| 127 | + "\n", |
| 128 | + "BASELINE_MODEL = \"./models/all_dla34_new.pth\"\n", |
| 129 | + "FT_MODEL = \"./models/model_30.pth\"\n", |
122 | 130 | "\n", |
123 | 131 | "# train on the GPU or on the CPU, if a GPU is not available\n", |
124 | 132 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", |
|
129 | 137 | "cell_type": "markdown", |
130 | 138 | "metadata": {}, |
131 | 139 | "source": [ |
132 | | - "## Prepare Training Dataset" |
133 | | - ] |
134 | | - }, |
135 | | - { |
136 | | - "cell_type": "code", |
137 | | - "execution_count": 4, |
138 | | - "metadata": {}, |
139 | | - "outputs": [ |
140 | | - { |
141 | | - "data": { |
142 | | - "text/plain": [ |
143 | | - "['labels_with_ids', '.ipynb_checkpoints', 'images']" |
144 | | - ] |
145 | | - }, |
146 | | - "execution_count": 4, |
147 | | - "metadata": {}, |
148 | | - "output_type": "execute_result" |
149 | | - } |
150 | | - ], |
151 | | - "source": [ |
152 | | - "DATA_PATH_TRAIN = \"./data/odFridgeObjects_FairMOTformat/\"\n", |
153 | | - "os.listdir(DATA_PATH_TRAIN)" |
| 140 | + "## 01 Finetune a Pretrained Model" |
154 | 141 | ] |
155 | 142 | }, |
156 | 143 | { |
157 | 144 | "cell_type": "markdown", |
158 | 145 | "metadata": {}, |
159 | 146 | "source": [ |
160 | | - "## Load Training Dataset" |
| 147 | + "Initialize the training dataset." |
161 | 148 | ] |
162 | 149 | }, |
163 | 150 | { |
164 | 151 | "cell_type": "code", |
165 | | - "execution_count": 5, |
166 | | - "metadata": { |
167 | | - "scrolled": true |
168 | | - }, |
| 152 | + "execution_count": 4, |
| 153 | + "metadata": {}, |
169 | 154 | "outputs": [ |
170 | 155 | { |
171 | 156 | "name": "stdout", |
|
183 | 168 | ], |
184 | 169 | "source": [ |
185 | 170 | "data_train = TrackingDataset(\n", |
186 | | - " DATA_PATH_TRAIN,\n", |
| 171 | + " TRAIN_DATA_PATH,\n", |
187 | 172 | " batch_size=BATCH_SIZE\n", |
188 | 173 | ")" |
189 | 174 | ] |
|
192 | 177 | "cell_type": "markdown", |
193 | 178 | "metadata": {}, |
194 | 179 | "source": [ |
195 | | - "## Finetune a Pretrained Model" |
| 180 | + "Initialize and load the model. We use the baseline FairMOT model, which can be downloaded [here](https://drive.google.com/file/d/1udpOPum8fJdoEQm6n0jsIgMMViOMFinu/view)." |
196 | 181 | ] |
197 | 182 | }, |
198 | 183 | { |
199 | 184 | "cell_type": "code", |
200 | | - "execution_count": 6, |
| 185 | + "execution_count": 5, |
201 | 186 | "metadata": {}, |
202 | 187 | "outputs": [ |
203 | 188 | { |
|
209 | 194 | } |
210 | 195 | ], |
211 | 196 | "source": [ |
212 | | - "tracker = TrackingLearner(data_train) \n", |
| 197 | + "tracker = TrackingLearner(data_train, \"./models/fairmot_ft.pth\")\n", |
213 | 198 | "print(f\"Model: {type(tracker.model)}\")" |
214 | 199 | ] |
215 | 200 | }, |
216 | 201 | { |
217 | 202 | "cell_type": "code", |
218 | | - "execution_count": 7, |
| 203 | + "execution_count": 6, |
219 | 204 | "metadata": { |
220 | 205 | "scrolled": true |
221 | 206 | }, |
222 | 207 | "outputs": [ |
| 208 | + { |
| 209 | + "name": "stdout", |
| 210 | + "output_type": "stream", |
| 211 | + "text": [ |
| 212 | + "Loading /home/jihon/computervision-recipes/scenarios/tracking/models/all_dla34.pth\n", |
| 213 | + "loaded /home/jihon/computervision-recipes/scenarios/tracking/models/all_dla34.pth, epoch 10\n", |
| 214 | + "Resumed optimizer with start lr 0.0001\n", |
| 215 | + "===== Epoch: 11/11 =====\n" |
| 216 | + ] |
| 217 | + }, |
223 | 218 | { |
224 | 219 | "name": "stderr", |
225 | 220 | "output_type": "stream", |
226 | 221 | "text": [ |
227 | 222 | "/anaconda/envs/cv/lib/python3.7/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", |
228 | 223 | " warnings.warn(warning.format(ret))\n" |
229 | 224 | ] |
| 225 | + }, |
| 226 | + { |
| 227 | + "name": "stdout", |
| 228 | + "output_type": "stream", |
| 229 | + "text": [ |
| 230 | + "loss: 1.1128346400433464\n", |
| 231 | + "hm_loss: 0.06353224289612051\n", |
| 232 | + "wh_loss: 1.57920023114543\n", |
| 233 | + "off_loss: 0.18636367223715702\n", |
| 234 | + "id_loss: 0.8860541224528692\n", |
| 235 | + "time: 44.016666666666666\n", |
| 236 | + "Model saved to ./models/fairmot_ft.pth\n" |
| 237 | + ] |
| 238 | + } |
| 239 | + ], |
| 240 | + "source": [ |
| 241 | + "tracker.fit(num_epochs=EPOCHS, lr=LEARNING_RATE, resume=True)" |
| 242 | + ] |
| 243 | + }, |
| 244 | + { |
| 245 | + "cell_type": "code", |
| 246 | + "execution_count": 7, |
| 247 | + "metadata": { |
| 248 | + "scrolled": true |
| 249 | + }, |
| 250 | + "outputs": [ |
| 251 | + { |
| 252 | + "name": "stdout", |
| 253 | + "output_type": "stream", |
| 254 | + "text": [ |
| 255 | + "Model saved to ./models/model_01.pth\n" |
| 256 | + ] |
230 | 257 | } |
231 | 258 | ], |
232 | 259 | "source": [ |
233 | | - "tracker.fit(num_epochs=EPOCHS, lr=LEARNING_RATE)" |
| 260 | + "if SAVE_MODEL:\n", |
| 261 | + " tracker.save(f\"./models/model_{EPOCHS:02d}.pth\")" |
| 262 | + ] |
| 263 | + }, |
| 264 | + { |
| 265 | + "cell_type": "markdown", |
| 266 | + "metadata": {}, |
| 267 | + "source": [ |
| 268 | + "## 02 Evaluate" |
| 269 | + ] |
| 270 | + }, |
| 271 | + { |
| 272 | + "cell_type": "markdown", |
| 273 | + "metadata": {}, |
| 274 | + "source": [ |
| 275 | + "Note that `EVAL_DATA_PATH` follows the FairMOT input format." |
| 276 | + ] |
| 277 | + }, |
| 278 | + { |
| 279 | + "cell_type": "code", |
| 280 | + "execution_count": 10, |
| 281 | + "metadata": { |
| 282 | + "scrolled": true |
| 283 | + }, |
| 284 | + "outputs": [ |
| 285 | + { |
| 286 | + "name": "stdout", |
| 287 | + "output_type": "stream", |
| 288 | + "text": [ |
| 289 | + "Creating model...\n", |
| 290 | + "loaded ./models/fairmot_ft.pth, epoch 11\n" |
| 291 | + ] |
| 292 | + } |
| 293 | + ], |
| 294 | + "source": [ |
| 295 | + "eval_results = tracker.predict(\n", |
| 296 | + " EVAL_DATA_PATH,\n", |
| 297 | + " conf_thres=CONF_THRES,\n", |
| 298 | + " track_buffer=TRACK_BUFFER,\n", |
| 299 | + " im_size=IM_SIZE,\n", |
| 300 | + " frame_rate=FRAME_RATE\n", |
| 301 | + ")" |
| 302 | + ] |
| 303 | + }, |
| 304 | + { |
| 305 | + "cell_type": "code", |
| 306 | + "execution_count": 11, |
| 307 | + "metadata": { |
| 308 | + "scrolled": true |
| 309 | + }, |
| 310 | + "outputs": [], |
| 311 | + "source": [ |
| 312 | + "eval_metrics = tracker.evaluate(eval_results, EVAL_DATA_PATH) " |
| 313 | + ] |
| 314 | + }, |
| 315 | + { |
| 316 | + "cell_type": "markdown", |
| 317 | + "metadata": {}, |
| 318 | + "source": [ |
| 319 | + "## 03 Predict" |
| 320 | + ] |
| 321 | + }, |
| 322 | + { |
| 323 | + "cell_type": "code", |
| 324 | + "execution_count": 12, |
| 325 | + "metadata": {}, |
| 326 | + "outputs": [], |
| 327 | + "source": [ |
| 328 | + "input_video = download(\n", |
| 329 | + " Urls.carcans_video_path, osp.join(data_path(), \"carcans.mp4\")\n", |
| 330 | + ")" |
| 331 | + ] |
| 332 | + }, |
| 333 | + { |
| 334 | + "cell_type": "code", |
| 335 | + "execution_count": 15, |
| 336 | + "metadata": { |
| 337 | + "scrolled": false |
| 338 | + }, |
| 339 | + "outputs": [ |
| 340 | + { |
| 341 | + "name": "stdout", |
| 342 | + "output_type": "stream", |
| 343 | + "text": [ |
| 344 | + "Creating model...\n", |
| 345 | + "loaded ./models/fairmot_ft.pth, epoch 11\n", |
| 346 | + "Lenth of the video: 251 frames\n" |
| 347 | + ] |
| 348 | + } |
| 349 | + ], |
| 350 | + "source": [ |
| 351 | + "test_results = tracker.predict(\n", |
| 352 | + " input_video,\n", |
| 353 | + " conf_thres=CONF_THRES,\n", |
| 354 | + " track_buffer=TRACK_BUFFER,\n", |
| 355 | + " im_size=IM_SIZE,\n", |
| 356 | + ")" |
| 357 | + ] |
| 358 | + }, |
| 359 | + { |
| 360 | + "cell_type": "code", |
| 361 | + "execution_count": 16, |
| 362 | + "metadata": {}, |
| 363 | + "outputs": [], |
| 364 | + "source": [ |
| 365 | + "output_video = osp.join(data_path(), \"carcans_output.mp4\")" |
| 366 | + ] |
| 367 | + }, |
| 368 | + { |
| 369 | + "cell_type": "code", |
| 370 | + "execution_count": null, |
| 371 | + "metadata": {}, |
| 372 | + "outputs": [], |
| 373 | + "source": [ |
| 374 | + "write_video(test_results, input_video, output_video)" |
| 375 | + ] |
| 376 | + }, |
| 377 | + { |
| 378 | + "cell_type": "code", |
| 379 | + "execution_count": null, |
| 380 | + "metadata": {}, |
| 381 | + "outputs": [], |
| 382 | + "source": [ |
| 383 | + "Video.from_file(output_video)" |
234 | 384 | ] |
235 | 385 | } |
236 | 386 | ], |
|
0 commit comments