44from collections import namedtuple
55from datetime import datetime , timezone
66from pathlib import Path
7- from typing import TYPE_CHECKING , Dict , Union , Iterable , List , Optional , Any
7+ from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Union
88from urllib .parse import urlparse
99
1010import ndjson
1111import requests
12-
1312from labelbox import utils
1413from labelbox .exceptions import InvalidQueryError , LabelboxError
1514from labelbox .orm import query
16- from labelbox .orm .db_object import DbObject , Updateable , Deletable
15+ from labelbox .orm .db_object import DbObject , Deletable , Updateable
1716from labelbox .orm .model import Entity , Field , Relationship
1817from labelbox .pagination import PaginatedCollection
1918from labelbox .schema .media_type import MediaType
@@ -318,7 +317,7 @@ def _validate_datetime(string_date: str) -> bool:
318317 return True
319318 except ValueError :
320319 pass
321- raise ValueError (f"""Incorrect format for: { string_date } .
320+ raise ValueError (f"""Incorrect format for: { string_date } .
322321 Format must be \" YYYY-MM-DD\" or \" YYYY-MM-DD hh:mm:ss\" """ )
323322 return True
324323
@@ -561,7 +560,7 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None:
561560 timestamp = datetime .now (timezone .utc ).strftime ("%Y-%m-%dT%H:%M:%SZ" )
562561 self .update (setup_complete = timestamp )
563562
564- def create_batch (self , name : str , data_rows : List [str ], priority : int = 5 ):
563+ def create_batch (self , name : str , data_rows : List [str ], priority : int = 5 , wait_processing_max_seconds : int = 5 ):
565564 """Create a new batch for a project. Batches is in Beta and subject to change
566565
567566 Args:
@@ -590,11 +589,18 @@ def create_batch(self, name: str, data_rows: List[str], priority: int = 5):
590589 if not len (dr_ids ):
591590 raise ValueError ("You need at least one data row in a batch" )
592591
593- method = 'createBatch'
592+ self ._wait_until_data_rows_are_processed (
593+ data_rows ,
594+ wait_processing_max_seconds = wait_processing_max_seconds
595+ )
596+ method = 'createBatchV2'
594597 query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
595598 project(where: {id: $projectId}) {
596599 %s(input: $batchInput) {
597- %s
600+ batch{
601+ %s
602+ }
603+ failedDataRowIds
598604 }
599605 }
600606 }
@@ -613,9 +619,9 @@ def create_batch(self, name: str, data_rows: List[str], priority: int = 5):
613619 params ,
614620 timeout = 180.0 ,
615621 experimental = True )["project" ][method ]
616-
617- res ['size' ] = len (dr_ids )
618- return Entity .Batch (self .client , self .uid , res )
622+ batch = res [ 'batch' ]
623+ batch ['size' ] = len (dr_ids )
624+ return Entity .Batch (self .client , self .uid , batch , failed_data_row_ids = res [ 'failedDataRowIds' ] )
619625
620626 def _update_queue_mode (self , mode : "QueueMode" ) -> "QueueMode" :
621627 """
@@ -964,6 +970,34 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
964970 raise ValueError (
965971 f'Invalid annotations given of type: { type (annotations )} ' )
966972
973+ def _wait_until_data_rows_are_processed (self , data_row_ids : List [str ], wait_processing_max_seconds : int , sleep_interval = 30 ):
974+ """ Wait until all the specified data rows are processed"""
975+ start_time = datetime .now ()
976+ while True :
977+ if (datetime .now () - start_time ).total_seconds () >= wait_processing_max_seconds :
978+ logger .warning (
979+ """Not all data rows have been processed, proceeding anyway""" )
980+ return
981+
982+ all_good = self .__check_data_rows_have_been_processed (data_row_ids )
983+ if all_good :
984+ return
985+ time .sleep (sleep_interval )
986+
987+ def __check_data_rows_have_been_processed (self , data_row_ids : List [str ]):
988+ data_row_ids_param = "data_row_ids"
989+
990+ query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]!) {
991+ queryAllDataRowsHaveBeenProcessed(dataRowIds:$%s) {
992+ allDataRowsHaveBeenProcessed
993+ }
994+ }""" % (data_row_ids_param , data_row_ids_param )
995+
996+ params = {}
997+ params [data_row_ids_param ] = data_row_ids
998+ response = self .client .execute (query_str , params )
999+ return response ["queryAllDataRowsHaveBeenProcessed" ]["allDataRowsHaveBeenProcessed" ]
1000+
9671001
9681002class ProjectMember (DbObject ):
9691003 user = Relationship .ToOne ("User" , cache = True )
0 commit comments