@@ -669,16 +669,20 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None:
669669 timestamp = datetime .now (timezone .utc ).strftime ("%Y-%m-%dT%H:%M:%SZ" )
670670 self .update (setup_complete = timestamp )
671671
672- def create_batch (self ,
673- name : str ,
674- data_rows : List [Union [str , DataRow ]],
675- priority : int = 5 ,
676- consensus_settings : Optional [Dict [str , float ]] = None ):
677- """Create a new batch for a project. Batches is in Beta and subject to change
672+ def create_batch (
673+ self ,
674+ name : str ,
675+ data_rows : Optional [List [Union [str , DataRow ]]] = None ,
676+ priority : int = 5 ,
677+ consensus_settings : Optional [Dict [str , float ]] = None ,
678+ global_keys : Optional [List [str ]] = None ,
679+ ):
680+ """Create a new batch for a project. One of `global_keys` or `data_rows` must be provided but not both.
678681
679682 Args:
680683 name: a name for the batch, must be unique within a project
681- data_rows: Either a list of `DataRows` or Data Row ids
684+ data_rows: Either a list of `DataRows` or Data Row ids.
685+ global_keys: global keys for data rows to add to the batch.
682686 priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest
683687 consensus_settings: An optional dictionary with consensus settings: {'number_of_labels': 3, 'coverage_percentage': 0.1}
684688 """
@@ -688,35 +692,45 @@ def create_batch(self,
688692 raise ValueError ("Project must be in batch mode" )
689693
690694 dr_ids = []
691- for dr in data_rows :
692- if isinstance (dr , Entity .DataRow ):
693- dr_ids .append (dr .uid )
694- elif isinstance (dr , str ):
695- dr_ids .append (dr )
696- else :
697- raise ValueError ("You can DataRow ids or DataRow objects" )
695+ if data_rows is not None :
696+ for dr in data_rows :
697+ if isinstance (dr , Entity .DataRow ):
698+ dr_ids .append (dr .uid )
699+ elif isinstance (dr , str ):
700+ dr_ids .append (dr )
701+ else :
702+ raise ValueError (
703+ "`data_rows` must be DataRow ids or DataRow objects" )
704+
705+ if data_rows is not None :
706+ row_count = len (data_rows )
707+ elif global_keys is not None :
708+ row_count = len (global_keys )
709+ else :
710+ row_count = 0
698711
699- if len ( dr_ids ) > 100_000 :
712+ if row_count > 100_000 :
700713 raise ValueError (
701714 f"Batch exceeds max size, break into smaller batches" )
702- if not len ( dr_ids ) :
715+ if not row_count :
703716 raise ValueError ("You need at least one data row in a batch" )
704717
705718 self ._wait_until_data_rows_are_processed (
706- dr_ids , self ._wait_processing_max_seconds )
719+ dr_ids , global_keys , self ._wait_processing_max_seconds )
707720
708721 if consensus_settings :
709722 consensus_settings = ConsensusSettings (** consensus_settings ).dict (
710723 by_alias = True )
711724
712725 if len (dr_ids ) >= 10_000 :
713- return self ._create_batch_async (name , dr_ids , priority ,
726+ return self ._create_batch_async (name , dr_ids , global_keys , priority ,
714727 consensus_settings )
715728 else :
716- return self ._create_batch_sync (name , dr_ids , priority ,
729+ return self ._create_batch_sync (name , dr_ids , global_keys , priority ,
717730 consensus_settings )
718731
719- def _create_batch_sync (self , name , dr_ids , priority , consensus_settings ):
732+ def _create_batch_sync (self , name , dr_ids , global_keys , priority ,
733+ consensus_settings ):
720734 method = 'createBatchV2'
721735 query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
722736 project(where: {id: $projectId}) {
@@ -734,6 +748,7 @@ def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
734748 "batchInput" : {
735749 "name" : name ,
736750 "dataRowIds" : dr_ids ,
751+ "globalKeys" : global_keys ,
737752 "priority" : priority ,
738753 "consensusSettings" : consensus_settings
739754 }
@@ -751,7 +766,8 @@ def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
751766
752767 def _create_batch_async (self ,
753768 name : str ,
754- dr_ids : List [str ],
769+ dr_ids : Optional [List [str ]] = None ,
770+ global_keys : Optional [List [str ]] = None ,
755771 priority : int = 5 ,
756772 consensus_settings : Optional [Dict [str ,
757773 float ]] = None ):
@@ -794,6 +810,7 @@ def _create_batch_async(self,
794810 "input" : {
795811 "batchId" : batch_id ,
796812 "dataRowIds" : dr_ids ,
813+ "globalKeys" : global_keys ,
797814 "priority" : priority ,
798815 }
799816 }
@@ -1260,38 +1277,50 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
12601277 raise ValueError (
12611278 f'Invalid annotations given of type: { type (annotations )} ' )
12621279
1263- def _wait_until_data_rows_are_processed (self ,
1264- data_row_ids : List [str ],
1265- wait_processing_max_seconds : int ,
1266- sleep_interval = 30 ):
1280+ def _wait_until_data_rows_are_processed (
1281+ self ,
1282+ data_row_ids : Optional [List [str ]] = None ,
1283+ global_keys : Optional [List [str ]] = None ,
1284+ wait_processing_max_seconds : int = _wait_processing_max_seconds ,
1285+ sleep_interval = 30 ):
12671286 """ Wait until all the specified data rows are processed"""
12681287 start_time = datetime .now ()
1288+
12691289 while True :
12701290 if (datetime .now () -
12711291 start_time ).total_seconds () >= wait_processing_max_seconds :
12721292 raise ProcessingWaitTimeout (
12731293 "Maximum wait time exceeded while waiting for data rows to be processed. Try creating a batch a bit later"
12741294 )
12751295
1276- all_good = self .__check_data_rows_have_been_processed (data_row_ids )
1296+ all_good = self .__check_data_rows_have_been_processed (
1297+ data_row_ids , global_keys )
12771298 if all_good :
12781299 return
12791300
12801301 logger .debug (
12811302 'Some of the data rows are still being processed, waiting...' )
12821303 time .sleep (sleep_interval )
12831304
1284- def __check_data_rows_have_been_processed (self , data_row_ids : List [str ]):
1285- data_row_ids_param = "data_row_ids"
1305+ def __check_data_rows_have_been_processed (
1306+ self ,
1307+ data_row_ids : Optional [List [str ]] = None ,
1308+ global_keys : Optional [List [str ]] = None ):
1309+
1310+ if data_row_ids is not None and len (data_row_ids ) > 0 :
1311+ param_name = "dataRowIds"
1312+ params = {param_name : data_row_ids }
1313+ else :
1314+ param_name = "globalKeys"
1315+ global_keys = global_keys if global_keys is not None else []
1316+ params = {param_name : global_keys }
12861317
1287- query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]! ) {
1288- queryAllDataRowsHaveBeenProcessed(dataRowIds :$%s) {
1318+ query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]) {
1319+ queryAllDataRowsHaveBeenProcessed(%s :$%s) {
12891320 allDataRowsHaveBeenProcessed
12901321 }
1291- }""" % (data_row_ids_param , data_row_ids_param )
1322+ }""" % (param_name , param_name , param_name )
12921323
1293- params = {}
1294- params [data_row_ids_param ] = data_row_ids
12951324 response = self .client .execute (query_str , params )
12961325 return response ["queryAllDataRowsHaveBeenProcessed" ][
12971326 "allDataRowsHaveBeenProcessed" ]
0 commit comments