diff --git a/models.py b/models.py index 3ca0cd6..e2b434c 100755 --- a/models.py +++ b/models.py @@ -72,7 +72,7 @@ def build_model(self, load_pretrained=True): iou_threshold=self.config['iou_threshold'], score_threshold=self.config['score_threshold'])) - if load_pretrained and self.weight_path and self.weight_path.endswith('.weights'): + if load_pretrained and self.weight_path and (self.weight_path.endswith('.weights') or self.weight_path.endswith('.h5')): if self.weight_path.endswith('.weights'): load_weights(self.yolo_model, self.weight_path) print(f'load from {self.weight_path}')