|
| 1 | +from typing import List, Union, Optional |
| 2 | + |
| 3 | +from pydantic import BaseModel, validator, Field |
| 4 | + |
| 5 | +from ...annotation_types.annotation import AnnotationType, ClassificationAnnotation, VideoClassificationAnnotation |
| 6 | +from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio |
| 7 | +from .base import NDAnnotation |
| 8 | + |
| 9 | + |
| 10 | +class NDFeature(BaseModel): |
| 11 | + schema_id: str = Field(..., alias='schemaId') |
| 12 | + |
| 13 | + @validator('schema_id', pre=True, always=True) |
| 14 | + def validate_id(cls, v): |
| 15 | + if v is None: |
| 16 | + raise ValueError( |
| 17 | + "Schema ids are not set. Use `LabelGenerator.assign_schema_ids`, `LabelCollection.assign_schema_ids`, or `Label.assign_schema_ids`." |
| 18 | + ) |
| 19 | + return v |
| 20 | + |
| 21 | + class Config: |
| 22 | + allow_population_by_field_name = True |
| 23 | + |
| 24 | + |
| 25 | +class FrameLocation(BaseModel): |
| 26 | + end: int |
| 27 | + start: int |
| 28 | + |
| 29 | + |
| 30 | +class VideoSupported(BaseModel): |
| 31 | + #Note that frames are only allowed as top level inferences for video |
| 32 | + frames: Optional[List[FrameLocation]] = None |
| 33 | + |
| 34 | + def dict(self, *args, **kwargs): |
| 35 | + res = super().dict(*args, **kwargs) |
| 36 | + # This means these are no video frames .. |
| 37 | + if self.frames is None: |
| 38 | + res.pop('frames') |
| 39 | + return res |
| 40 | + |
| 41 | + |
| 42 | +class NDTextSubclass(NDFeature): |
| 43 | + answer: str |
| 44 | + |
| 45 | + def to_common(self) -> Text: |
| 46 | + return Text(answer=self.answer) |
| 47 | + |
| 48 | + @classmethod |
| 49 | + def from_common(cls, annotation) -> "NDTextSubclass": |
| 50 | + return cls(answer=annotation.value.answer, |
| 51 | + schema_id=annotation.schema_id) |
| 52 | + |
| 53 | + |
| 54 | +class NDChecklistSubclass(NDFeature): |
| 55 | + answer: List[NDFeature] |
| 56 | + |
| 57 | + def to_common(self) -> Checklist: |
| 58 | + return Checklist(answer=[ |
| 59 | + ClassificationAnswer(schema_id=answer.schema_id) |
| 60 | + for answer in self.answer |
| 61 | + ]) |
| 62 | + |
| 63 | + @classmethod |
| 64 | + def from_common(cls, annotation) -> "NDChecklistSubclass": |
| 65 | + return cls(answer=[ |
| 66 | + NDFeature(schema_id=answer.schema_id) |
| 67 | + for answer in annotation.value.answer |
| 68 | + ], |
| 69 | + schema_id=annotation.schema_id) |
| 70 | + |
| 71 | + |
| 72 | +class NDRadioSubclass(NDFeature): |
| 73 | + answer: NDFeature |
| 74 | + |
| 75 | + def to_common(self) -> Radio: |
| 76 | + return Radio(answer=ClassificationAnswer( |
| 77 | + schema_id=self.answer.schema_id)) |
| 78 | + |
| 79 | + @classmethod |
| 80 | + def from_common(cls, annotation) -> "NDRadioSubclass": |
| 81 | + return cls( |
| 82 | + answer=NDFeature(schema_id=annotation.value.answer.schema_id), |
| 83 | + schema_id=annotation.schema_id) |
| 84 | + |
| 85 | + |
| 86 | +### ====== End of subclasses |
| 87 | + |
| 88 | + |
| 89 | +class NDText(NDAnnotation, NDTextSubclass): |
| 90 | + |
| 91 | + @classmethod |
| 92 | + def from_common(cls, annotation, data) -> "NDText": |
| 93 | + return cls( |
| 94 | + answer=annotation.value.answer, |
| 95 | + dataRow={'id': data.uid}, |
| 96 | + schema_id=annotation.schema_id, |
| 97 | + uuid=annotation.extra.get('uuid'), |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported): |
| 102 | + |
| 103 | + @classmethod |
| 104 | + def from_common(cls, annotation, data) -> "NDChecklist": |
| 105 | + return cls(answer=[ |
| 106 | + NDFeature(schema_id=answer.schema_id) |
| 107 | + for answer in annotation.value.answer |
| 108 | + ], |
| 109 | + dataRow={'id': data.uid}, |
| 110 | + schema_id=annotation.schema_id, |
| 111 | + uuid=annotation.extra.get('uuid'), |
| 112 | + frames=annotation.extra.get('frames')) |
| 113 | + |
| 114 | + |
| 115 | +class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported): |
| 116 | + |
| 117 | + @classmethod |
| 118 | + def from_common(cls, annotation, data) -> "NDRadio": |
| 119 | + return cls( |
| 120 | + answer=NDFeature(schema_id=annotation.value.answer.schema_id), |
| 121 | + dataRow={'id': data.uid}, |
| 122 | + schema_id=annotation.schema_id, |
| 123 | + uuid=annotation.extra.get('uuid'), |
| 124 | + frames=annotation.extra.get('frames')) |
| 125 | + |
| 126 | + |
| 127 | +class NDSubclassification: |
| 128 | + # TODO: Create a type for these unions |
| 129 | + @classmethod |
| 130 | + def from_common( |
| 131 | + cls, annotation: AnnotationType |
| 132 | + ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: |
| 133 | + classify_obj = cls.lookup_subclassification(annotation) |
| 134 | + if classify_obj is None: |
| 135 | + raise TypeError( |
| 136 | + f"Unable to convert object to MAL format. `{type(annotation.value)}`" |
| 137 | + ) |
| 138 | + return classify_obj.from_common(annotation) |
| 139 | + |
| 140 | + @staticmethod |
| 141 | + def to_common(annotation: AnnotationType) -> ClassificationAnnotation: |
| 142 | + return ClassificationAnnotation(value=annotation.to_common(), |
| 143 | + schema_id=annotation.schema_id) |
| 144 | + |
| 145 | + @staticmethod |
| 146 | + def lookup_subclassification(annotation: AnnotationType): |
| 147 | + if isinstance(annotation, Dropdown): |
| 148 | + raise TypeError("Dropdowns are not supported for MAL") |
| 149 | + return { |
| 150 | + Text: NDTextSubclass, |
| 151 | + Checklist: NDChecklistSubclass, |
| 152 | + Radio: NDRadioSubclass, |
| 153 | + }.get(type(annotation.value)) |
| 154 | + |
| 155 | + |
| 156 | +class NDClassification: |
| 157 | + |
| 158 | + @staticmethod |
| 159 | + def to_common( |
| 160 | + annotation: AnnotationType |
| 161 | + ) -> Union[ClassificationAnnotation, VideoClassificationAnnotation]: |
| 162 | + common = ClassificationAnnotation(value=annotation.to_common(), |
| 163 | + schema_id=annotation.schema_id, |
| 164 | + extra={'uuid': annotation.uuid}) |
| 165 | + if getattr(annotation, 'frames', None) is None: |
| 166 | + return [common] |
| 167 | + results = [] |
| 168 | + for frame in annotation.frames: |
| 169 | + for idx in range(frame.start, frame.end + 1, 1): |
| 170 | + results.append( |
| 171 | + VideoClassificationAnnotation(frame=idx, **common.dict())) |
| 172 | + return results |
| 173 | + |
| 174 | + @classmethod |
| 175 | + def from_common( |
| 176 | + cls, annotation: AnnotationType, data |
| 177 | + ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: |
| 178 | + classify_obj = cls.lookup_classification(annotation) |
| 179 | + if classify_obj is None: |
| 180 | + raise TypeError( |
| 181 | + f"Unable to convert object to MAL format. `{type(annotation.value)}`" |
| 182 | + ) |
| 183 | + if len(annotation.classifications): |
| 184 | + raise ValueError( |
| 185 | + "Nested classifications not supported by this format") |
| 186 | + return classify_obj.from_common(annotation, data) |
| 187 | + |
| 188 | + @staticmethod |
| 189 | + def lookup_classification(annotation: AnnotationType): |
| 190 | + if isinstance(annotation, Dropdown): |
| 191 | + raise TypeError("Dropdowns are not supported for MAL") |
| 192 | + return { |
| 193 | + Text: NDText, |
| 194 | + Checklist: NDChecklist, |
| 195 | + Radio: NDRadio, |
| 196 | + Dropdown: NDChecklist, |
| 197 | + }.get(type(annotation.value)) |
| 198 | + |
| 199 | + |
| 200 | +NDSubclassificationType = Union[NDRadioSubclass, NDChecklistSubclass, |
| 201 | + NDTextSubclass] |
| 202 | +NDClassificationType = Union[NDRadio, NDChecklist, NDText] |
0 commit comments