11from itertools import groupby
22from operator import itemgetter
3- from typing import Dict , Generator , List , Tuple , Union
3+ from typing import Dict , Generator , List , Optional , Tuple , Union
44from collections import defaultdict
55import warnings
66
1717from .metric import NDScalarMetric , NDMetricAnnotation , NDConfusionMatrixMetric
1818from .classification import NDChecklistSubclass , NDClassification , NDClassificationType , NDRadioSubclass
1919from .objects import NDObject , NDObjectType , NDSegments
20+ from .base import DataRow
2021
2122
2223class NDLabel (BaseModel ):
@@ -27,7 +28,10 @@ class NDLabel(BaseModel):
2728 def to_common (self ) -> LabelGenerator :
2829 grouped_annotations = defaultdict (list )
2930 for annotation in self .annotations :
30- grouped_annotations [annotation .data_row .id ].append (annotation )
31+ grouped_annotations [annotation .data_row .id or
32+ annotation .data_row .global_key ].append (
33+ annotation )
34+ print (grouped_annotations )
3135 return LabelGenerator (
3236 data = self ._generate_annotations (grouped_annotations ))
3337
@@ -45,9 +49,11 @@ def _generate_annotations(
4549 NDConfusionMatrixMetric ,
4650 NDScalarMetric , NDSegments ]]]
4751 ) -> Generator [Label , None , None ]:
48- for data_row_id , annotations in grouped_annotations .items ():
52+ for _ , annotations in grouped_annotations .items ():
4953 annots = []
54+ data_row = annotations [0 ].data_row
5055 for annotation in annotations :
56+
5157 if isinstance (annotation , NDSegments ):
5258 annots .extend (
5359 NDSegments .to_common (annotation , annotation .name ,
@@ -62,22 +68,30 @@ def _generate_annotations(
6268 else :
6369 raise TypeError (
6470 f"Unsupported annotation. { type (annotation )} " )
65- data = self . _infer_media_type ( annots )( uid = data_row_id )
66- yield Label ( annotations = annots , data = data )
71+ yield Label ( annotations = annots ,
72+ data = self . _infer_media_type ( data_row , annots ) )
6773
6874 def _infer_media_type (
69- self , annotations : List [Union [TextEntity , VideoClassificationAnnotation ,
70- VideoObjectAnnotation , ObjectAnnotation ,
71- ClassificationAnnotation , ScalarMetric ,
72- ConfusionMatrixMetric ]]
75+ self , data_row : DataRow ,
76+ annotations : List [Union [TextEntity , VideoClassificationAnnotation ,
77+ VideoObjectAnnotation , ObjectAnnotation ,
78+ ClassificationAnnotation , ScalarMetric ,
79+ ConfusionMatrixMetric ]]
7380 ) -> Union [TextData , VideoData , ImageData ]:
81+ if len (annotations ) == 0 :
82+ raise ValueError ("Missing annotations while inferring media type" )
83+
7484 types = {type (annotation ) for annotation in annotations }
85+ data = ImageData
7586 if TextEntity in types :
76- return TextData
87+ data = TextData
7788 elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types :
78- return VideoData
89+ data = VideoData
90+
91+ if data_row .id :
92+ return data (uid = data_row .id )
7993 else :
80- return ImageData
94+ return data ( global_key = data_row . global_key )
8195
8296 @staticmethod
8397 def _get_consecutive_frames (
0 commit comments