@@ -574,7 +574,6 @@ def set_strides_grids(self):
574574 self .expanded_strides = np .concatenate (expanded_strides , 1 )
575575
576576
577-
578577class YoloV3ONNX (DetectionModel ):
579578 __model__ = "YOLOv3-ONNX"
580579
@@ -766,7 +765,9 @@ def non_max_suppression(
766765
767766 # Detections matrix nx6 (xyxy, conf, cls)
768767 box , cls , mask = x [:, :4 ], x [:, 4 : nc + 4 ], x [:, nc + 4 :]
769- box = xywh2xyxy (box ) # center_x, center_y, width, height) to (x1, y1, x2, y2) # TODO: first cut by conf_thres
768+ box = xywh2xyxy (
769+ box
770+ ) # center_x, center_y, width, height) to (x1, y1, x2, y2) # TODO: first cut by conf_thres
770771 if multi_label :
771772 i , j = (cls > conf_thres ).nonzero (as_tuple = False ).T
772773 x = torch .cat ((box [i ], x [i , 4 + j , None ], j [:, None ].float (), mask [i ]), 1 )
@@ -808,13 +809,9 @@ def __init__(self, inference_adapter, configuration, preload=False):
808809 self .raise_error ("the output must be of precision f32" )
809810 out_shape = output .shape
810811 if 3 != len (out_shape ):
811- self .raise_error (
812- "the output must be of rank 3"
813- )
812+ self .raise_error ("the output must be of rank 3" )
814813 if self .labels and len (self .labels ) + 4 != out_shape [1 ]:
815- self .raise_error (
816- "number of labes must be smaller than out_shape[1] by 4"
817- )
814+ self .raise_error ("number of labels must be smaller than out_shape[1] by 4" )
818815
819816 @classmethod
820817 def parameters (cls ):
@@ -841,7 +838,9 @@ def parameters(cls):
841838 def postprocess (self , outputs , meta ):
842839 if 1 != len (outputs ):
843840 raise RuntimeError ("YoloV8 wrapper expects 1 output" )
844- boxes = non_max_suppression (next (iter (outputs .values ())), self .confidence_threshold , self .iou_threshold )
841+ boxes = non_max_suppression (
842+ next (iter (outputs .values ())), self .confidence_threshold , self .iou_threshold
843+ )
845844
846845 inputImgWidth , inputImgHeight = (
847846 meta ["original_shape" ][1 ],
@@ -884,4 +883,5 @@ def postprocess(self, outputs, meta):
884883
885884class YOLOv8 (YOLOv5 ):
886885 """YOLOv5 and YOLOv8 are identical in terms of inference"""
886+
887887 __model__ = "YOLOv8"
0 commit comments