11from itertools import groupby
22from operator import itemgetter
3- from typing import Dict , Generator , List , Tuple , Union
3+ from typing import Dict , Generator , List , Optional , Tuple , Union
44from collections import defaultdict
55import warnings
66
1717from .metric import NDScalarMetric , NDMetricAnnotation , NDConfusionMatrixMetric
1818from .classification import NDChecklistSubclass , NDClassification , NDClassificationType , NDRadioSubclass
1919from .objects import NDObject , NDObjectType , NDSegments
20+ from .base import DataRow
2021
2122
2223class NDLabel (BaseModel ):
@@ -27,7 +28,9 @@ class NDLabel(BaseModel):
2728 def to_common (self ) -> LabelGenerator :
2829 grouped_annotations = defaultdict (list )
2930 for annotation in self .annotations :
30- grouped_annotations [annotation .data_row .id ].append (annotation )
31+ grouped_annotations [annotation .data_row .id or
32+ annotation .data_row .global_key ].append (
33+ annotation )
3134 return LabelGenerator (
3235 data = self ._generate_annotations (grouped_annotations ))
3336
@@ -45,9 +48,11 @@ def _generate_annotations(
4548 NDConfusionMatrixMetric ,
4649 NDScalarMetric , NDSegments ]]]
4750 ) -> Generator [Label , None , None ]:
48- for data_row_id , annotations in grouped_annotations .items ():
51+ for _ , annotations in grouped_annotations .items ():
4952 annots = []
53+ data_row = annotations [0 ].data_row
5054 for annotation in annotations :
55+
5156 if isinstance (annotation , NDSegments ):
5257 annots .extend (
5358 NDSegments .to_common (annotation , annotation .name ,
@@ -62,22 +67,30 @@ def _generate_annotations(
6267 else :
6368 raise TypeError (
6469 f"Unsupported annotation. { type (annotation )} " )
65- data = self . _infer_media_type ( annots )( uid = data_row_id )
66- yield Label ( annotations = annots , data = data )
70+ yield Label ( annotations = annots ,
71+ data = self . _infer_media_type ( data_row , annots ) )
6772
6873 def _infer_media_type (
69- self , annotations : List [Union [TextEntity , VideoClassificationAnnotation ,
70- VideoObjectAnnotation , ObjectAnnotation ,
71- ClassificationAnnotation , ScalarMetric ,
72- ConfusionMatrixMetric ]]
74+ self , data_row : DataRow ,
75+ annotations : List [Union [TextEntity , VideoClassificationAnnotation ,
76+ VideoObjectAnnotation , ObjectAnnotation ,
77+ ClassificationAnnotation , ScalarMetric ,
78+ ConfusionMatrixMetric ]]
7379 ) -> Union [TextData , VideoData , ImageData ]:
80+ if len (annotations ) == 0 :
81+ raise ValueError ("Missing annotations while inferring media type" )
82+
7483 types = {type (annotation ) for annotation in annotations }
84+ data = ImageData
7585 if TextEntity in types :
76- return TextData
86+ data = TextData
7787 elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types :
78- return VideoData
88+ data = VideoData
89+
90+ if data_row .id :
91+ return data (uid = data_row .id )
7992 else :
80- return ImageData
93+ return data ( global_key = data_row . global_key )
8194
8295 @staticmethod
8396 def _get_consecutive_frames (
0 commit comments