33import datetime
44from enum import Enum
55from itertools import chain
6- from typing import List , Optional , Dict , Union , Callable , Type , Any
6+ from typing import List , Optional , Dict , Union , Callable , Type , Any , Generator
77
8- from pydantic import BaseModel , conlist , constr , root_validator
8+ from pydantic import BaseModel , conlist , constr
99
1010from labelbox .schema .ontology import SchemaId
1111from labelbox .utils import camel_case
@@ -41,6 +41,7 @@ class DataRowMetadataSchema(BaseModel):
4141
4242
4343class CamelCaseMixin (BaseModel ):
44+
4445 class Config :
4546 allow_population_by_field_name = True
4647 alias_generator = camel_case
@@ -65,7 +66,7 @@ class DeleteDataRowMetadata(CamelCaseMixin):
6566class DataRowMetadataBatchResponse (CamelCaseMixin ):
6667 data_row_id : str
6768 error : str
68- fields : List [DataRowMetadataField ]
69+ fields : List [Union [ DataRowMetadataField , SchemaId ] ]
6970
7071
7172# --- Batch GraphQL Objects ---
@@ -113,7 +114,7 @@ def __init__(self, client):
113114 self ._raw_ontology = self ._get_ontology ()
114115 # all fields
115116 self .all_fields = self ._parse_ontology ()
116- self .all_fields_id_index = self ._make_id_index (self .all_fields )
117+ self .all_fields_id_index = self ._make_id_index (self .all_fields )
117118 # reserved fields
118119 self .reserved_fields : List [DataRowMetadataSchema ] = [
119120 f for f in self .all_fields if f .reserved
@@ -146,8 +147,16 @@ def _make_id_index(
146147 def _get_ontology (self ) -> Dict [str , Any ]:
147148 query = """query GetMetadataOntologyBetaPyApi {
148149 customMetadataOntology {
149- id name kind reserved
150- options { id kind name reserved}
150+ id
151+ name
152+ kind
153+ reserved
154+ options {
155+ id
156+ kind
157+ name
158+ reserved
159+ }
151160 }}
152161 """
153162 return self .client .execute (query )["customMetadataOntology" ]
@@ -165,7 +174,6 @@ def _parse_ontology(self) -> List[DataRowMetadataSchema]:
165174 }
166175 }) for option in schema ["options" ]
167176 ]
168-
169177 schema ["options" ] = options
170178 fields .append (DataRowMetadataSchema (** schema ))
171179
@@ -221,10 +229,11 @@ def bulk_upsert(
221229 >>> mdo.batch_upsert([metadata])
222230
223231 Args:
224- metadata: List of DataRow Metadata
232+ metadata: List of DataRow Metadata to upsert
225233
226234 Returns:
227- response: List of response objects containing the status of the upsert, the data row id,
235+ list of unsuccessful upserts.
236+ An empty list means the upload was successful.
228237 """
229238
230239 if not (len (metadata )):
@@ -234,12 +243,23 @@ def _batch_upsert(
234243 upserts : List [_UpsertBatchDataRowMetadata ]
235244 ) -> List [DataRowMetadataBatchResponse ]:
236245 query = """mutation UpsertDataRowMetadataBetaPyApi($metadata: [DataRowCustomMetadataBatchUpsertInput!]!) {
237- upsertDataRowCustomMetadata(data: $metadata){
238- dataRowId error fields { value schemaId}}
246+ upsertDataRowCustomMetadata(data: $metadata){
247+ dataRowId
248+ error
249+ fields {
250+ value
251+ schemaId
252+ }
253+ }
239254 }"""
240255 res = self .client .execute (
241256 query , {"metadata" : upserts })['upsertDataRowCustomMetadata' ]
242- return [DataRowMetadataBatchResponse (data_row_id = r ['dataRowId' ], error = r ['error' ], fields = self .parse_metadata ([r ])[0 ].fields ) for r in res ]
257+ return [
258+ DataRowMetadataBatchResponse (data_row_id = r ['dataRowId' ],
259+ error = r ['error' ],
260+ fields = self .parse_metadata (
261+ [r ])[0 ].fields ) for r in res
262+ ]
243263
244264 items = []
245265 for m in metadata :
@@ -256,7 +276,7 @@ def _batch_upsert(
256276
257277 def bulk_delete (
258278 self , deletes : List [DeleteDataRowMetadata ]
259- ) -> List :
279+ ) -> List [ DataRowMetadataBatchResponse ] :
260280 """ Delete metadata from a datarow by specifiying the fields you want to remove
261281
262282 >>> delete = DeleteDataRowMetadata(
@@ -269,12 +289,12 @@ def bulk_delete(
269289 >>> )
270290 >>> mdo.batch_delete([metadata])
271291
272-
273292 Args:
274293 deletes: Data row and schema ids to delete
275294
276295 Returns:
277- list of
296+ list of unsuccessful deletions.
297+ An empty list means all data rows were successfully deleted.
278298
279299 """
280300
@@ -283,13 +303,25 @@ def bulk_delete(
283303
284304 def _batch_delete (
285305 deletes : List [_DeleteBatchDataRowMetadata ]
286- ) -> List :
306+ ) -> List [ DataRowMetadataBatchResponse ] :
287307 query = """mutation DeleteDataRowMetadataBetaPyApi($deletes: [DataRowCustomMetadataBatchDeleteInput!]!) {
288- deleteDataRowCustomMetadata(data: $deletes) {
289- dataRowId error fields { value schemaId } }}
308+ deleteDataRowCustomMetadata(data: $deletes) {
309+ dataRowId
310+ error
311+ fields {
312+ value
313+ schemaId
314+ }
315+ }
316+ }
290317 """
291- return self .client .execute (query , {"deletes" : deletes })['deleteDataRowCustomMetadata' ]
292-
318+ res = self .client .execute (
319+ query , {"deletes" : deletes })['deleteDataRowCustomMetadata' ]
320+ failures = []
321+ for dr in res :
322+ dr ['fields' ] = [f ['schemaId' ] for f in dr ['fields' ]]
323+ failures .append (DataRowMetadataBatchResponse (** dr ))
324+ return failures
293325
294326 items = [self ._validate_delete (m ) for m in deletes ]
295327 return _batch_operations (_batch_delete ,
@@ -339,12 +371,12 @@ def _validate_delete(self, delete: DeleteDataRowMetadata):
339371
340372 deletes .add (schema .id )
341373
342- return _DeleteBatchDataRowMetadata (dataRowId = delete . data_row_id ,
343- schemaIds = list (
344- delete .fields )).dict (by_alias = True )
374+ return _DeleteBatchDataRowMetadata (
375+ dataRowId = delete . data_row_id ,
376+ schemaIds = list ( delete .fields )).dict (by_alias = True )
345377
346378
347- def _batch_items (iterable , size ) :
379+ def _batch_items (iterable : List [ Any ] , size : int ) -> Generator [ Any , None , None ] :
348380 l = len (iterable )
349381 for ndx in range (0 , l , size ):
350382 yield iterable [ndx :min (ndx + size , l )]
@@ -362,24 +394,29 @@ def _batch_operations(
362394 return response
363395
364396
365- def _validate_parse_embedding (field : DataRowMetadataField ):
397+ def _validate_parse_embedding (
398+ field : DataRowMetadataField
399+ ) -> List [Dict [str , Union [SchemaId , Embedding ]]]:
366400 return [field .dict (by_alias = True )]
367401
368402
369- def _validate_parse_datetime (field : DataRowMetadataField ):
403+ def _validate_parse_datetime (
404+ field : DataRowMetadataField ) -> List [Dict [str , Union [SchemaId , str ]]]:
370405 # TODO: better validate tzinfo
371406 return [{
372407 "schemaId" : field .schema_id ,
373408 "value" : field .value .isoformat () + "Z" , # needs to be UTC
374409 }]
375410
376411
377- def _validate_parse_text (field : DataRowMetadataField ):
412+ def _validate_parse_text (
413+ field : DataRowMetadataField ) -> List [Dict [str , Union [SchemaId , str ]]]:
378414 return [field .dict (by_alias = True )]
379415
380416
381- def _validate_enum_parse (schema : DataRowMetadataSchema ,
382- field : DataRowMetadataField ):
417+ def _validate_enum_parse (
418+ schema : DataRowMetadataSchema ,
419+ field : DataRowMetadataField ) -> List [Dict [str , Union [SchemaId , dict ]]]:
383420 if schema .options :
384421 if field .value not in {o .id for o in schema .options }:
385422 raise ValueError (
0 commit comments