Skip to content

Commit cbc228f

Browse files
author
Matt Sokoloff
committed
add missing type annotations
1 parent 405290f commit cbc228f

File tree

14 files changed

+243
-197
lines changed

14 files changed

+243
-197
lines changed

labelbox/data/annotation_types/data/raster.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def validate_args(cls, values):
118118
return values
119119

120120
class Config:
121-
# TODO: Create a type for numpy arrays
121+
# Required for numpy arrays
122122
arbitrary_types_allowed = True
123+
# Required for sharing references
123124
copy_on_model_validation = False
125+
# Required for discriminating between data types
124126
extra = 'forbid'

labelbox/data/annotation_types/data/text.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,5 @@ def validate_date(cls, values):
7171
return values
7272

7373
class config:
74+
# Required for discriminating between data types
7475
extra = 'forbid'

labelbox/data/annotation_types/data/video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def frame_generator(
5656
logger.info("Downloading the video locally to %s", file_path)
5757
urllib.request.urlretrieve(self.url, file_path)
5858
self.file_path = file_path
59-
# TODO: If the filepath exists but there was no data we should use the url (and the opposite too)
6059

6160
vidcap = cv2.VideoCapture(self.file_path)
6261

@@ -135,6 +134,7 @@ def validate_data(cls, values):
135134
return values
136135

137136
class Config:
138-
# TODO: Create numpy array type
137+
# Required for numpy arrays
139138
arbitrary_types_allowed = True
139+
# Required for discriminating between data types
140140
extra = 'forbid'

labelbox/data/annotation_types/geometry/polygon.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Any, Dict, List
1+
from typing import List
22

33
import numpy as np
44
import geojson
55
import cv2
6-
from pydantic import ValidationError, validator
6+
from pydantic import validator
77

88
from .point import Point
99
from .geometry import Geometry
@@ -36,7 +36,7 @@ def raster(self, height: int, width: int, color=255) -> np.ndarray:
3636
@validator('points')
3737
def is_geom_valid(cls, points):
3838
if len(points) < 3:
39-
raise ValidationError(
39+
raise ValueError(
4040
f"A polygon must have at least 3 points to be valid. Found {points}"
4141
)
4242
return points

labelbox/data/annotation_types/geometry/rectangle.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from typing import Any, Dict
2-
31
import cv2
42
import geojson
53
import numpy as np
4+
65
from .geometry import Geometry
76
from .point import Point
87

@@ -38,5 +37,3 @@ def raster(self, height: int, width: int, color: int = 255) -> np.ndarray:
3837
canvas = np.zeros((height, width), dtype=np.uint8)
3938
pts = np.array(self.geometry['coordinates']).astype(np.int32)
4039
return cv2.fillPoly(canvas, pts=pts, color=color)
41-
42-
# TODO: Validate the start points are less than the end points

labelbox/data/annotation_types/ner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, Any
22

3-
from pydantic import BaseModel, root_validator, ValidationError
3+
from pydantic import BaseModel, root_validator
44

55

66
class TextEntity(BaseModel):

labelbox/data/serialization/labelbox_v1/classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def to_common(self):
9898
return classifications
9999

100100
@classmethod
101-
def from_common(cls,
102-
annotations: List[AnnotationType]) -> "LBV1Classifications":
101+
def from_common(
102+
cls, annotations: List[ClassificationAnnotation]
103+
) -> "LBV1Classifications":
103104
classifications = []
104105
for annotation in annotations:
105106
classification = classification_mapping.get(type(annotation.value))

labelbox/data/serialization/labelbox_v1/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, Generator, Iterable, Optional
1+
from typing import Any, Callable, Dict, Generator, Iterable
22
import logging
33

44
import ndjson

labelbox/data/serialization/labelbox_v1/label.py

Lines changed: 66 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def to_common(self):
3939
# Labelbox doesn't support subclasses on image level classifications
4040
# These are added to top level classifications
4141
classifications=[],
42-
#keyframe = classification.keyframe,
4342
frame=self.frame_number,
4443
name=classification.title)
4544
for classification in self.classifications
@@ -58,7 +57,7 @@ def to_common(self):
5857
'feature_id': cls.feature_id,
5958
'title': cls.title,
6059
'value': cls.value,
61-
'keyframe': getattr(cls, 'keyframe', None)
60+
#'keyframe': getattr(cls, 'keyframe', None)
6261
}) for cls in obj.classifications
6362
],
6463
name=obj.title,
@@ -111,76 +110,63 @@ class LBV1Label(BaseModel):
111110
data_row_id: str = Field(..., alias="DataRow ID")
112111
row_data: str = Field(..., alias="Labeled Data")
113112
external_id: Optional[str] = Field(None, alias="External ID")
114-
created_by: Optional[str] = Field(None, alias='Created By')
115-
116-
id: Optional[str] = Field(None, alias='ID')
117-
project_name: Optional[str] = Field(None, alias='Project Name')
118-
created_at: Optional[str] = Field(None, alias='Created At')
119-
updated_at: Optional[str] = Field(None, alias='Updated At')
120-
seconds_to_label: Optional[float] = Field(None, alias='Seconds to Label')
121-
agreement: Optional[float] = Field(None, alias='Agreement')
122-
benchmark_agreement: Optional[float] = Field(None,
123-
alias='Benchmark Agreement')
124-
benchmark_id: Optional[float] = Field(None, alias='Benchmark ID')
125-
dataset_name: Optional[str] = Field(None, alias='Dataset Name')
126-
reviews: Optional[List[Review]] = Field(None, alias='Reviews')
127-
label_url: Optional[str] = Field(None, alias='View Label')
128-
has_open_issues: Optional[float] = Field(None, alias='Has Open Issues')
129-
skipped: Optional[bool] = Field(None, alias='Skipped')
130-
131-
def construct_data_ref(self, is_video):
132-
# TODO: Let users specify the type ...
133-
keys = {'external_id': self.external_id, 'uid': self.data_row_id}
134113

135-
if is_video:
136-
return VideoData(url=self.row_data, **keys)
137-
if any([x in self.row_data for x in (".jpg", ".png", ".jpeg")
138-
]) and self.row_data.startswith(("http://", "https://")):
139-
return RasterData(url=self.row_data, **keys)
140-
elif any([x in self.row_data for x in (".txt", ".text", ".html")
141-
]) and self.row_data.startswith(("http://", "https://")):
142-
return TextData(url=self.row_data, **keys)
143-
elif isinstance(self.row_data, str):
144-
return TextData(text=self.row_data, **keys)
145-
elif len([
146-
annotation for annotation in self.label.objects
147-
if isinstance(annotation, TextEntity)
148-
]):
149-
return TextData(url=self.row_data, **keys)
150-
else:
151-
raise TypeError("Can't infer data type from row data.")
114+
created_by: Optional[str] = Field(None,
115+
alias='Created By',
116+
extra_field=True)
117+
project_name: Optional[str] = Field(None,
118+
alias='Project Name',
119+
extra_field=True)
120+
id: Optional[str] = Field(None, alias='ID', extra_field=True)
121+
created_at: Optional[str] = Field(None,
122+
alias='Created At',
123+
extra_field=True)
124+
updated_at: Optional[str] = Field(None,
125+
alias='Updated At',
126+
extra_field=True)
127+
seconds_to_label: Optional[float] = Field(None,
128+
alias='Seconds to Label',
129+
extra_field=True)
130+
agreement: Optional[float] = Field(None,
131+
alias='Agreement',
132+
extra_field=True)
133+
benchmark_agreement: Optional[float] = Field(None,
134+
alias='Benchmark Agreement',
135+
extra_field=True)
136+
benchmark_id: Optional[float] = Field(None,
137+
alias='Benchmark ID',
138+
extra_field=True)
139+
dataset_name: Optional[str] = Field(None,
140+
alias='Dataset Name',
141+
extra_field=True)
142+
reviews: Optional[List[Review]] = Field(None,
143+
alias='Reviews',
144+
extra_field=True)
145+
label_url: Optional[str] = Field(None, alias='View Label', extra_field=True)
146+
has_open_issues: Optional[float] = Field(None,
147+
alias='Has Open Issues',
148+
extra_field=True)
149+
skipped: Optional[bool] = Field(None, alias='Skipped', extra_field=True)
152150

153151
def to_common(self) -> Label:
154-
is_video = False
155152
if isinstance(self.label, list):
156153
annotations = []
157154
for lbl in self.label:
158155
annotations.extend(lbl.to_common())
159-
is_video = True
156+
data = VideoData(url=self.row_data,
157+
external_id=self.external_id,
158+
uid=self.data_row_id)
160159
else:
161160
annotations = self.label.to_common()
161+
data = self._infer_media_type()
162162

163-
return Label(
164-
data=self.construct_data_ref(is_video),
165-
annotations=annotations,
166-
extra={
167-
'Created By': self.created_by,
168-
'Project Name': self.project_name,
169-
'ID': self.id,
170-
'Created At': self.created_at,
171-
'Updated At': self.updated_at,
172-
'Seconds to Label': self.seconds_to_label,
173-
'Agreement': self.agreement,
174-
'Benchmark Agreement': self.benchmark_agreement,
175-
'Benchmark ID': self.benchmark_id,
176-
'Dataset Name': self.dataset_name,
177-
'Reviews': [
178-
review.dict(by_alias=True) for review in self.reviews
179-
],
180-
'View Label': self.label_url,
181-
'Has Open Issues': self.has_open_issues,
182-
'Skipped': self.skipped
183-
})
163+
return Label(data=data,
164+
annotations=annotations,
165+
extra={
166+
field.alias: getattr(self, field_name)
167+
for field_name, field in self.__fields__.items()
168+
if field.field_info.extra.get('extra_field')
169+
})
184170

185171
@classmethod
186172
def from_common(cls, label: Label, signer: Callable[[bytes], str]):
@@ -196,5 +182,23 @@ def from_common(cls, label: Label, signer: Callable[[bytes], str]):
196182
external_id=label.data.external_id,
197183
**label.extra)
198184

185+
def _infer_media_type(self):
186+
keys = {'external_id': self.external_id, 'uid': self.data_row_id}
187+
if any([x in self.row_data for x in (".jpg", ".png", ".jpeg")
188+
]) and self.row_data.startswith(("http://", "https://")):
189+
return RasterData(url=self.row_data, **keys)
190+
elif any([x in self.row_data for x in (".txt", ".text", ".html")
191+
]) and self.row_data.startswith(("http://", "https://")):
192+
return TextData(url=self.row_data, **keys)
193+
elif isinstance(self.row_data, str):
194+
return TextData(text=self.row_data, **keys)
195+
elif len([
196+
annotation for annotation in self.label.objects
197+
if isinstance(annotation, TextEntity)
198+
]):
199+
return TextData(url=self.row_data, **keys)
200+
else:
201+
raise TypeError("Can't infer data type from row data.")
202+
199203
class Config:
200204
allow_population_by_field_name = True

labelbox/data/serialization/labelbox_v1/objects.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from pydantic import BaseModel
44

5-
from ...annotation_types.annotation import (AnnotationType,
6-
ClassificationAnnotation,
5+
from ...annotation_types.annotation import (ClassificationAnnotation,
76
ObjectAnnotation)
87
from ...annotation_types.data import RasterData
98
from ...annotation_types.geometry import Line, Mask, Point, Polygon, Rectangle
@@ -174,16 +173,6 @@ def from_common(cls, text_entity: TextEntity,
174173
**extra)
175174

176175

177-
object_mapping = {
178-
Line: LBV1Line,
179-
Point: LBV1Point,
180-
Polygon: LBV1Polygon,
181-
Rectangle: LBV1Rectangle,
182-
Mask: LBV1Mask,
183-
TextEntity: LBV1TextEntity
184-
}
185-
186-
187176
class LBV1Objects(BaseModel):
188177
objects: List[Union[LBV1Line, LBV1Point, LBV1Polygon, LBV1Rectangle,
189178
LBV1TextEntity, LBV1Mask]]
@@ -215,23 +204,37 @@ def to_common(self) -> List[ObjectAnnotation]:
215204
return objects
216205

217206
@classmethod
218-
def from_common(cls, annotations: List[AnnotationType]) -> "LBV1Objects":
207+
def from_common(cls, annotations: List[ObjectAnnotation]) -> "LBV1Objects":
219208
objects = []
220-
221209
for annotation in annotations:
222-
obj = object_mapping.get(type(annotation.value))
223-
if obj is not None:
224-
subclasses = []
225-
subclasses = LBV1Classifications.from_common(
226-
annotation.classifications).classifications
227-
228-
objects.append(
229-
obj.from_common(
230-
annotation.value, subclasses, annotation.schema_id,
231-
annotation.name, {
232-
'keyframe': getattr(annotation, 'keyframe', None),
233-
**annotation.extra
234-
}))
235-
else:
236-
raise TypeError(f"Unexpected type {type(annotation.value)}")
210+
obj = cls.lookup_object(annotation)
211+
subclasses = []
212+
subclasses = LBV1Classifications.from_common(
213+
annotation.classifications).classifications
214+
215+
objects.append(
216+
obj.from_common(
217+
annotation.value, subclasses, annotation.schema_id,
218+
annotation.name, {
219+
'keyframe': getattr(annotation, 'keyframe', None),
220+
**annotation.extra
221+
}))
237222
return cls(objects=objects)
223+
224+
@staticmethod
225+
def lookup_object(annotation: ObjectAnnotation) -> "LBV1ObjectType":
226+
result = {
227+
Line: LBV1Line,
228+
Point: LBV1Point,
229+
Polygon: LBV1Polygon,
230+
Rectangle: LBV1Rectangle,
231+
Mask: LBV1Mask,
232+
TextEntity: LBV1TextEntity
233+
}.get(type(annotation.value))
234+
if result is None:
235+
raise TypeError(f"Unexpected type {type(annotation.value)}")
236+
return result
237+
238+
239+
LBV1ObjectType = Union[LBV1Line, LBV1Point, LBV1Polygon, LBV1Rectangle,
240+
LBV1Mask, LBV1TextEntity]

0 commit comments

Comments
 (0)