Skip to content

Commit 44f843a

Browse files
Merge pull request #986 from Labelbox/kkim/AL-5172
2 parents 857607e + 2ba04b9 commit 44f843a

23 files changed

+1338
-69
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from labelbox.schema.project import Project
66
from labelbox.schema.model import Model
77
from labelbox.schema.bulk_import_request import BulkImportRequest
8-
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport
8+
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport, MEAToMALPredictionImport
99
from labelbox.schema.dataset import Dataset
1010
from labelbox.schema.data_row import DataRow
1111
from labelbox.schema.label import Label

labelbox/data/annotation_types/data/base_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ class BaseData(BaseModel, ABC):
1111
"""
1212
external_id: Optional[str] = None
1313
uid: Optional[str] = None
14+
global_key: Optional[str] = None
1415
media_attributes: Optional[Dict[str, Any]] = None
1516
metadata: Optional[List[Dict[str, Any]]] = None

labelbox/data/annotation_types/data/raster.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,10 @@ def validate_args(cls, values):
163163
url = values.get("url")
164164
arr = values.get("arr")
165165
uid = values.get('uid')
166-
if uid == file_path == im_bytes == url == None and arr is None:
166+
global_key = values.get('global_key')
167+
if uid == file_path == im_bytes == url == global_key == None and arr is None:
167168
raise ValueError(
168-
"One of `file_path`, `im_bytes`, `url`, `uid` or `arr` required."
169+
"One of `file_path`, `im_bytes`, `url`, `uid`, `global_key` or `arr` required."
169170
)
170171
if arr is not None:
171172
if arr.dtype != np.uint8:

labelbox/data/annotation_types/data/text.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,11 @@ def validate_date(cls, values):
9393
text = values.get("text")
9494
url = values.get("url")
9595
uid = values.get('uid')
96-
if uid == file_path == text == url == None:
96+
global_key = values.get('global_key')
97+
if uid == file_path == text == url == global_key == None:
9798
raise ValueError(
98-
"One of `file_path`, `text`, `uid`, or `url` required.")
99+
"One of `file_path`, `text`, `uid`, `global_key` or `url` required."
100+
)
99101
return values
100102

101103
def __repr__(self) -> str:

labelbox/data/annotation_types/data/video.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,12 @@ def validate_data(cls, values):
153153
url = values.get("url")
154154
frames = values.get("frames")
155155
uid = values.get("uid")
156+
global_key = values.get("global_key")
156157

157-
if uid == file_path == frames == url == None:
158+
if uid == file_path == frames == url == global_key == None:
158159
raise ValueError(
159-
"One of `file_path`, `frames`, `uid`, or `url` required.")
160+
"One of `file_path`, `frames`, `uid`, `global_key` or `url` required."
161+
)
160162
return values
161163

162164
def __repr__(self) -> str:

labelbox/data/serialization/ndjson/base.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,37 @@
22
from uuid import uuid4
33
from pydantic import BaseModel, root_validator, validator, Field
44

5-
from labelbox.utils import camel_case
5+
from labelbox.utils import _CamelCaseMixin, camel_case, is_exactly_one_set
66
from ...annotation_types.types import Cuid
77

88

9-
class DataRow(BaseModel):
9+
class DataRow(_CamelCaseMixin):
1010
id: str = None
11+
global_key: str = None
1112

12-
@validator('id', pre=True, always=True)
13-
def validate_id(cls, v):
14-
if v is None:
15-
raise ValueError(
16-
"Data row ids are not set. Use `LabelGenerator.add_to_dataset`,or `Label.create_data_row`. "
17-
"You can also manually assign the id for each `BaseData` object"
18-
)
19-
return v
13+
@root_validator()
14+
def must_set_one(cls, values):
15+
if not is_exactly_one_set(values.get('id'), values.get('global_key')):
16+
raise ValueError("Must set either id or global_key")
17+
return values
2018

2119

22-
class NDJsonBase(BaseModel):
20+
class NDJsonBase(_CamelCaseMixin):
2321
uuid: str = None
2422
data_row: DataRow
2523

2624
@validator('uuid', pre=True, always=True)
2725
def set_id(cls, v):
2826
return v or str(uuid4())
2927

30-
class Config:
31-
allow_population_by_field_name = True
32-
alias_generator = camel_case
28+
def dict(self, *args, **kwargs):
29+
""" Pop missing id or missing globalKey from dataRow """
30+
res = super().dict(*args, **kwargs)
31+
if not self.data_row.id:
32+
res['dataRow'].pop('id')
33+
if not self.data_row.global_key:
34+
res['dataRow'].pop('globalKey')
35+
return res
3336

3437

3538
class NDAnnotation(NDJsonBase):

labelbox/data/serialization/ndjson/classification.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio
99
from ...annotation_types.types import Cuid
1010
from ...annotation_types.data import TextData, VideoData, ImageData
11-
from .base import NDAnnotation
11+
from .base import DataRow, NDAnnotation
1212

1313

1414
class NDFeature(ConfidenceMixin):
@@ -125,7 +125,7 @@ def from_common(cls, text: Text, name: str, feature_schema_id: Cuid,
125125
ImageData]) -> "NDText":
126126
return cls(
127127
answer=text.answer,
128-
data_row={'id': data.uid},
128+
data_row=DataRow(id=data.uid, global_key=data.global_key),
129129
name=name,
130130
schema_id=feature_schema_id,
131131
uuid=extra.get('uuid'),
@@ -145,7 +145,7 @@ def from_common(
145145
confidence=answer.confidence)
146146
for answer in checklist.answer
147147
],
148-
data_row={'id': data.uid},
148+
data_row=DataRow(id=data.uid, global_key=data.global_key),
149149
name=name,
150150
schema_id=feature_schema_id,
151151
uuid=extra.get('uuid'),
@@ -161,7 +161,7 @@ def from_common(cls, radio: Radio, name: str, feature_schema_id: Cuid,
161161
return cls(answer=NDFeature(name=radio.answer.name,
162162
schema_id=radio.answer.feature_schema_id,
163163
confidence=radio.answer.confidence),
164-
data_row={'id': data.uid},
164+
data_row=DataRow(id=data.uid, global_key=data.global_key),
165165
name=name,
166166
schema_id=feature_schema_id,
167167
uuid=extra.get('uuid'),

labelbox/data/serialization/ndjson/label.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from itertools import groupby
22
from operator import itemgetter
3-
from typing import Dict, Generator, List, Tuple, Union
3+
from typing import Dict, Generator, List, Optional, Tuple, Union
44
from collections import defaultdict
55
import warnings
66

@@ -17,6 +17,7 @@
1717
from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric
1818
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass
1919
from .objects import NDObject, NDObjectType, NDSegments
20+
from .base import DataRow
2021

2122

2223
class NDLabel(BaseModel):
@@ -27,7 +28,9 @@ class NDLabel(BaseModel):
2728
def to_common(self) -> LabelGenerator:
2829
grouped_annotations = defaultdict(list)
2930
for annotation in self.annotations:
30-
grouped_annotations[annotation.data_row.id].append(annotation)
31+
grouped_annotations[annotation.data_row.id or
32+
annotation.data_row.global_key].append(
33+
annotation)
3134
return LabelGenerator(
3235
data=self._generate_annotations(grouped_annotations))
3336

@@ -45,9 +48,11 @@ def _generate_annotations(
4548
NDConfusionMatrixMetric,
4649
NDScalarMetric, NDSegments]]]
4750
) -> Generator[Label, None, None]:
48-
for data_row_id, annotations in grouped_annotations.items():
51+
for _, annotations in grouped_annotations.items():
4952
annots = []
53+
data_row = annotations[0].data_row
5054
for annotation in annotations:
55+
5156
if isinstance(annotation, NDSegments):
5257
annots.extend(
5358
NDSegments.to_common(annotation, annotation.name,
@@ -62,22 +67,30 @@ def _generate_annotations(
6267
else:
6368
raise TypeError(
6469
f"Unsupported annotation. {type(annotation)}")
65-
data = self._infer_media_type(annots)(uid=data_row_id)
66-
yield Label(annotations=annots, data=data)
70+
yield Label(annotations=annots,
71+
data=self._infer_media_type(data_row, annots))
6772

6873
def _infer_media_type(
69-
self, annotations: List[Union[TextEntity, VideoClassificationAnnotation,
70-
VideoObjectAnnotation, ObjectAnnotation,
71-
ClassificationAnnotation, ScalarMetric,
72-
ConfusionMatrixMetric]]
74+
self, data_row: DataRow,
75+
annotations: List[Union[TextEntity, VideoClassificationAnnotation,
76+
VideoObjectAnnotation, ObjectAnnotation,
77+
ClassificationAnnotation, ScalarMetric,
78+
ConfusionMatrixMetric]]
7379
) -> Union[TextData, VideoData, ImageData]:
80+
if len(annotations) == 0:
81+
raise ValueError("Missing annotations while inferring media type")
82+
7483
types = {type(annotation) for annotation in annotations}
84+
data = ImageData
7585
if TextEntity in types:
76-
return TextData
86+
data = TextData
7787
elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types:
78-
return VideoData
88+
data = VideoData
89+
90+
if data_row.id:
91+
return data(uid=data_row.id)
7992
else:
80-
return ImageData
93+
return data(global_key=data_row.global_key)
8194

8295
@staticmethod
8396
def _get_consecutive_frames(

labelbox/data/serialization/ndjson/metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional, Union, Type
22

33
from labelbox.data.annotation_types.data import ImageData, TextData
4-
from labelbox.data.serialization.ndjson.base import NDJsonBase
4+
from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase
55
from labelbox.data.annotation_types.metrics.scalar import (
66
ScalarMetric, ScalarMetricAggregation, ScalarMetricValue,
77
ScalarMetricConfidenceValue)
@@ -50,7 +50,7 @@ def from_common(
5050
feature_name=metric.feature_name,
5151
subclass_name=metric.subclass_name,
5252
aggregation=metric.aggregation,
53-
data_row={'id': data.uid})
53+
data_row=DataRow(id=data.uid, global_key=data.global_key))
5454

5555

5656
class NDScalarMetric(BaseNDMetric):
@@ -75,7 +75,7 @@ def from_common(cls, metric: ScalarMetric,
7575
feature_name=metric.feature_name,
7676
subclass_name=metric.subclass_name,
7777
aggregation=metric.aggregation.value,
78-
data_row={'id': data.uid})
78+
data_row=DataRow(id=data.uid, global_key=data.global_key))
7979

8080
def dict(self, *args, **kwargs):
8181
res = super().dict(*args, **kwargs)

labelbox/data/serialization/ndjson/objects.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def from_common(cls,
6060
'x': point.x,
6161
'y': point.y
6262
},
63-
dataRow=DataRow(id=data.uid),
63+
data_row=DataRow(id=data.uid, global_key=data.global_key),
6464
name=name,
6565
schema_id=feature_schema_id,
6666
uuid=extra.get('uuid'),
@@ -105,7 +105,7 @@ def from_common(cls,
105105
'x': pt.x,
106106
'y': pt.y
107107
} for pt in line.points],
108-
dataRow=DataRow(id=data.uid),
108+
data_row=DataRow(id=data.uid, global_key=data.global_key),
109109
name=name,
110110
schema_id=feature_schema_id,
111111
uuid=extra.get('uuid'),
@@ -154,7 +154,7 @@ def from_common(cls,
154154
'x': pt.x,
155155
'y': pt.y
156156
} for pt in polygon.points],
157-
dataRow=DataRow(id=data.uid),
157+
data_row=DataRow(id=data.uid, global_key=data.global_key),
158158
name=name,
159159
schema_id=feature_schema_id,
160160
uuid=extra.get('uuid'),
@@ -183,7 +183,7 @@ def from_common(cls,
183183
left=rectangle.start.x,
184184
height=rectangle.end.y - rectangle.start.y,
185185
width=rectangle.end.x - rectangle.start.x),
186-
dataRow=DataRow(id=data.uid),
186+
data_row=DataRow(id=data.uid, global_key=data.global_key),
187187
name=name,
188188
schema_id=feature_schema_id,
189189
uuid=extra.get('uuid'),
@@ -280,7 +280,7 @@ def from_common(cls, segments: List[VideoObjectAnnotation], data: VideoData,
280280
segments = [NDSegment.from_common(segment) for segment in segments]
281281

282282
return cls(segments=segments,
283-
dataRow=DataRow(id=data.uid),
283+
data_row=DataRow(id=data.uid, global_key=data.global_key),
284284
name=name,
285285
schema_id=feature_schema_id,
286286
uuid=extra.get('uuid'))
@@ -332,7 +332,7 @@ def from_common(cls,
332332
png=base64.b64encode(im_bytes.getvalue()).decode('utf-8'))
333333

334334
return cls(mask=lbv1_mask,
335-
dataRow=DataRow(id=data.uid),
335+
data_row=DataRow(id=data.uid, global_key=data.global_key),
336336
name=name,
337337
schema_id=feature_schema_id,
338338
uuid=extra.get('uuid'),
@@ -364,7 +364,7 @@ def from_common(cls,
364364
start=text_entity.start,
365365
end=text_entity.end,
366366
),
367-
dataRow=DataRow(id=data.uid),
367+
data_row=DataRow(id=data.uid, global_key=data.global_key),
368368
name=name,
369369
schema_id=feature_schema_id,
370370
uuid=extra.get('uuid'),

0 commit comments

Comments
 (0)