@@ -666,16 +666,20 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None:
666666 timestamp = datetime .now (timezone .utc ).strftime ("%Y-%m-%dT%H:%M:%SZ" )
667667 self .update (setup_complete = timestamp )
668668
669- def create_batch (self ,
670- name : str ,
671- data_rows : List [Union [str , DataRow ]],
672- priority : int = 5 ,
673- consensus_settings : Optional [Dict [str , float ]] = None ):
674- """Create a new batch for a project. Batches is in Beta and subject to change
669+ def create_batch (
670+ self ,
671+ name : str ,
672+ data_rows : Optional [List [Union [str , DataRow ]]] = None ,
673+ priority : int = 5 ,
674+ consensus_settings : Optional [Dict [str , float ]] = None ,
675+ global_keys : Optional [List [str ]] = None ,
676+ ):
677+ """Create a new batch for a project. One of `global_keys` or `data_rows` must be provided but not both.
675678
676679 Args:
677680 name: a name for the batch, must be unique within a project
678- data_rows: Either a list of `DataRows` or Data Row ids
681+ data_rows: Either a list of `DataRows` or Data Row ids.
682+ global_keys: global keys for data rows to add to the batch.
679683 priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest
680684 consensus_settings: An optional dictionary with consensus settings: {'number_of_labels': 3, 'coverage_percentage': 0.1}
681685 """
@@ -685,35 +689,45 @@ def create_batch(self,
685689 raise ValueError ("Project must be in batch mode" )
686690
687691 dr_ids = []
688- for dr in data_rows :
689- if isinstance (dr , Entity .DataRow ):
690- dr_ids .append (dr .uid )
691- elif isinstance (dr , str ):
692- dr_ids .append (dr )
693- else :
694- raise ValueError ("You can DataRow ids or DataRow objects" )
692+ if data_rows is not None :
693+ for dr in data_rows :
694+ if isinstance (dr , Entity .DataRow ):
695+ dr_ids .append (dr .uid )
696+ elif isinstance (dr , str ):
697+ dr_ids .append (dr )
698+ else :
699+ raise ValueError (
700+ "`data_rows` must be DataRow ids or DataRow objects" )
701+
702+ if data_rows is not None :
703+ row_count = len (data_rows )
704+ elif global_keys is not None :
705+ row_count = len (global_keys )
706+ else :
707+ row_count = 0
695708
696- if len ( dr_ids ) > 100_000 :
709+ if row_count > 100_000 :
697710 raise ValueError (
698711 f"Batch exceeds max size, break into smaller batches" )
699- if not len ( dr_ids ) :
712+ if not row_count :
700713 raise ValueError ("You need at least one data row in a batch" )
701714
702715 self ._wait_until_data_rows_are_processed (
703- dr_ids , self ._wait_processing_max_seconds )
716+ dr_ids , global_keys , self ._wait_processing_max_seconds )
704717
705718 if consensus_settings :
706719 consensus_settings = ConsensusSettings (** consensus_settings ).dict (
707720 by_alias = True )
708721
709722 if len (dr_ids ) >= 10_000 :
710- return self ._create_batch_async (name , dr_ids , priority ,
723+ return self ._create_batch_async (name , dr_ids , global_keys , priority ,
711724 consensus_settings )
712725 else :
713- return self ._create_batch_sync (name , dr_ids , priority ,
726+ return self ._create_batch_sync (name , dr_ids , global_keys , priority ,
714727 consensus_settings )
715728
716- def _create_batch_sync (self , name , dr_ids , priority , consensus_settings ):
729+ def _create_batch_sync (self , name , dr_ids , global_keys , priority ,
730+ consensus_settings ):
717731 method = 'createBatchV2'
718732 query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
719733 project(where: {id: $projectId}) {
@@ -731,6 +745,7 @@ def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
731745 "batchInput" : {
732746 "name" : name ,
733747 "dataRowIds" : dr_ids ,
748+ "globalKeys" : global_keys ,
734749 "priority" : priority ,
735750 "consensusSettings" : consensus_settings
736751 }
@@ -748,7 +763,8 @@ def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
748763
749764 def _create_batch_async (self ,
750765 name : str ,
751- dr_ids : List [str ],
766+ dr_ids : Optional [List [str ]] = None ,
767+ global_keys : Optional [List [str ]] = None ,
752768 priority : int = 5 ,
753769 consensus_settings : Optional [Dict [str ,
754770 float ]] = None ):
@@ -791,6 +807,7 @@ def _create_batch_async(self,
791807 "input" : {
792808 "batchId" : batch_id ,
793809 "dataRowIds" : dr_ids ,
810+ "globalKeys" : global_keys ,
794811 "priority" : priority ,
795812 }
796813 }
@@ -1257,38 +1274,50 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
12571274 raise ValueError (
12581275 f'Invalid annotations given of type: { type (annotations )} ' )
12591276
1260- def _wait_until_data_rows_are_processed (self ,
1261- data_row_ids : List [str ],
1262- wait_processing_max_seconds : int ,
1263- sleep_interval = 30 ):
1277+ def _wait_until_data_rows_are_processed (
1278+ self ,
1279+ data_row_ids : Optional [List [str ]] = None ,
1280+ global_keys : Optional [List [str ]] = None ,
1281+ wait_processing_max_seconds : int = _wait_processing_max_seconds ,
1282+ sleep_interval = 30 ):
12641283 """ Wait until all the specified data rows are processed"""
12651284 start_time = datetime .now ()
1285+
12661286 while True :
12671287 if (datetime .now () -
12681288 start_time ).total_seconds () >= wait_processing_max_seconds :
12691289 raise ProcessingWaitTimeout (
12701290 "Maximum wait time exceeded while waiting for data rows to be processed. Try creating a batch a bit later"
12711291 )
12721292
1273- all_good = self .__check_data_rows_have_been_processed (data_row_ids )
1293+ all_good = self .__check_data_rows_have_been_processed (
1294+ data_row_ids , global_keys )
12741295 if all_good :
12751296 return
12761297
12771298 logger .debug (
12781299 'Some of the data rows are still being processed, waiting...' )
12791300 time .sleep (sleep_interval )
12801301
1281- def __check_data_rows_have_been_processed (self , data_row_ids : List [str ]):
1282- data_row_ids_param = "data_row_ids"
1302+ def __check_data_rows_have_been_processed (
1303+ self ,
1304+ data_row_ids : Optional [List [str ]] = None ,
1305+ global_keys : Optional [List [str ]] = None ):
1306+
1307+ if data_row_ids is not None and len (data_row_ids ) > 0 :
1308+ param_name = "dataRowIds"
1309+ params = {param_name : data_row_ids }
1310+ else :
1311+ param_name = "globalKeys"
1312+ global_keys = global_keys if global_keys is not None else []
1313+ params = {param_name : global_keys }
12831314
1284- query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]! ) {
1285- queryAllDataRowsHaveBeenProcessed(dataRowIds :$%s) {
1315+ query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]) {
1316+ queryAllDataRowsHaveBeenProcessed(%s :$%s) {
12861317 allDataRowsHaveBeenProcessed
12871318 }
1288- }""" % (data_row_ids_param , data_row_ids_param )
1319+ }""" % (param_name , param_name , param_name )
12891320
1290- params = {}
1291- params [data_row_ids_param ] = data_row_ids
12921321 response = self .client .execute (query_str , params )
12931322 return response ["queryAllDataRowsHaveBeenProcessed" ][
12941323 "allDataRowsHaveBeenProcessed" ]
0 commit comments