Skip to content

Commit 405290f

Browse files
author
Matt Sokoloff
committed
add missing files
1 parent 1a7c21f commit 405290f

File tree

2 files changed

+315
-0
lines changed

2 files changed

+315
-0
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from typing import List, Union
2+
3+
from pydantic.main import BaseModel
4+
5+
from ...annotation_types.annotation import AnnotationType, ClassificationAnnotation
6+
from ...annotation_types.classification import Checklist, ClassificationAnswer, Radio, Text, Dropdown
7+
from .feature import LBV1Feature
8+
9+
10+
class LBV1ClassificationAnswer(LBV1Feature):
11+
...
12+
13+
14+
class LBV1Radio(LBV1Feature):
15+
answer: LBV1ClassificationAnswer
16+
17+
def to_common(self):
18+
return Radio(answer=ClassificationAnswer(
19+
schema_id=self.answer.schema_id,
20+
name=self.answer.title,
21+
extra={
22+
'feature_id': self.answer.feature_id,
23+
'value': self.answer.value
24+
}))
25+
26+
@classmethod
27+
def from_common(cls, radio: Radio, schema_id: str, **extra) -> "LBV1Radio":
28+
return cls(schema_id=schema_id,
29+
answer=LBV1ClassificationAnswer(
30+
schema_id=radio.answer.schema_id,
31+
title=radio.answer.name,
32+
value=radio.answer.extra['value'],
33+
feature_id=radio.answer.extra['feature_id']),
34+
**extra)
35+
36+
37+
class LBV1Checklist(LBV1Feature):
38+
answers: List[LBV1ClassificationAnswer]
39+
40+
def to_common(self):
41+
return Checklist(answer=[
42+
ClassificationAnswer(schema_id=answer.schema_id,
43+
name=answer.title,
44+
extra={
45+
'feature_id': answer.feature_id,
46+
'value': answer.value
47+
}) for answer in self.answers
48+
])
49+
50+
@classmethod
51+
def from_common(cls, checklist: Checklist, schema_id: str,
52+
**extra) -> "LBV1Checklist":
53+
return cls(schema_id=schema_id,
54+
answers=[
55+
LBV1ClassificationAnswer(
56+
schema_id=answer.schema_id,
57+
title=answer.name,
58+
value=answer.extra['value'],
59+
feature_id=answer.extra['feature_id'])
60+
for answer in checklist.answer
61+
],
62+
**extra)
63+
64+
65+
class LBV1Text(LBV1Feature):
66+
answer: str
67+
68+
def to_common(self):
69+
return Text(answer=self.answer)
70+
71+
@classmethod
72+
def from_common(cls, text: Text, schema_id: str, **extra) -> "LBV1Text":
73+
return cls(schema_id=schema_id, answer=text.answer, **extra)
74+
75+
76+
classification_mapping = {
77+
Text: LBV1Text,
78+
Dropdown: LBV1Checklist,
79+
Checklist: LBV1Checklist,
80+
Radio: LBV1Radio
81+
}
82+
83+
84+
class LBV1Classifications(BaseModel):
85+
classifications: List[Union[LBV1Radio, LBV1Checklist, LBV1Text]] = []
86+
87+
def to_common(self):
88+
classifications = [
89+
ClassificationAnnotation(value=classification.to_common(),
90+
classifications=[],
91+
name=classification.title,
92+
extra={
93+
'value': classification.value,
94+
'feature_id': classification.feature_id
95+
})
96+
for classification in self.classifications
97+
]
98+
return classifications
99+
100+
@classmethod
101+
def from_common(cls,
102+
annotations: List[AnnotationType]) -> "LBV1Classifications":
103+
classifications = []
104+
for annotation in annotations:
105+
classification = classification_mapping.get(type(annotation.value))
106+
if classification is not None:
107+
classifications.append(
108+
classification.from_common(annotation.value,
109+
annotation.schema_id,
110+
**annotation.extra))
111+
else:
112+
raise TypeError(f"Unexpected type {type(annotation.value)}")
113+
return cls(classifications=classifications)
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from typing import List, Union, Optional
2+
3+
from pydantic import BaseModel, validator, Field
4+
5+
from ...annotation_types.annotation import AnnotationType, ClassificationAnnotation, VideoClassificationAnnotation
6+
from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio
7+
from .base import NDAnnotation
8+
9+
10+
class NDFeature(BaseModel):
11+
schema_id: str = Field(..., alias='schemaId')
12+
13+
@validator('schema_id', pre=True, always=True)
14+
def validate_id(cls, v):
15+
if v is None:
16+
raise ValueError(
17+
"Schema ids are not set. Use `LabelGenerator.assign_schema_ids`, `LabelCollection.assign_schema_ids`, or `Label.assign_schema_ids`."
18+
)
19+
return v
20+
21+
class Config:
22+
allow_population_by_field_name = True
23+
24+
25+
class FrameLocation(BaseModel):
26+
end: int
27+
start: int
28+
29+
30+
class VideoSupported(BaseModel):
31+
#Note that frames are only allowed as top level inferences for video
32+
frames: Optional[List[FrameLocation]] = None
33+
34+
def dict(self, *args, **kwargs):
35+
res = super().dict(*args, **kwargs)
36+
# This means these are no video frames ..
37+
if self.frames is None:
38+
res.pop('frames')
39+
return res
40+
41+
42+
class NDTextSubclass(NDFeature):
43+
answer: str
44+
45+
def to_common(self) -> Text:
46+
return Text(answer=self.answer)
47+
48+
@classmethod
49+
def from_common(cls, annotation) -> "NDTextSubclass":
50+
return cls(answer=annotation.value.answer,
51+
schema_id=annotation.schema_id)
52+
53+
54+
class NDChecklistSubclass(NDFeature):
55+
answer: List[NDFeature]
56+
57+
def to_common(self) -> Checklist:
58+
return Checklist(answer=[
59+
ClassificationAnswer(schema_id=answer.schema_id)
60+
for answer in self.answer
61+
])
62+
63+
@classmethod
64+
def from_common(cls, annotation) -> "NDChecklistSubclass":
65+
return cls(answer=[
66+
NDFeature(schema_id=answer.schema_id)
67+
for answer in annotation.value.answer
68+
],
69+
schema_id=annotation.schema_id)
70+
71+
72+
class NDRadioSubclass(NDFeature):
73+
answer: NDFeature
74+
75+
def to_common(self) -> Radio:
76+
return Radio(answer=ClassificationAnswer(
77+
schema_id=self.answer.schema_id))
78+
79+
@classmethod
80+
def from_common(cls, annotation) -> "NDRadioSubclass":
81+
return cls(
82+
answer=NDFeature(schema_id=annotation.value.answer.schema_id),
83+
schema_id=annotation.schema_id)
84+
85+
86+
### ====== End of subclasses
87+
88+
89+
class NDText(NDAnnotation, NDTextSubclass):
90+
91+
@classmethod
92+
def from_common(cls, annotation, data) -> "NDText":
93+
return cls(
94+
answer=annotation.value.answer,
95+
dataRow={'id': data.uid},
96+
schema_id=annotation.schema_id,
97+
uuid=annotation.extra.get('uuid'),
98+
)
99+
100+
101+
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):
102+
103+
@classmethod
104+
def from_common(cls, annotation, data) -> "NDChecklist":
105+
return cls(answer=[
106+
NDFeature(schema_id=answer.schema_id)
107+
for answer in annotation.value.answer
108+
],
109+
dataRow={'id': data.uid},
110+
schema_id=annotation.schema_id,
111+
uuid=annotation.extra.get('uuid'),
112+
frames=annotation.extra.get('frames'))
113+
114+
115+
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported):
116+
117+
@classmethod
118+
def from_common(cls, annotation, data) -> "NDRadio":
119+
return cls(
120+
answer=NDFeature(schema_id=annotation.value.answer.schema_id),
121+
dataRow={'id': data.uid},
122+
schema_id=annotation.schema_id,
123+
uuid=annotation.extra.get('uuid'),
124+
frames=annotation.extra.get('frames'))
125+
126+
127+
class NDSubclassification:
128+
# TODO: Create a type for these unions
129+
@classmethod
130+
def from_common(
131+
cls, annotation: AnnotationType
132+
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]:
133+
classify_obj = cls.lookup_subclassification(annotation)
134+
if classify_obj is None:
135+
raise TypeError(
136+
f"Unable to convert object to MAL format. `{type(annotation.value)}`"
137+
)
138+
return classify_obj.from_common(annotation)
139+
140+
@staticmethod
141+
def to_common(annotation: AnnotationType) -> ClassificationAnnotation:
142+
return ClassificationAnnotation(value=annotation.to_common(),
143+
schema_id=annotation.schema_id)
144+
145+
@staticmethod
146+
def lookup_subclassification(annotation: AnnotationType):
147+
if isinstance(annotation, Dropdown):
148+
raise TypeError("Dropdowns are not supported for MAL")
149+
return {
150+
Text: NDTextSubclass,
151+
Checklist: NDChecklistSubclass,
152+
Radio: NDRadioSubclass,
153+
}.get(type(annotation.value))
154+
155+
156+
class NDClassification:
157+
158+
@staticmethod
159+
def to_common(
160+
annotation: AnnotationType
161+
) -> Union[ClassificationAnnotation, VideoClassificationAnnotation]:
162+
common = ClassificationAnnotation(value=annotation.to_common(),
163+
schema_id=annotation.schema_id,
164+
extra={'uuid': annotation.uuid})
165+
if getattr(annotation, 'frames', None) is None:
166+
return [common]
167+
results = []
168+
for frame in annotation.frames:
169+
for idx in range(frame.start, frame.end + 1, 1):
170+
results.append(
171+
VideoClassificationAnnotation(frame=idx, **common.dict()))
172+
return results
173+
174+
@classmethod
175+
def from_common(
176+
cls, annotation: AnnotationType, data
177+
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]:
178+
classify_obj = cls.lookup_classification(annotation)
179+
if classify_obj is None:
180+
raise TypeError(
181+
f"Unable to convert object to MAL format. `{type(annotation.value)}`"
182+
)
183+
if len(annotation.classifications):
184+
raise ValueError(
185+
"Nested classifications not supported by this format")
186+
return classify_obj.from_common(annotation, data)
187+
188+
@staticmethod
189+
def lookup_classification(annotation: AnnotationType):
190+
if isinstance(annotation, Dropdown):
191+
raise TypeError("Dropdowns are not supported for MAL")
192+
return {
193+
Text: NDText,
194+
Checklist: NDChecklist,
195+
Radio: NDRadio,
196+
Dropdown: NDChecklist,
197+
}.get(type(annotation.value))
198+
199+
200+
NDSubclassificationType = Union[NDRadioSubclass, NDChecklistSubclass,
201+
NDTextSubclass]
202+
NDClassificationType = Union[NDRadio, NDChecklist, NDText]

0 commit comments

Comments
 (0)