22from operator import itemgetter
33from typing import Dict , Generator , List , Tuple , Union
44from collections import defaultdict
5+ import warnings
56
67from pydantic import BaseModel
78
8- from ...annotation_types .annotation import ClassificationAnnotation , ObjectAnnotation , VideoClassificationAnnotation
9+ from ...annotation_types .annotation import ClassificationAnnotation , ObjectAnnotation , VideoClassificationAnnotation , VideoObjectAnnotation
910from ...annotation_types .collection import LabelCollection , LabelGenerator
1011from ...annotation_types .data import ImageData , TextData , VideoData
1112from ...annotation_types .label import Label
1516
1617from .metric import NDScalarMetric , NDMetricAnnotation , NDConfusionMatrixMetric
1718from .classification import NDChecklistSubclass , NDClassification , NDClassificationType , NDRadioSubclass
18- from .objects import NDObject , NDObjectType
19+ from .objects import NDObject , NDObjectType , NDSegments
1920
2021
2122class NDLabel (BaseModel ):
2223 annotations : List [Union [NDObjectType , NDClassificationType ,
23- NDConfusionMatrixMetric , NDScalarMetric ]]
24+ NDConfusionMatrixMetric , NDScalarMetric ,
25+ NDSegments ]]
2426
2527 def to_common (self ) -> LabelGenerator :
2628 grouped_annotations = defaultdict (list )
@@ -37,15 +39,20 @@ def from_common(cls,
3739 yield from cls ._create_video_annotations (label )
3840
3941 def _generate_annotations (
40- self , grouped_annotations : Dict [str , List [Union [NDObjectType ,
41- NDClassificationType ,
42- NDConfusionMatrixMetric ,
43- NDScalarMetric ]]]
42+ self ,
43+ grouped_annotations : Dict [str ,
44+ List [Union [NDObjectType , NDClassificationType ,
45+ NDConfusionMatrixMetric ,
46+ NDScalarMetric , NDSegments ]]]
4447 ) -> Generator [Label , None , None ]:
4548 for data_row_id , annotations in grouped_annotations .items ():
4649 annots = []
4750 for annotation in annotations :
48- if isinstance (annotation , NDObjectType .__args__ ):
51+ if isinstance (annotation , NDSegments ):
52+ annots .extend (
53+ NDSegments .to_common (annotation , annotation .schema_id ))
54+
55+ elif isinstance (annotation , NDObjectType .__args__ ):
4956 annots .append (NDObject .to_common (annotation ))
5057 elif isinstance (annotation , NDClassificationType .__args__ ):
5158 annots .extend (NDClassification .to_common (annotation ))
@@ -55,7 +62,6 @@ def _generate_annotations(
5562 else :
5663 raise TypeError (
5764 f"Unsupported annotation. { type (annotation )} " )
58-
5965 data = self ._infer_media_type (annotations )(uid = data_row_id )
6066 yield Label (annotations = annots , data = data )
6167
@@ -65,7 +71,7 @@ def _infer_media_type(
6571 types = {type (annotation ) for annotation in annotations }
6672 if TextEntity in types :
6773 return TextData
68- elif VideoClassificationAnnotation in types :
74+ elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types :
6975 return VideoData
7076 else :
7177 return ImageData
@@ -83,26 +89,46 @@ def _get_consecutive_frames(
8389 def _create_video_annotations (
8490 cls , label : Label
8591 ) -> Generator [Union [NDChecklistSubclass , NDRadioSubclass ], None , None ]:
92+
8693 video_annotations = defaultdict (list )
8794 for annot in label .annotations :
88- if isinstance (annot , VideoClassificationAnnotation ):
95+ if isinstance (
96+ annot ,
97+ (VideoClassificationAnnotation , VideoObjectAnnotation )):
8998 video_annotations [annot .feature_schema_id ].append (annot )
9099
91100 for annotation_group in video_annotations .values ():
92101 consecutive_frames = cls ._get_consecutive_frames (
93102 sorted ([annotation .frame for annotation in annotation_group ]))
94- annotation = annotation_group [0 ]
95- frames_data = []
96- for frames in consecutive_frames :
97- frames_data .append ({'start' : frames [0 ], 'end' : frames [- 1 ]})
98- annotation .extra .update ({'frames' : frames_data })
99- yield NDClassification .from_common (annotation , label .data )
103+
104+ if isinstance (annotation_group [0 ], VideoClassificationAnnotation ):
105+ annotation = annotation_group [0 ]
106+ frames_data = []
107+ for frames in consecutive_frames :
108+ frames_data .append ({'start' : frames [0 ], 'end' : frames [- 1 ]})
109+ annotation .extra .update ({'frames' : frames_data })
110+ yield NDClassification .from_common (annotation , label .data )
111+
112+ elif isinstance (annotation_group [0 ], VideoObjectAnnotation ):
113+ warnings .warn (
114+ """Nested classifications are not currently supported
115+ for video object annotations
116+ and will not import alongside the object annotations.""" )
117+ segments = []
118+ for start_frame , end_frame in consecutive_frames :
119+ segment = []
120+ for annotation in annotation_group :
121+ if annotation .keyframe and start_frame <= annotation .frame <= end_frame :
122+ segment .append (annotation )
123+ segments .append (segment )
124+ yield NDObject .from_common (segments , label .data )
100125
101126 @classmethod
102127 def _create_non_video_annotations (cls , label : Label ):
103128 non_video_annotations = [
104129 annot for annot in label .annotations
105- if not isinstance (annot , VideoClassificationAnnotation )
130+ if not isinstance (annot , (VideoClassificationAnnotation ,
131+ VideoObjectAnnotation ))
106132 ]
107133 for annotation in non_video_annotations :
108134 if isinstance (annotation , ClassificationAnnotation ):
0 commit comments