22# Licensed under the MIT License.
33
44import argparse
5- from collections import OrderedDict
5+ from collections import OrderedDict , defaultdict
66from copy import deepcopy
77import glob
88import requests
99import os
1010import os .path as osp
11+ import tempfile #KIP
1112from typing import Dict , List , Tuple
1213
14+
1315import torch
1416import torch .cuda as cuda
1517import torch .nn as nn
1618from torch .utils .data import DataLoader
1719
1820import cv2
21+ import numpy as np
1922import pandas as pd
2023import matplotlib .pyplot as plt
24+ import motmetrics as mm
2125
2226from .references .fairmot .datasets .dataset .jde import LoadImages , LoadVideo
2327from .references .fairmot .models .model import (
2630 save_model ,
2731)
2832from .references .fairmot .tracker .multitracker import JDETracker
33+ from .references .fairmot .tracking_utils .evaluation import Evaluator
2934from .references .fairmot .trains .train_factory import train_factory
3035
3136from .bbox import TrackingBbox
32- from .dataset import TrackingDataset
37+ from .dataset import TrackingDataset , boxes_to_mot
3338from .opts import opts
3439from .plot import draw_boxes , assign_colors
3540from ..common .gpu import torch_device
@@ -211,7 +216,10 @@ def fit(
211216 Trainer = train_factory [opt_fit .task ]
212217 trainer = Trainer (opt_fit .opt , self .model , self .optimizer )
213218 trainer .set_device (opt_fit .gpus , opt_fit .chunk_sizes , opt_fit .device )
214-
219+
220+ # initialize loss vars
221+ self .losses_dict = defaultdict (list )
222+
215223 # training loop
216224 for epoch in range (
217225 start_epoch + 1 , start_epoch + opt_fit .num_epochs + 1
@@ -229,10 +237,41 @@ def fit(
229237 lr = opt_fit .lr * (0.1 ** (opt_fit .lr_step .index (epoch ) + 1 ))
230238 for param_group in optimizer .param_groups :
231239 param_group ["lr" ] = lr
240+
241+ # store losses in each epoch
242+ for k , v in log_dict_train .items ():
243+ if k in ['loss' , 'hm_loss' , 'wh_loss' , 'off_loss' , 'id_loss' ]:
244+ self .losses_dict [k ].append (v )
232245
233246 # save after training because at inference-time FairMOT src reads model weights from disk
234247 self .save (self .model_path )
235248
249+ def plot_training_losses (self , figsize : Tuple [int , int ] = (10 , 5 ))-> None :
250+ '''
251+ Plots training loss from calling `fit`
252+
253+ Args:
254+ figsize (optional): width and height wanted for figure of training-loss plot
255+
256+ '''
257+ fig = plt .figure (figsize = figsize )
258+ ax1 = fig .add_subplot (1 , 1 , 1 )
259+
260+ ax1 .set_xlim ([0 , len (self .losses_dict ['loss' ]) - 1 ])
261+ ax1 .set_xticks (range (0 , len (self .losses_dict ['loss' ])))
262+ ax1 .set_xlabel ("epochs" )
263+ ax1 .set_ylabel ("losses" )
264+
265+ ax1 .plot (self .losses_dict ['loss' ], c = "r" , label = 'loss' )
266+ ax1 .plot (self .losses_dict ['hm_loss' ], c = "y" , label = 'hm_loss' )
267+ ax1 .plot (self .losses_dict ['wh_loss' ], c = "g" , label = 'wh_loss' )
268+ ax1 .plot (self .losses_dict ['off_loss' ], c = "b" , label = 'off_loss' )
269+ ax1 .plot (self .losses_dict ['id_loss' ], c = "m" , label = 'id_loss' )
270+
271+ plt .legend (loc = 'upper right' )
272+ fig .suptitle ("Training losses over epochs" )
273+
274+
236275 def save (self , path ) -> None :
237276 """
238277 Save the model to a specified path.
@@ -243,9 +282,46 @@ def save(self, path) -> None:
243282 save_model (path , self .epoch , self .model , self .optimizer )
244283 print (f"Model saved to { path } " )
245284
246- def evaluate (self , results , gt ) -> pd .DataFrame :
247- pass
248-
285+ def evaluate (self ,
286+ results : Dict [int , List [TrackingBbox ]],
287+ gt_root_path : str ) -> str :
288+
289+ """ eval code that calls on 'motmetrics' package in referenced FairMOT script, to produce MOT metrics on inference, given ground-truth.
290+ Args:
291+ results: prediction results from predict() function, i.e. Dict[int, List[TrackingBbox]]
292+ gt_root_path: path of dataset containing GT annotations in MOTchallenge format (xywh)
293+ Returns:
294+ strsummary: str output by method in 'motmetrics' package, containing metrics scores
295+ """
296+
297+ #Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/track.py
298+ evaluator = Evaluator (gt_root_path , "single_vid" , "mot" )
299+
300+ with tempfile .TemporaryDirectory () as tmpdir1 :
301+ os .makedirs (osp .join (tmpdir1 ,'results' ))
302+ result_filename = osp .join (tmpdir1 ,'results' , 'results.txt' )
303+
304+ # Save results im MOT format for evaluation
305+ bboxes_mot = boxes_to_mot (results )
306+ np .savetxt (result_filename , bboxes_mot , delimiter = "," , fmt = "%s" )
307+
308+ # Run evaluation using pymotmetrics package
309+ accs = [evaluator .eval_file (result_filename )]
310+
311+ # get summary
312+ metrics = mm .metrics .motchallenge_metrics
313+ mh = mm .metrics .create ()
314+
315+ summary = Evaluator .get_summary (accs , ("single_vid" ,), metrics )
316+ strsummary = mm .io .render_summary (
317+ summary ,
318+ formatters = mh .formatters ,
319+ namemap = mm .io .motchallenge_metric_names
320+ )
321+ print (strsummary )
322+
323+ return strsummary
324+
249325 def predict (
250326 self ,
251327 im_or_video_path : str ,
0 commit comments