1- from typing import Dict , Generator , List , Optional , Union , Any
1+ from datetime import datetime
2+ from typing import Dict , Generator , List , Optional , Any , Final
23import os
34import json
45import logging
1415from io import StringIO
1516import requests
1617
17- from labelbox import pagination
1818from labelbox .exceptions import InvalidQueryError , LabelboxError , ResourceNotFoundError , InvalidAttributeError
1919from labelbox .orm .comparison import Comparison
2020from labelbox .orm .db_object import DbObject , Updateable , Deletable , experimental
2121from labelbox .orm .model import Entity , Field , Relationship
2222from labelbox .orm import query
2323from labelbox .exceptions import MalformedQueryException
2424from labelbox .pagination import PaginatedCollection
25+ from labelbox .pydantic_compat import BaseModel
2526from labelbox .schema .data_row import DataRow
2627from labelbox .schema .export_filters import DatasetExportFilters , build_filters
2728from labelbox .schema .export_params import CatalogExportParams , validate_catalog_export_params
2829from labelbox .schema .export_task import ExportTask
30+ from labelbox .schema .identifiable import UniqueId , GlobalKey
2931from labelbox .schema .task import Task
3032from labelbox .schema .user import User
3133
3436MAX_DATAROW_PER_API_OPERATION = 150_000
3537
3638
39+ class DataRowUpsertItem (BaseModel ):
40+ id : dict
41+ payload : dict
42+
43+
3744class Dataset (DbObject , Updateable , Deletable ):
3845 """ A Dataset is a collection of DataRows.
3946
@@ -47,6 +54,8 @@ class Dataset(DbObject, Updateable, Deletable):
4754 created_by (Relationship): `ToOne` relationship to User
4855 organization (Relationship): `ToOne` relationship to Organization
4956 """
57+ __upsert_chunk_size : Final = 10_000
58+
5059 name = Field .String ("name" )
5160 description = Field .String ("description" )
5261 updated_at = Field .DateTime ("updated_at" )
@@ -64,16 +73,16 @@ def data_rows(
6473 from_cursor : Optional [str ] = None ,
6574 where : Optional [Comparison ] = None ,
6675 ) -> PaginatedCollection :
67- """
76+ """
6877 Custom method to paginate data_rows via cursor.
6978
7079 Args:
7180 from_cursor (str): Cursor (data row id) to start from, if none, will start from the beginning
72- where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on.
81+ where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on.
7382 example: {'external_id': 'my_external_id'} to get a data row with external_id = 'my_external_id'
7483
7584
76- NOTE:
85+ NOTE:
7786 Order of retrieval is newest data row first.
7887 Deleted data rows are not retrieved.
7988 Failed data rows are not retrieved.
@@ -293,7 +302,10 @@ def create_data_rows(self, items) -> "Task":
293302 task ._user = user
294303 return task
295304
296- def _create_descriptor_file (self , items , max_attachments_per_data_row = None ):
305+ def _create_descriptor_file (self ,
306+ items ,
307+ max_attachments_per_data_row = None ,
308+ is_upsert = False ):
297309 """
298310 This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync`
299311 to prepare the input file. The user defined input is validated, processed, and json stringified.
@@ -346,6 +358,9 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
346358 AssetAttachment = Entity .AssetAttachment
347359
348360 def upload_if_necessary (item ):
361+ if is_upsert and 'row_data' not in item :
362+ # When upserting, row_data is not required
363+ return item
349364 row_data = item ['row_data' ]
350365 if isinstance (row_data , str ) and os .path .exists (row_data ):
351366 item_url = self .client .upload_file (row_data )
@@ -425,17 +440,17 @@ def format_row(item):
425440 return item
426441
427442 def validate_keys (item ):
428- if 'row_data' not in item :
443+ if not is_upsert and 'row_data' not in item :
429444 raise InvalidQueryError (
430445 "`row_data` missing when creating DataRow." )
431446
432447 if isinstance (item .get ('row_data' ),
433448 str ) and item .get ('row_data' ).startswith ("s3:/" ):
434449 raise InvalidQueryError (
435450 "row_data: s3 assets must start with 'https'." )
436- invalid_keys = set ( item ) - {
437- * {f .name for f in DataRow .fields ()}, 'attachments' , 'media_type'
438- }
451+ allowed_extra_fields = { 'attachments' , 'media_type' , 'dataset_id' }
452+ invalid_keys = set ( item ) - {f .name for f in DataRow .fields ()
453+ } - allowed_extra_fields
439454 if invalid_keys :
440455 raise InvalidAttributeError (DataRow , invalid_keys )
441456 return item
@@ -460,7 +475,12 @@ def formatLegacyConversationalData(item):
460475 item ["row_data" ] = one_conversation
461476 return item
462477
463- def convert_item (item ):
478+ def convert_item (data_row_item ):
479+ if isinstance (data_row_item , DataRowUpsertItem ):
480+ item = data_row_item .payload
481+ else :
482+ item = data_row_item
483+
464484 if "tileLayerUrl" in item :
465485 validate_attachments (item )
466486 return item
@@ -478,7 +498,11 @@ def convert_item(item):
478498 parse_metadata_fields (item )
479499 # Upload any local file paths
480500 item = upload_if_necessary (item )
481- return item
501+
502+ if isinstance (data_row_item , DataRowUpsertItem ):
503+ return {'id' : data_row_item .id , 'payload' : item }
504+ else :
505+ return item
482506
483507 if not isinstance (items , Iterable ):
484508 raise ValueError (
@@ -638,13 +662,13 @@ def export_v2(
638662 ) -> Task :
639663 """
640664 Creates a dataset export task with the given params and returns the task.
641-
665+
642666 >>> dataset = client.get_dataset(DATASET_ID)
643667 >>> task = dataset.export_v2(
644668 >>> filters={
645669 >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
646670 >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
647- >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...]
671+ >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...]
648672 >>> },
649673 >>> params={
650674 >>> "performance_details": False,
@@ -749,3 +773,100 @@ def _export(
749773 res = res [mutation_name ]
750774 task_id = res ["taskId" ]
751775 return Task .get_task (self .client , task_id )
776+
777+ def upsert_data_rows (self , items , file_upload_thread_count = 20 ) -> "Task" :
778+ """
779+ Upserts data rows in this dataset. When "key" is provided, and it references an existing data row,
780+ an update will be performed. When "key" is not provided a new data row will be created.
781+
782+ >>> task = dataset.upsert_data_rows([
783+ >>> # create new data row
784+ >>> {
785+ >>> "row_data": "http://my_site.com/photos/img_01.jpg",
786+ >>> "global_key": "global_key1",
787+ >>> "external_id": "ex_id1",
788+ >>> "attachments": [
789+ >>> {"type": AttachmentType.RAW_TEXT, "name": "att1", "value": "test1"}
790+ >>> ],
791+ >>> "metadata": [
792+ >>> {"name": "tag", "value": "tag value"},
793+ >>> ]
794+ >>> },
795+ >>> # update global key of data row by existing global key
796+ >>> {
797+ >>> "key": GlobalKey("global_key1"),
798+ >>> "global_key": "global_key1_updated"
799+ >>> },
800+ >>> # update data row by ID
801+ >>> {
802+ >>> "key": UniqueId(dr.uid),
803+ >>> "external_id": "ex_id1_updated"
804+ >>> },
805+ >>> ])
806+ >>> task.wait_till_done()
807+ """
808+ if len (items ) > MAX_DATAROW_PER_API_OPERATION :
809+ raise MalformedQueryException (
810+ f"Cannot upsert more than { MAX_DATAROW_PER_API_OPERATION } DataRows per function call."
811+ )
812+
813+ specs = self ._convert_items_to_upsert_format (items )
814+ chunks = [
815+ specs [i :i + self .__upsert_chunk_size ]
816+ for i in range (0 , len (specs ), self .__upsert_chunk_size )
817+ ]
818+
819+ def _upload_chunk (_chunk ):
820+ return self ._create_descriptor_file (_chunk , is_upsert = True )
821+
822+ with ThreadPoolExecutor (file_upload_thread_count ) as executor :
823+ futures = [
824+ executor .submit (_upload_chunk , chunk ) for chunk in chunks
825+ ]
826+ chunk_uris = [future .result () for future in as_completed (futures )]
827+
828+ manifest = {
829+ "source" : "SDK" ,
830+ "item_count" : len (specs ),
831+ "chunk_uris" : chunk_uris
832+ }
833+ data = json .dumps (manifest ).encode ("utf-8" )
834+ manifest_uri = self .client .upload_data (data ,
835+ content_type = "application/json" ,
836+ filename = "manifest.json" )
837+
838+ query_str = """
839+ mutation UpsertDataRowsPyApi($manifestUri: String!) {
840+ upsertDataRows(data: { manifestUri: $manifestUri }) {
841+ id createdAt updatedAt name status completionPercentage result errors type metadata
842+ }
843+ }
844+ """
845+
846+ res = self .client .execute (query_str , {"manifestUri" : manifest_uri })
847+ res = res ["upsertDataRows" ]
848+ task = Task (self .client , res )
849+ task ._user = self .client .get_user ()
850+ return task
851+
852+ def _convert_items_to_upsert_format (self , _items ):
853+ _upsert_items : List [DataRowUpsertItem ] = []
854+ for item in _items :
855+ # enforce current dataset's id for all specs
856+ item ['dataset_id' ] = self .uid
857+ key = item .pop ('key' , None )
858+ if not key :
859+ key = {'type' : 'AUTO' , 'value' : '' }
860+ elif isinstance (key , UniqueId ):
861+ key = {'type' : 'ID' , 'value' : key .key }
862+ elif isinstance (key , GlobalKey ):
863+ key = {'type' : 'GKEY' , 'value' : key .key }
864+ else :
865+ raise ValueError (
866+ f"Key must be an instance of UniqueId or GlobalKey, got: { type (item ['key' ]).__name__ } "
867+ )
868+ item = {
869+ k : v for k , v in item .items () if v is not None
870+ } # remove None values
871+ _upsert_items .append (DataRowUpsertItem (payload = item , id = key ))
872+ return _upsert_items
0 commit comments