Skip to content

Commit f3cf37f

Browse files
authored
Added evaluation to tracking 01 notebook
1 parent af63f87 commit f3cf37f

File tree

6 files changed

+261
-91
lines changed

6 files changed

+261
-91
lines changed

scenarios/tracking/01_training_introduction.ipynb

100644100755
Lines changed: 130 additions & 71 deletions
Large diffs are not rendered by default.

utils_cv/tracking/dataset.py

100644100755
Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33

4+
from collections import OrderedDict
5+
import numpy as np
46
import os
57
import os.path as osp
6-
from typing import Dict
8+
from typing import Dict, List
79
from torch.utils.data import DataLoader
810
from torchvision.transforms import transforms as T
11+
from .bbox import TrackingBbox
912
from .references.fairmot.datasets.dataset.jde import JointDataset
1013
from .opts import opts
1114
from ..common.gpu import db_num_workers
@@ -54,3 +57,31 @@ def _init_dataloaders(self) -> None:
5457
pin_memory=True,
5558
drop_last=True,
5659
)
60+
61+
def boxes_to_mot(results: Dict[int, List[TrackingBbox]]) -> None:
62+
"""
63+
Save the predicted tracks to csv file in MOT challenge format ["frame", "id", "left", "top", "width", "height",]
64+
65+
Args:
66+
results: dictionary mapping frame id to a list of predicted TrackingBboxes
67+
txt_path: path to which results are saved in csv file
68+
69+
"""
70+
# convert results to dataframe in MOT challenge format
71+
preds = OrderedDict(sorted(results.items()))
72+
bboxes = [
73+
[
74+
bb.frame_id,
75+
bb.track_id,
76+
bb.top,
77+
bb.left,
78+
bb.bottom - bb.top,
79+
bb.right - bb.left,
80+
1, -1, -1, -1,
81+
]
82+
for _, v in preds.items()
83+
for bb in v
84+
]
85+
bboxes_formatted = np.array(bboxes)
86+
87+
return bboxes_formatted

utils_cv/tracking/model.py

100644100755
Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,26 @@
22
# Licensed under the MIT License.
33

44
import argparse
5-
from collections import OrderedDict
5+
from collections import OrderedDict, defaultdict
66
from copy import deepcopy
77
import glob
88
import requests
99
import os
1010
import os.path as osp
11+
import tempfile #KIP
1112
from typing import Dict, List, Tuple
1213

14+
1315
import torch
1416
import torch.cuda as cuda
1517
import torch.nn as nn
1618
from torch.utils.data import DataLoader
1719

1820
import cv2
21+
import numpy as np
1922
import pandas as pd
2023
import matplotlib.pyplot as plt
24+
import motmetrics as mm
2125

2226
from .references.fairmot.datasets.dataset.jde import LoadImages, LoadVideo
2327
from .references.fairmot.models.model import (
@@ -26,10 +30,11 @@
2630
save_model,
2731
)
2832
from .references.fairmot.tracker.multitracker import JDETracker
33+
from .references.fairmot.tracking_utils.evaluation import Evaluator
2934
from .references.fairmot.trains.train_factory import train_factory
3035

3136
from .bbox import TrackingBbox
32-
from .dataset import TrackingDataset
37+
from .dataset import TrackingDataset, boxes_to_mot
3338
from .opts import opts
3439
from .plot import draw_boxes, assign_colors
3540
from ..common.gpu import torch_device
@@ -211,7 +216,10 @@ def fit(
211216
Trainer = train_factory[opt_fit.task]
212217
trainer = Trainer(opt_fit.opt, self.model, self.optimizer)
213218
trainer.set_device(opt_fit.gpus, opt_fit.chunk_sizes, opt_fit.device)
214-
219+
220+
# initialize loss vars
221+
self.losses_dict = defaultdict(list)
222+
215223
# training loop
216224
for epoch in range(
217225
start_epoch + 1, start_epoch + opt_fit.num_epochs + 1
@@ -229,10 +237,41 @@ def fit(
229237
lr = opt_fit.lr * (0.1 ** (opt_fit.lr_step.index(epoch) + 1))
230238
for param_group in optimizer.param_groups:
231239
param_group["lr"] = lr
240+
241+
# store losses in each epoch
242+
for k, v in log_dict_train.items():
243+
if k in ['loss', 'hm_loss', 'wh_loss', 'off_loss', 'id_loss']:
244+
self.losses_dict[k].append(v)
232245

233246
# save after training because at inference-time FairMOT src reads model weights from disk
234247
self.save(self.model_path)
235248

249+
def plot_training_losses(self, figsize: Tuple[int, int] = (10, 5))->None:
250+
'''
251+
Plots training loss from calling `fit`
252+
253+
Args:
254+
figsize (optional): width and height wanted for figure of training-loss plot
255+
256+
'''
257+
fig = plt.figure(figsize=figsize)
258+
ax1 = fig.add_subplot(1, 1, 1)
259+
260+
ax1.set_xlim([0, len(self.losses_dict['loss']) - 1])
261+
ax1.set_xticks(range(0, len(self.losses_dict['loss'])))
262+
ax1.set_xlabel("epochs")
263+
ax1.set_ylabel("losses")
264+
265+
ax1.plot(self.losses_dict['loss'], c="r", label='loss')
266+
ax1.plot(self.losses_dict['hm_loss'], c="y", label='hm_loss')
267+
ax1.plot(self.losses_dict['wh_loss'], c="g", label='wh_loss')
268+
ax1.plot(self.losses_dict['off_loss'], c="b", label='off_loss')
269+
ax1.plot(self.losses_dict['id_loss'], c="m", label='id_loss')
270+
271+
plt.legend(loc='upper right')
272+
fig.suptitle("Training losses over epochs")
273+
274+
236275
def save(self, path) -> None:
237276
"""
238277
Save the model to a specified path.
@@ -243,9 +282,46 @@ def save(self, path) -> None:
243282
save_model(path, self.epoch, self.model, self.optimizer)
244283
print(f"Model saved to {path}")
245284

246-
def evaluate(self, results, gt) -> pd.DataFrame:
247-
pass
248-
285+
def evaluate(self,
286+
results: Dict[int, List[TrackingBbox]],
287+
gt_root_path: str) -> str:
288+
289+
""" eval code that calls on 'motmetrics' package in referenced FairMOT script, to produce MOT metrics on inference, given ground-truth.
290+
Args:
291+
results: prediction results from predict() function, i.e. Dict[int, List[TrackingBbox]]
292+
gt_root_path: path of dataset containing GT annotations in MOTchallenge format (xywh)
293+
Returns:
294+
strsummary: str output by method in 'motmetrics' package, containing metrics scores
295+
"""
296+
297+
#Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/track.py
298+
evaluator = Evaluator(gt_root_path, "single_vid", "mot")
299+
300+
with tempfile.TemporaryDirectory() as tmpdir1:
301+
os.makedirs(osp.join(tmpdir1,'results'))
302+
result_filename = osp.join(tmpdir1,'results', 'results.txt')
303+
304+
# Save results im MOT format for evaluation
305+
bboxes_mot = boxes_to_mot(results)
306+
np.savetxt(result_filename, bboxes_mot, delimiter=",", fmt="%s")
307+
308+
# Run evaluation using pymotmetrics package
309+
accs=[evaluator.eval_file(result_filename)]
310+
311+
# get summary
312+
metrics = mm.metrics.motchallenge_metrics
313+
mh = mm.metrics.create()
314+
315+
summary = Evaluator.get_summary(accs, ("single_vid",), metrics)
316+
strsummary = mm.io.render_summary(
317+
summary,
318+
formatters=mh.formatters,
319+
namemap=mm.io.motchallenge_metric_names
320+
)
321+
print(strsummary)
322+
323+
return strsummary
324+
249325
def predict(
250326
self,
251327
im_or_video_path: str,

utils_cv/tracking/plot.py

100644100755
File mode changed.

utils_cv/tracking/references/fairmot/tracking_utils/evaluation.py

100644100755
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import motmetrics as mm
55
mm.lap.default_solver = 'lap'
66

7-
from tracking_utils.io import read_results, unzip_objs
7+
from .io import read_results, unzip_objs #EDITED
88

99

1010
class Evaluator(object):
@@ -76,10 +76,10 @@ def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
7676
return events
7777

7878
def eval_file(self, filename):
79-
self.reset_accumulator()
80-
79+
self.reset_accumulator()
8180
result_frame_dict = read_results(filename, self.data_type, is_gt=False)
8281
frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
82+
8383
for frame_id in frames:
8484
trk_objs = result_frame_dict.get(frame_id, [])
8585
trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]

utils_cv/tracking/references/fairmot/tracking_utils/io.py

100644100755
Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Dict
33
import numpy as np
44

5-
from tracking_utils.log import logger
5+
from .log import logger #EDITED
66

77

88
def write_results(filename, results_dict: Dict, data_type: str):
@@ -61,16 +61,18 @@ def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
6161

6262

6363
def read_mot_results(filename, is_gt, is_ignore):
64+
6465
valid_labels = {1}
6566
ignore_labels = {2, 7, 8, 12}
6667
results_dict = dict()
67-
if os.path.isfile(filename):
68-
with open(filename, 'r') as f:
69-
for line in f.readlines():
70-
linelist = line.split(',')
71-
if len(linelist) < 7:
68+
if os.path.isfile(filename):
69+
with open(filename, 'r') as f:
70+
for line in f.readlines():
71+
linelist = line.split(',')
72+
if len(linelist) < 7:
7273
continue
7374
fid = int(linelist[0])
75+
7476
if fid < 1:
7577
continue
7678
results_dict.setdefault(fid, list())
@@ -82,6 +84,7 @@ def read_mot_results(filename, is_gt, is_ignore):
8284
if mark == 0 or label not in valid_labels:
8385
continue
8486
score = 1
87+
8588
elif is_ignore:
8689
if 'MOT16-' in filename or 'MOT17-' in filename:
8790
label = int(float(linelist[7]))
@@ -91,14 +94,15 @@ def read_mot_results(filename, is_gt, is_ignore):
9194
else:
9295
continue
9396
score = 1
97+
9498
else:
9599
score = float(linelist[6])
96-
100+
101+
97102
tlwh = tuple(map(float, linelist[2:6]))
98-
target_id = int(linelist[1])
99-
103+
target_id = int(linelist[1])
100104
results_dict[fid].append((tlwh, target_id, score))
101-
105+
102106
return results_dict
103107

104108

0 commit comments

Comments
 (0)