Skip to content

Commit 36f4a23

Browse files
author
Val Brodsky
committed
Add a _GenericDataType and allow pass a simple dict to Label
1 parent e3bd0c8 commit 36f4a23

File tree

6 files changed

+468
-57
lines changed

6 files changed

+468
-57
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Callable, Literal, Optional
2+
3+
from labelbox import pydantic_compat
4+
from labelbox.data.annotation_types.data.base_data import BaseData
5+
from labelbox.utils import _NoCoercionMixin
6+
7+
8+
class GenericDataRowData(BaseData, _NoCoercionMixin):
9+
"""Generic data row data
10+
"""
11+
url: Optional[str] = None
12+
class_name: Literal["GenericDataRowData"] = "GenericDataRowData"
13+
14+
def create_url(self, signer: Callable[[bytes], str]) -> None:
15+
return None
16+
17+
@pydantic_compat.root_validator(pre=True)
18+
def validate_one_datarow_key_present(cls, data):
19+
keys = ['external_id', 'global_key', 'uid']
20+
count = 0
21+
for key in keys:
22+
if data.get(key):
23+
count += 1
24+
if count < 1:
25+
raise ValueError(f"Exactly one of {keys} must be present.")
26+
if count > 1:
27+
raise ValueError(f"Only one of {keys} can be present.")
28+
return data

labelbox/data/annotation_types/label.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from labelbox import pydantic_compat
66

77
import labelbox
8+
from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData
89
from labelbox.data.annotation_types.data.tiled_image import TiledImageData
910
from labelbox.schema import ontology
1011
from .annotation import ClassificationAnnotation, ObjectAnnotation
1112
from .relationship import RelationshipAnnotation
1213
from .classification import ClassificationAnswer
13-
from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, MaskData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData
14+
from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData
1415
from .geometry import Mask
1516
from .metrics import ScalarMetric, ConfusionMatrixMetric
1617
from .types import Cuid
@@ -21,7 +22,7 @@
2122
DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData,
2223
ConversationData, DicomData, DocumentData, HTMLData,
2324
LlmPromptCreationData, LlmPromptResponseCreationData,
24-
LlmResponseCreationData]
25+
LlmResponseCreationData, GenericDataRowData]
2526

2627

2728
class Label(pydantic_compat.BaseModel):
@@ -51,6 +52,18 @@ class Label(pydantic_compat.BaseModel):
5152
RelationshipAnnotation]] = []
5253
extra: Dict[str, Any] = {}
5354

55+
@staticmethod
56+
def is_data_type(data: Union[Dict[str, Any], DataType]) -> bool:
57+
if isinstance(data, DataType):
58+
return True
59+
return False
60+
61+
@pydantic_compat.root_validator(pre=True)
62+
def validate_data(cls, label):
63+
if not Label.is_data_type(label.get("data")):
64+
label["data"]["class_name"] = "GenericDataRowData"
65+
return label
66+
5467
def object_annotations(self) -> List[ObjectAnnotation]:
5568
return self._get_annotations_by_type(ObjectAnnotation)
5669

tests/data/annotation_import/conftest.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,3 +1885,57 @@ def bbox_video_annotation_objects():
18851885
]
18861886

18871887
return bbox_annotation
1888+
1889+
1890+
class Helpers:
1891+
1892+
@staticmethod
1893+
def remove_keys_recursive(d, keys):
1894+
for k in keys:
1895+
if k in d:
1896+
del d[k]
1897+
for k, v in d.items():
1898+
if isinstance(v, dict):
1899+
Helpers.remove_keys_recursive(v, keys)
1900+
elif isinstance(v, list):
1901+
for i in v:
1902+
if isinstance(i, dict):
1903+
Helpers.remove_keys_recursive(i, keys)
1904+
1905+
@staticmethod
1906+
# NOTE this uses quite a primitive check for cuids but I do not think it is worth coming up with a better one
1907+
# Also this function is NOT written with performance in mind, good for small to mid size dicts like we have in our test
1908+
def rename_cuid_key_recursive(d):
1909+
new_key = "<cuid>"
1910+
for k in list(d.keys()):
1911+
if len(k) == 25 and not k.isalpha(): # primitive check for cuid
1912+
d[new_key] = d.pop(k)
1913+
for k, v in d.items():
1914+
if isinstance(v, dict):
1915+
Helpers.rename_cuid_key_recursive(v)
1916+
elif isinstance(v, list):
1917+
for i in v:
1918+
if isinstance(i, dict):
1919+
Helpers.rename_cuid_key_recursive(i)
1920+
1921+
1922+
@pytest.fixture
1923+
def helpers():
1924+
return Helpers
1925+
1926+
1927+
@pytest.fixture
1928+
def create_data_row_for_project(project, dataset, data_row_ndjson, batch_name):
1929+
data_row = dataset.create_data_row(data_row_ndjson)
1930+
1931+
project.create_batch(
1932+
batch_name,
1933+
[data_row.uid], # sample of data row objects
1934+
5, # priority between 1(Highest) - 5(lowest)
1935+
)
1936+
project.data_row_ids.append(data_row.uid)
1937+
1938+
yield data_row
1939+
1940+
data_row.delete()
1941+
project.delete()

tests/data/annotation_import/test_data_types.py

Lines changed: 11 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import labelbox as lb
77
from labelbox.data.annotation_types.data.video import VideoData
8-
from labelbox.schema.data_row import DataRow
98
from labelbox.schema.media_type import MediaType
109
import labelbox.types as lb_types
1110
from labelbox.data.annotation_types.data import (
@@ -70,35 +69,6 @@
7069
]
7170

7271

73-
def remove_keys_recursive(d, keys):
74-
for k in keys:
75-
if k in d:
76-
del d[k]
77-
for k, v in d.items():
78-
if isinstance(v, dict):
79-
remove_keys_recursive(v, keys)
80-
elif isinstance(v, list):
81-
for i in v:
82-
if isinstance(i, dict):
83-
remove_keys_recursive(i, keys)
84-
85-
86-
# NOTE this uses quite a primitive check for cuids but I do not think it is worth coming up with a better one
87-
# Also this function is NOT written with performance in mind, good for small to mid size dicts like we have in our test
88-
def rename_cuid_key_recursive(d):
89-
new_key = "<cuid>"
90-
for k in list(d.keys()):
91-
if len(k) == 25 and not k.isalpha(): # primitive check for cuid
92-
d[new_key] = d.pop(k)
93-
for k, v in d.items():
94-
if isinstance(v, dict):
95-
rename_cuid_key_recursive(v)
96-
elif isinstance(v, list):
97-
for i in v:
98-
if isinstance(i, dict):
99-
rename_cuid_key_recursive(i)
100-
101-
10272
def get_annotation_comparison_dicts_from_labels(labels):
10373
labels_ndjson = list(NDJsonConverter.serialize(labels))
10474
for annotation in labels_ndjson:
@@ -161,19 +131,6 @@ def get_annotation_comparison_dicts_from_export(export_result, data_row_id,
161131
return converted_annotations
162132

163133

164-
def create_data_row_for_project(project, dataset, data_row_ndjson, batch_name):
165-
data_row = dataset.create_data_row(data_row_ndjson)
166-
167-
project.create_batch(
168-
batch_name,
169-
[data_row.uid], # sample of data row objects
170-
5, # priority between 1(Highest) - 5(lowest)
171-
)
172-
project.data_row_ids.append(data_row.uid)
173-
174-
return data_row
175-
176-
177134
# TODO: Add VideoData. Currently label import job finishes without errors but project.export_labels() returns empty list.
178135
@pytest.mark.parametrize(
179136
"data_type_class",
@@ -190,15 +147,10 @@ def create_data_row_for_project(project, dataset, data_row_ndjson, batch_name):
190147
LlmResponseCreationData,
191148
],
192149
)
193-
def test_import_data_types(
194-
client,
195-
configured_project,
196-
initial_dataset,
197-
rand_gen,
198-
data_row_json_by_data_type,
199-
annotations_by_data_type,
200-
data_type_class,
201-
):
150+
def test_import_data_types(client, configured_project, initial_dataset,
151+
rand_gen, data_row_json_by_data_type,
152+
annotations_by_data_type, data_type_class,
153+
create_data_row_for_project):
202154
project = configured_project
203155
project_id = project.uid
204156
dataset = initial_dataset
@@ -241,6 +193,7 @@ def test_import_data_types_by_global_key(
241193
rand_gen,
242194
data_row_json_by_data_type,
243195
annotations_by_data_type,
196+
create_data_row_for_project,
244197
):
245198
project = configured_project
246199
project_id = project.uid
@@ -331,6 +284,8 @@ def test_import_data_types_v2(
331284
exports_v2_by_data_type,
332285
export_v2_test_helpers,
333286
rand_gen,
287+
helpers,
288+
create_data_row_for_project,
334289
):
335290
project = configured_project
336291
dataset = initial_dataset
@@ -381,9 +336,9 @@ def test_import_data_types_v2(
381336
exported_project_labels = exported_project["labels"][0]
382337
exported_annotations = exported_project_labels["annotations"]
383338

384-
remove_keys_recursive(exported_annotations,
385-
["feature_id", "feature_schema_id"])
386-
rename_cuid_key_recursive(exported_annotations)
339+
helpers.remove_keys_recursive(exported_annotations,
340+
["feature_id", "feature_schema_id"])
341+
helpers.rename_cuid_key_recursive(exported_annotations)
387342
assert exported_annotations == exports_v2_by_data_type[data_type_string]
388343

389344
data_row = client.get_data_row(data_row.uid)
@@ -400,6 +355,7 @@ def test_import_label_annotations(
400355
data_class,
401356
annotations,
402357
rand_gen,
358+
create_data_row_for_project,
403359
):
404360
project = configured_project_with_one_data_row
405361
dataset = initial_dataset

0 commit comments

Comments
 (0)