Skip to content

Commit 3e44635

Browse files
author
Matt Sokoloff
committed
minor changes
1 parent 994b0da commit 3e44635

File tree

2 files changed

+66
-35
lines changed

2 files changed

+66
-35
lines changed

labelbox/schema/data_row_metadata.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import datetime
44
from enum import Enum
55
from itertools import chain
6-
from typing import List, Optional, Dict, Union, Callable, Type, Any
6+
from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator
77

8-
from pydantic import BaseModel, conlist, constr, root_validator
8+
from pydantic import BaseModel, conlist, constr
99

1010
from labelbox.schema.ontology import SchemaId
1111
from labelbox.utils import camel_case
@@ -41,6 +41,7 @@ class DataRowMetadataSchema(BaseModel):
4141

4242

4343
class CamelCaseMixin(BaseModel):
44+
4445
class Config:
4546
allow_population_by_field_name = True
4647
alias_generator = camel_case
@@ -65,7 +66,7 @@ class DeleteDataRowMetadata(CamelCaseMixin):
6566
class DataRowMetadataBatchResponse(CamelCaseMixin):
6667
data_row_id: str
6768
error: str
68-
fields: List[DataRowMetadataField]
69+
fields: List[Union[DataRowMetadataField, SchemaId]]
6970

7071

7172
# --- Batch GraphQL Objects ---
@@ -113,7 +114,7 @@ def __init__(self, client):
113114
self._raw_ontology = self._get_ontology()
114115
# all fields
115116
self.all_fields = self._parse_ontology()
116-
self.all_fields_id_index= self._make_id_index(self.all_fields)
117+
self.all_fields_id_index = self._make_id_index(self.all_fields)
117118
# reserved fields
118119
self.reserved_fields: List[DataRowMetadataSchema] = [
119120
f for f in self.all_fields if f.reserved
@@ -146,8 +147,16 @@ def _make_id_index(
146147
def _get_ontology(self) -> Dict[str, Any]:
147148
query = """query GetMetadataOntologyBetaPyApi {
148149
customMetadataOntology {
149-
id name kind reserved
150-
options { id kind name reserved}
150+
id
151+
name
152+
kind
153+
reserved
154+
options {
155+
id
156+
kind
157+
name
158+
reserved
159+
}
151160
}}
152161
"""
153162
return self.client.execute(query)["customMetadataOntology"]
@@ -165,7 +174,6 @@ def _parse_ontology(self) -> List[DataRowMetadataSchema]:
165174
}
166175
}) for option in schema["options"]
167176
]
168-
169177
schema["options"] = options
170178
fields.append(DataRowMetadataSchema(**schema))
171179

@@ -221,10 +229,11 @@ def bulk_upsert(
221229
>>> mdo.batch_upsert([metadata])
222230
223231
Args:
224-
metadata: List of DataRow Metadata
232+
metadata: List of DataRow Metadata to upsert
225233
226234
Returns:
227-
response: List of response objects containing the status of the upsert, the data row id,
235+
list of unsuccessful upserts.
236+
An empty list means the upload was successful.
228237
"""
229238

230239
if not (len(metadata)):
@@ -234,12 +243,23 @@ def _batch_upsert(
234243
upserts: List[_UpsertBatchDataRowMetadata]
235244
) -> List[DataRowMetadataBatchResponse]:
236245
query = """mutation UpsertDataRowMetadataBetaPyApi($metadata: [DataRowCustomMetadataBatchUpsertInput!]!) {
237-
upsertDataRowCustomMetadata(data: $metadata){
238-
dataRowId error fields { value schemaId}}
246+
upsertDataRowCustomMetadata(data: $metadata){
247+
dataRowId
248+
error
249+
fields {
250+
value
251+
schemaId
252+
}
253+
}
239254
}"""
240255
res = self.client.execute(
241256
query, {"metadata": upserts})['upsertDataRowCustomMetadata']
242-
return [DataRowMetadataBatchResponse(data_row_id = r['dataRowId'], error = r['error'], fields = self.parse_metadata([r])[0].fields ) for r in res]
257+
return [
258+
DataRowMetadataBatchResponse(data_row_id=r['dataRowId'],
259+
error=r['error'],
260+
fields=self.parse_metadata(
261+
[r])[0].fields) for r in res
262+
]
243263

244264
items = []
245265
for m in metadata:
@@ -256,7 +276,7 @@ def _batch_upsert(
256276

257277
def bulk_delete(
258278
self, deletes: List[DeleteDataRowMetadata]
259-
) -> List:
279+
) -> List[DataRowMetadataBatchResponse]:
260280
""" Delete metadata from a datarow by specifiying the fields you want to remove
261281
262282
>>> delete = DeleteDataRowMetadata(
@@ -269,12 +289,12 @@ def bulk_delete(
269289
>>> )
270290
>>> mdo.batch_delete([metadata])
271291
272-
273292
Args:
274293
deletes: Data row and schema ids to delete
275294
276295
Returns:
277-
list of
296+
list of unsuccessful deletions.
297+
An empty list means all data rows were successfully deleted.
278298
279299
"""
280300

@@ -283,13 +303,25 @@ def bulk_delete(
283303

284304
def _batch_delete(
285305
deletes: List[_DeleteBatchDataRowMetadata]
286-
) -> List:
306+
) -> List[DataRowMetadataBatchResponse]:
287307
query = """mutation DeleteDataRowMetadataBetaPyApi($deletes: [DataRowCustomMetadataBatchDeleteInput!]!) {
288-
deleteDataRowCustomMetadata(data: $deletes) {
289-
dataRowId error fields { value schemaId } }}
308+
deleteDataRowCustomMetadata(data: $deletes) {
309+
dataRowId
310+
error
311+
fields {
312+
value
313+
schemaId
314+
}
315+
}
316+
}
290317
"""
291-
return self.client.execute(query, {"deletes": deletes})['deleteDataRowCustomMetadata']
292-
318+
res = self.client.execute(
319+
query, {"deletes": deletes})['deleteDataRowCustomMetadata']
320+
failures = []
321+
for dr in res:
322+
dr['fields'] = [f['schemaId'] for f in dr['fields']]
323+
failures.append(DataRowMetadataBatchResponse(**dr))
324+
return failures
293325

294326
items = [self._validate_delete(m) for m in deletes]
295327
return _batch_operations(_batch_delete,
@@ -339,12 +371,12 @@ def _validate_delete(self, delete: DeleteDataRowMetadata):
339371

340372
deletes.add(schema.id)
341373

342-
return _DeleteBatchDataRowMetadata(dataRowId=delete.data_row_id,
343-
schemaIds=list(
344-
delete.fields)).dict(by_alias = True)
374+
return _DeleteBatchDataRowMetadata(
375+
dataRowId=delete.data_row_id,
376+
schemaIds=list(delete.fields)).dict(by_alias=True)
345377

346378

347-
def _batch_items(iterable, size):
379+
def _batch_items(iterable: List[Any], size: int) -> Generator[Any, None, None]:
348380
l = len(iterable)
349381
for ndx in range(0, l, size):
350382
yield iterable[ndx:min(ndx + size, l)]
@@ -362,24 +394,29 @@ def _batch_operations(
362394
return response
363395

364396

365-
def _validate_parse_embedding(field: DataRowMetadataField):
397+
def _validate_parse_embedding(
398+
field: DataRowMetadataField
399+
) -> List[Dict[str, Union[SchemaId, Embedding]]]:
366400
return [field.dict(by_alias=True)]
367401

368402

369-
def _validate_parse_datetime(field: DataRowMetadataField):
403+
def _validate_parse_datetime(
404+
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]:
370405
# TODO: better validate tzinfo
371406
return [{
372407
"schemaId": field.schema_id,
373408
"value": field.value.isoformat() + "Z", # needs to be UTC
374409
}]
375410

376411

377-
def _validate_parse_text(field: DataRowMetadataField):
412+
def _validate_parse_text(
413+
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]:
378414
return [field.dict(by_alias=True)]
379415

380416

381-
def _validate_enum_parse(schema: DataRowMetadataSchema,
382-
field: DataRowMetadataField):
417+
def _validate_enum_parse(
418+
schema: DataRowMetadataSchema,
419+
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, dict]]]:
383420
if schema.options:
384421
if field.value not in {o.id for o in schema.options}:
385422
raise ValueError(

tests/integration/test_data_row_metadata.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def test_bulk_delete_datarow_metadata(datarow, mdo):
9898

9999
assert len(datarow.metadata["fields"])
100100

101-
102101
mdo.bulk_delete([
103102
DeleteDataRowMetadata(data_row_id=datarow.uid,
104103
fields=[m.schema_id for m in metadata.fields])
@@ -244,8 +243,3 @@ def test_parse_raw_metadata(mdo):
244243
row = parsed[0]
245244
assert row.data_row_id == example["dataRowId"]
246245
assert len(row.fields) == 3
247-
248-
249-
250-
251-

0 commit comments

Comments
 (0)