Skip to content

Commit 7906f20

Browse files
author
Matt Sokoloff
committed
wip
1 parent b8ae50e commit 7906f20

File tree

4 files changed

+77
-90
lines changed

4 files changed

+77
-90
lines changed

labelbox/orm/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ class Type(Enum):
4343
ID = auto()
4444
DateTime = auto()
4545
Json = auto()
46-
List = auto()
4746

4847
class EnumType:
4948

@@ -91,10 +90,6 @@ def Enum(enum_cls: type, *args):
9190
def Json(*args):
9291
return Field(Field.Type.Json, *args)
9392

94-
@staticmethod
95-
def List(*args):
96-
return Field(Field.Type.List, *args)
97-
9893
def __init__(self,
9994
field_type: Union[Type, EnumType],
10095
name,

labelbox/schema/data_row_metadata.py

Lines changed: 75 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from itertools import chain
66
from typing import List, Optional, Dict, Union, Callable, Type
77

8-
from pydantic import BaseModel, conlist, constr
8+
from pydantic import BaseModel, conlist, constr, root_validator
99

1010
from labelbox.schema.ontology import SchemaId
11+
from labelbox.utils import camel_case
1112

1213

1314
class DataRowMetadataKind(Enum):
@@ -33,58 +34,70 @@ class DataRowMetadataSchema(BaseModel):
3334
# Constraints for metadata values
3435
Embedding: Type[List[float]] = conlist(float, min_items=128, max_items=128)
3536
DateTime: Type[datetime.datetime] = datetime.datetime # must be in UTC
36-
String: Type[str] = constr(max_length=500)
37+
String: Type[str] = constr(max_length=1000000) # 500)
3738
OptionId: Type[SchemaId] = SchemaId # enum option
3839

3940
DataRowMetadataValue = Union[Embedding, DateTime, String, OptionId]
4041

4142

43+
class CamelCaseMixin(BaseModel):
44+
45+
class Config:
46+
allow_population_by_field_name = True
47+
alias_generator = camel_case
48+
49+
4250
# Metadata base class
43-
class DataRowMetadataField(BaseModel):
51+
class DataRowMetadataField(CamelCaseMixin):
4452
schema_id: SchemaId
4553
value: DataRowMetadataValue
4654

4755

48-
class DataRowMetadata(BaseModel):
56+
class DataRowMetadata(CamelCaseMixin):
4957
data_row_id: str
5058
fields: List[DataRowMetadataField]
5159

5260

53-
class DeleteDataRowMetadata(BaseModel):
61+
class DeleteDataRowMetadata(CamelCaseMixin):
5462
data_row_id: str
5563
fields: List[SchemaId]
5664

5765

58-
class DataRowMetadataBatchResponseFields:
66+
class DataRowMetadataBatchResponseFields(CamelCaseMixin):
5967
schema_id: str
6068
value: DataRowMetadataValue
6169

6270

63-
class DataRowMetadataBatchResponse:
71+
class DataRowMetadataBatchResponse(CamelCaseMixin):
6472
data_row_id: str
6573
error: str
6674
fields: List[DataRowMetadataBatchResponseFields]
6775

76+
@root_validator(pre=True)
77+
def handle_enums(cls, values):
78+
breakpoint()
79+
return values
80+
6881

6982
# --- Batch GraphQL Objects ---
7083
# Don't want to crowd the name space with internals
7184

7285

7386
# Bulk upsert values
74-
class _UpsertDataRowMetadataInput(BaseModel):
75-
schemaId: str
87+
class _UpsertDataRowMetadataInput(CamelCaseMixin):
88+
schema_id: str
7689
value: Union[str, List, dict]
7790

7891

7992
# Batch of upsert values for a datarow
80-
class _UpsertBatchDataRowMetadata(BaseModel):
81-
dataRowId: str
93+
class _UpsertBatchDataRowMetadata(CamelCaseMixin):
94+
data_row_id: str
8295
fields: List[_UpsertDataRowMetadataInput]
8396

8497

85-
class _DeleteBatchDataRowMetadata(BaseModel):
86-
dataRowId: str
87-
schemaIds: List[SchemaId]
98+
class _DeleteBatchDataRowMetadata(CamelCaseMixin):
99+
data_row_id: str
100+
schema_ids: List[SchemaId]
88101

89102

90103
_BatchInputs = Union[List[_UpsertBatchDataRowMetadata],
@@ -99,7 +112,7 @@ class DataRowMetadataOntology:
99112
reserved and custom. Reserved fields are defined by Labelbox and used for creating
100113
specific experiences in the platform.
101114
102-
>>> mdo = client.get_datarow_metadata_ontology()
115+
>>> mdo = client.get_data_row_metadata_ontology()
103116
104117
"""
105118

@@ -131,14 +144,14 @@ def __init__(self, client):
131144
]
132145
self.custom_id_index: Dict[SchemaId,
133146
DataRowMetadataSchema] = self._make_id_index(
134-
self.custom_fields)
147+
self.custom_fields)
135148
self.custom_name_index: Dict[str, DataRowMetadataSchema] = {
136149
f.name: f for f in self.custom_fields
137150
}
138151

139152
@staticmethod
140153
def _make_id_index(
141-
fields: List[DataRowMetadataSchema]
154+
fields: List[DataRowMetadataSchema]
142155
) -> Dict[SchemaId, DataRowMetadataSchema]:
143156
index = {}
144157
for f in fields:
@@ -187,9 +200,9 @@ def _parse_ontology(self):
187200
return fields
188201

189202
def parse_metadata(
190-
self, unparsed: List[Dict[str,
191-
List[Union[str,
192-
Dict]]]]) -> List[DataRowMetadata]:
203+
self, unparsed: List[Dict[str,
204+
List[Union[str,
205+
Dict]]]]) -> List[DataRowMetadata]:
193206
""" Parse metadata responses
194207
195208
>>> mdo.parse_metadata([datarow.metadata])
@@ -201,7 +214,6 @@ def parse_metadata(
201214
metadata: List of `DataRowMetadata`
202215
203216
"""
204-
205217
parsed = []
206218
for dr in unparsed:
207219
fields = []
@@ -219,14 +231,14 @@ def parse_metadata(
219231
fields.append(field)
220232
parsed.append(
221233
DataRowMetadata(data_row_id=dr["data_row_id"], fields=fields))
222-
223234
return parsed
224235

225236
def bulk_upsert(
226237
self, metadata: List[DataRowMetadata]
227238
) -> List[DataRowMetadataBatchResponse]:
228239
"""Upsert datarow metadata
229240
241+
230242
>>> metadata = DataRowMetadata(
231243
>>> data_row_id="datarow-id",
232244
>>> fields=[
@@ -240,14 +252,26 @@ def bulk_upsert(
240252
metadata: List of DataRow Metadata
241253
242254
Returns:
243-
response: []
255+
response: List of response objects containing the status of the upsert, the data row id,
256+
257+
258+
class DataRowMetadataBatchResponseFields:
259+
schema_id: str
260+
value: DataRowMetadataValue
261+
262+
263+
class DataRowMetadataBatchResponse:
264+
data_row_id: str
265+
error: str
266+
fields: List[DataRowMetadataBatchResponseFields]
267+
244268
"""
245269

246270
if not (len(metadata)):
247271
raise ValueError("Empty list passed")
248272

249273
def _batch_upsert(
250-
upserts: List[_UpsertBatchDataRowMetadata]
274+
upserts: List[_UpsertBatchDataRowMetadata]
251275
) -> List[DataRowMetadataBatchResponse]:
252276

253277
query = """mutation UpsertDataRowMetadataBetaPyApi($metadata: [DataRowCustomMetadataBatchUpsertInput!]!) {
@@ -260,22 +284,28 @@ def _batch_upsert(
260284
}
261285
}
262286
}"""
263-
264-
return self.client.execute(query, {"metadata": upserts})
287+
res = self.client.execute(
288+
query, {"metadata": upserts})['upsertDataRowCustomMetadata']
289+
breakpoint()
290+
return [DataRowMetadataBatchResponse(**r) for r in res]
265291

266292
items = []
267293
for m in metadata:
268294
items.append(
269295
_UpsertBatchDataRowMetadata(
270-
dataRowId=m.data_row_id,
296+
data_row_id=m.data_row_id,
271297
fields=list(
272298
chain.from_iterable(
273-
self._parse_upsert(m) for m in m.fields))).dict())
299+
self._parse_upsert(m) for m in m.fields))).dict(
300+
by_alias=True))
274301

275-
return _batch_operations(_batch_upsert, items, self._batch_size)
302+
breakpoint()
303+
res = _batch_operations(_batch_upsert, items, self._batch_size)
304+
breakpoint()
305+
return res
276306

277307
def bulk_delete(
278-
self, deletes: List[DeleteDataRowMetadata]
308+
self, deletes: List[DeleteDataRowMetadata]
279309
) -> List[DataRowMetadataBatchResponse]:
280310
""" Delete metadata from a datarow by specifiying the fields you want to remove
281311
@@ -291,17 +321,18 @@ def bulk_delete(
291321
292322
293323
Args:
294-
deletes:
324+
deletes: Data row and schema ids to delete
295325
296326
Returns:
327+
list of
297328
298329
"""
299330

300331
if not len(deletes):
301332
raise ValueError("Empty list passed")
302333

303334
def _batch_delete(
304-
deletes: List[_DeleteBatchDataRowMetadata]
335+
deletes: List[_DeleteBatchDataRowMetadata]
305336
) -> List[DataRowMetadataBatchResponse]:
306337
query = """mutation DeleteDataRowMetadataBetaPyApi($deletes: [DataRowCustomMetadataBatchDeleteInput!]!) {
307338
deleteDataRowCustomMetadata(data: $deletes) {
@@ -314,13 +345,15 @@ def _batch_delete(
314345
}
315346
}
316347
"""
317-
return self.client.execute(query, {"deletes": deletes})
318348

319-
items = []
320-
for m in deletes:
321-
items.append(self._validate_delete(m))
349+
res = self.client.execute(query, {"deletes": deletes})
350+
breakpoint()
351+
return res
322352

323-
return _batch_operations(_batch_delete, items, batch_size=self._batch_size)
353+
items = [self._validate_delete(m) for m in deletes]
354+
return _batch_operations(_batch_delete,
355+
items,
356+
batch_size=self._batch_size)
324357

325358
def _parse_upsert(
326359
self, metadatum: DataRowMetadataField
@@ -378,17 +411,17 @@ def _batch_items(iterable, size):
378411

379412

380413
def _batch_operations(
381-
batch_function: _BatchFunction,
382-
items: List,
383-
batch_size: int = 100,
414+
batch_function: _BatchFunction,
415+
items: List,
416+
batch_size: int = 100,
384417
):
385418
response = []
419+
386420
for batch in _batch_items(items, batch_size):
387421
response += batch_function(batch)
388422
# TODO: understand this better
389423
# if len(response):
390424
# raise ValueError(response)
391-
392425
return response
393426

394427

@@ -410,11 +443,7 @@ def _validate_parse_datetime(field: DataRowMetadataField):
410443
def _validate_parse_text(field: DataRowMetadataField):
411444
if not isinstance(field.value, str):
412445
raise ValueError("Invalid value")
413-
414-
return [{
415-
"schemaId": field.schema_id,
416-
"value": field.value,
417-
}]
446+
return [field.dict(by_alias=True)]
418447

419448

420449
def _validate_enum_parse(schema: DataRowMetadataSchema,

tests/integration/test_asset_metadata.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

tests/integration/test_data_row_metadata.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def big_dataset(dataset: Dataset):
3636

3737
def make_metadata(dr_id) -> DataRowMetadata:
3838
embeddings = [0.0] * 128
39-
msg = "my-message"
39+
msg = "my-message" * 1000
4040
time = datetime.utcnow()
4141

4242
metadata = DataRowMetadata(
@@ -166,8 +166,7 @@ def test_bulk_delete_datarow_enum_metadata(datarow: DataRow, mdo):
166166
assert len(datarow.metadata["fields"])
167167

168168
mdo.bulk_delete([
169-
DeleteDataRowMetadata(data_row_id=datarow.uid,
170-
fields=[SPLIT_SCHEMA_ID])
169+
DeleteDataRowMetadata(data_row_id=datarow.uid, fields=[SPLIT_SCHEMA_ID])
171170
])
172171
assert not len(datarow.metadata["fields"])
173172

0 commit comments

Comments
 (0)