55from itertools import chain
66from typing import List , Optional , Dict , Union , Callable , Type
77
8- from pydantic import BaseModel , conlist , constr
8+ from pydantic import BaseModel , conlist , constr , root_validator
99
1010from labelbox .schema .ontology import SchemaId
11+ from labelbox .utils import camel_case
1112
1213
1314class DataRowMetadataKind (Enum ):
@@ -33,58 +34,70 @@ class DataRowMetadataSchema(BaseModel):
3334# Constraints for metadata values
3435Embedding : Type [List [float ]] = conlist (float , min_items = 128 , max_items = 128 )
3536DateTime : Type [datetime .datetime ] = datetime .datetime # must be in UTC
36- String : Type [str ] = constr (max_length = 500 )
37+ String : Type [str ] = constr (max_length = 1000000 ) # 500)
3738OptionId : Type [SchemaId ] = SchemaId # enum option
3839
3940DataRowMetadataValue = Union [Embedding , DateTime , String , OptionId ]
4041
4142
43+ class CamelCaseMixin (BaseModel ):
44+
45+ class Config :
46+ allow_population_by_field_name = True
47+ alias_generator = camel_case
48+
49+
4250# Metadata base class
43- class DataRowMetadataField (BaseModel ):
51+ class DataRowMetadataField (CamelCaseMixin ):
4452 schema_id : SchemaId
4553 value : DataRowMetadataValue
4654
4755
48- class DataRowMetadata (BaseModel ):
56+ class DataRowMetadata (CamelCaseMixin ):
4957 data_row_id : str
5058 fields : List [DataRowMetadataField ]
5159
5260
53- class DeleteDataRowMetadata (BaseModel ):
61+ class DeleteDataRowMetadata (CamelCaseMixin ):
5462 data_row_id : str
5563 fields : List [SchemaId ]
5664
5765
58- class DataRowMetadataBatchResponseFields :
66+ class DataRowMetadataBatchResponseFields ( CamelCaseMixin ) :
5967 schema_id : str
6068 value : DataRowMetadataValue
6169
6270
63- class DataRowMetadataBatchResponse :
71+ class DataRowMetadataBatchResponse ( CamelCaseMixin ) :
6472 data_row_id : str
6573 error : str
6674 fields : List [DataRowMetadataBatchResponseFields ]
6775
76+ @root_validator (pre = True )
77+ def handle_enums (cls , values ):
78+ breakpoint ()
79+ return values
80+
6881
6982# --- Batch GraphQL Objects ---
7083# Don't want to crowd the name space with internals
7184
7285
7386# Bulk upsert values
74- class _UpsertDataRowMetadataInput (BaseModel ):
75- schemaId : str
87+ class _UpsertDataRowMetadataInput (CamelCaseMixin ):
88+ schema_id : str
7689 value : Union [str , List , dict ]
7790
7891
7992# Batch of upsert values for a datarow
80- class _UpsertBatchDataRowMetadata (BaseModel ):
81- dataRowId : str
93+ class _UpsertBatchDataRowMetadata (CamelCaseMixin ):
94+ data_row_id : str
8295 fields : List [_UpsertDataRowMetadataInput ]
8396
8497
85- class _DeleteBatchDataRowMetadata (BaseModel ):
86- dataRowId : str
87- schemaIds : List [SchemaId ]
98+ class _DeleteBatchDataRowMetadata (CamelCaseMixin ):
99+ data_row_id : str
100+ schema_ids : List [SchemaId ]
88101
89102
90103_BatchInputs = Union [List [_UpsertBatchDataRowMetadata ],
@@ -99,7 +112,7 @@ class DataRowMetadataOntology:
99112 reserved and custom. Reserved fields are defined by Labelbox and used for creating
100113 specific experiences in the platform.
101114
102- >>> mdo = client.get_datarow_metadata_ontology ()
115+ >>> mdo = client.get_data_row_metadata_ontology ()
103116
104117 """
105118
@@ -131,14 +144,14 @@ def __init__(self, client):
131144 ]
132145 self .custom_id_index : Dict [SchemaId ,
133146 DataRowMetadataSchema ] = self ._make_id_index (
134- self .custom_fields )
147+ self .custom_fields )
135148 self .custom_name_index : Dict [str , DataRowMetadataSchema ] = {
136149 f .name : f for f in self .custom_fields
137150 }
138151
139152 @staticmethod
140153 def _make_id_index (
141- fields : List [DataRowMetadataSchema ]
154+ fields : List [DataRowMetadataSchema ]
142155 ) -> Dict [SchemaId , DataRowMetadataSchema ]:
143156 index = {}
144157 for f in fields :
@@ -187,9 +200,9 @@ def _parse_ontology(self):
187200 return fields
188201
189202 def parse_metadata (
190- self , unparsed : List [Dict [str ,
191- List [Union [str ,
192- Dict ]]]]) -> List [DataRowMetadata ]:
203+ self , unparsed : List [Dict [str ,
204+ List [Union [str ,
205+ Dict ]]]]) -> List [DataRowMetadata ]:
193206 """ Parse metadata responses
194207
195208 >>> mdo.parse_metadata([datarow.metadata])
@@ -201,7 +214,6 @@ def parse_metadata(
201214 metadata: List of `DataRowMetadata`
202215
203216 """
204-
205217 parsed = []
206218 for dr in unparsed :
207219 fields = []
@@ -219,14 +231,14 @@ def parse_metadata(
219231 fields .append (field )
220232 parsed .append (
221233 DataRowMetadata (data_row_id = dr ["data_row_id" ], fields = fields ))
222-
223234 return parsed
224235
225236 def bulk_upsert (
226237 self , metadata : List [DataRowMetadata ]
227238 ) -> List [DataRowMetadataBatchResponse ]:
228239 """Upsert datarow metadata
229240
241+
230242 >>> metadata = DataRowMetadata(
231243 >>> data_row_id="datarow-id",
232244 >>> fields=[
@@ -240,14 +252,26 @@ def bulk_upsert(
240252 metadata: List of DataRow Metadata
241253
242254 Returns:
243- response: []
255+ response: List of response objects containing the status of the upsert, the data row id,
256+
257+
258+ class DataRowMetadataBatchResponseFields:
259+ schema_id: str
260+ value: DataRowMetadataValue
261+
262+
263+ class DataRowMetadataBatchResponse:
264+ data_row_id: str
265+ error: str
266+ fields: List[DataRowMetadataBatchResponseFields]
267+
244268 """
245269
246270 if not (len (metadata )):
247271 raise ValueError ("Empty list passed" )
248272
249273 def _batch_upsert (
250- upserts : List [_UpsertBatchDataRowMetadata ]
274+ upserts : List [_UpsertBatchDataRowMetadata ]
251275 ) -> List [DataRowMetadataBatchResponse ]:
252276
253277 query = """mutation UpsertDataRowMetadataBetaPyApi($metadata: [DataRowCustomMetadataBatchUpsertInput!]!) {
@@ -260,22 +284,28 @@ def _batch_upsert(
260284 }
261285 }
262286 }"""
263-
264- return self .client .execute (query , {"metadata" : upserts })
287+ res = self .client .execute (
288+ query , {"metadata" : upserts })['upsertDataRowCustomMetadata' ]
289+ breakpoint ()
290+ return [DataRowMetadataBatchResponse (** r ) for r in res ]
265291
266292 items = []
267293 for m in metadata :
268294 items .append (
269295 _UpsertBatchDataRowMetadata (
270- dataRowId = m .data_row_id ,
296+ data_row_id = m .data_row_id ,
271297 fields = list (
272298 chain .from_iterable (
273- self ._parse_upsert (m ) for m in m .fields ))).dict ())
299+ self ._parse_upsert (m ) for m in m .fields ))).dict (
300+ by_alias = True ))
274301
275- return _batch_operations (_batch_upsert , items , self ._batch_size )
302+ breakpoint ()
303+ res = _batch_operations (_batch_upsert , items , self ._batch_size )
304+ breakpoint ()
305+ return res
276306
277307 def bulk_delete (
278- self , deletes : List [DeleteDataRowMetadata ]
308+ self , deletes : List [DeleteDataRowMetadata ]
279309 ) -> List [DataRowMetadataBatchResponse ]:
280310 """ Delete metadata from a datarow by specifiying the fields you want to remove
281311
@@ -291,17 +321,18 @@ def bulk_delete(
291321
292322
293323 Args:
294- deletes:
324+ deletes: Data row and schema ids to delete
295325
296326 Returns:
327+ list of
297328
298329 """
299330
300331 if not len (deletes ):
301332 raise ValueError ("Empty list passed" )
302333
303334 def _batch_delete (
304- deletes : List [_DeleteBatchDataRowMetadata ]
335+ deletes : List [_DeleteBatchDataRowMetadata ]
305336 ) -> List [DataRowMetadataBatchResponse ]:
306337 query = """mutation DeleteDataRowMetadataBetaPyApi($deletes: [DataRowCustomMetadataBatchDeleteInput!]!) {
307338 deleteDataRowCustomMetadata(data: $deletes) {
@@ -314,13 +345,15 @@ def _batch_delete(
314345 }
315346 }
316347 """
317- return self .client .execute (query , {"deletes" : deletes })
318348
319- items = []
320- for m in deletes :
321- items . append ( self . _validate_delete ( m ))
349+ res = self . client . execute ( query , { "deletes" : deletes })
350+ breakpoint ()
351+ return res
322352
323- return _batch_operations (_batch_delete , items , batch_size = self ._batch_size )
353+ items = [self ._validate_delete (m ) for m in deletes ]
354+ return _batch_operations (_batch_delete ,
355+ items ,
356+ batch_size = self ._batch_size )
324357
325358 def _parse_upsert (
326359 self , metadatum : DataRowMetadataField
@@ -378,17 +411,17 @@ def _batch_items(iterable, size):
378411
379412
380413def _batch_operations (
381- batch_function : _BatchFunction ,
382- items : List ,
383- batch_size : int = 100 ,
414+ batch_function : _BatchFunction ,
415+ items : List ,
416+ batch_size : int = 100 ,
384417):
385418 response = []
419+
386420 for batch in _batch_items (items , batch_size ):
387421 response += batch_function (batch )
388422 # TODO: understand this better
389423 # if len(response):
390424 # raise ValueError(response)
391-
392425 return response
393426
394427
@@ -410,11 +443,7 @@ def _validate_parse_datetime(field: DataRowMetadataField):
410443def _validate_parse_text (field : DataRowMetadataField ):
411444 if not isinstance (field .value , str ):
412445 raise ValueError ("Invalid value" )
413-
414- return [{
415- "schemaId" : field .schema_id ,
416- "value" : field .value ,
417- }]
446+ return [field .dict (by_alias = True )]
418447
419448
420449def _validate_enum_parse (schema : DataRowMetadataSchema ,
0 commit comments