44from collections import namedtuple
55from datetime import datetime , timezone
66from pathlib import Path
7- from typing import TYPE_CHECKING , Dict , Union , Iterable , List , Optional , Any
7+ from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Union
88from urllib .parse import urlparse
99
1010import ndjson
1111import requests
1212
1313from labelbox import utils
14- from labelbox .exceptions import InvalidQueryError , LabelboxError
14+ from labelbox .exceptions import (InvalidQueryError , LabelboxError ,
15+ ProcessingWaitTimeout )
1516from labelbox .orm import query
16- from labelbox .orm .db_object import DbObject , Updateable , Deletable
17+ from labelbox .orm .db_object import DbObject , Deletable , Updateable
1718from labelbox .orm .model import Entity , Field , Relationship
1819from labelbox .pagination import PaginatedCollection
1920from labelbox .schema .consensus_settings import ConsensusSettings
@@ -90,6 +91,9 @@ class Project(DbObject, Updateable, Deletable):
9091 benchmarks = Relationship .ToMany ("Benchmark" , False )
9192 ontology = Relationship .ToOne ("Ontology" , True )
9293
94+ #
95+ _wait_processing_max_seconds = 3600
96+
9397 def update (self , ** kwargs ):
9498 """ Updates this project with the specified attributes
9599
@@ -319,7 +323,7 @@ def _validate_datetime(string_date: str) -> bool:
319323 return True
320324 except ValueError :
321325 pass
322- raise ValueError (f"""Incorrect format for: { string_date } .
326+ raise ValueError (f"""Incorrect format for: { string_date } .
323327 Format must be \" YYYY-MM-DD\" or \" YYYY-MM-DD hh:mm:ss\" """ )
324328 return True
325329
@@ -595,11 +599,16 @@ def create_batch(self,
595599 if not len (dr_ids ):
596600 raise ValueError ("You need at least one data row in a batch" )
597601
598- method = 'createBatch'
602+ self ._wait_until_data_rows_are_processed (
603+ data_rows , self ._wait_processing_max_seconds )
604+ method = 'createBatchV2'
599605 query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
600606 project(where: {id: $projectId}) {
601607 %s(input: $batchInput) {
602- %s
608+ batch {
609+ %s
610+ }
611+ failedDataRowIds
603612 }
604613 }
605614 }
@@ -622,9 +631,12 @@ def create_batch(self,
622631 params ,
623632 timeout = 180.0 ,
624633 experimental = True )["project" ][method ]
625-
626- res ['size' ] = len (dr_ids )
627- return Entity .Batch (self .client , self .uid , res )
634+ batch = res ['batch' ]
635+ batch ['size' ] = len (dr_ids )
636+ return Entity .Batch (self .client ,
637+ self .uid ,
638+ batch ,
639+ failed_data_row_ids = res ['failedDataRowIds' ])
628640
629641 def _update_queue_mode (self , mode : "QueueMode" ) -> "QueueMode" :
630642 """
@@ -977,6 +989,42 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
977989 raise ValueError (
978990 f'Invalid annotations given of type: { type (annotations )} ' )
979991
992+ def _wait_until_data_rows_are_processed (self ,
993+ data_row_ids : List [str ],
994+ wait_processing_max_seconds : int ,
995+ sleep_interval = 30 ):
996+ """ Wait until all the specified data rows are processed"""
997+ start_time = datetime .now ()
998+ while True :
999+ if (datetime .now () -
1000+ start_time ).total_seconds () >= wait_processing_max_seconds :
1001+ raise ProcessingWaitTimeout (
1002+ "Maximum wait time exceeded while waiting for data rows to be processed. Try creating a batch a bit later"
1003+ )
1004+
1005+ all_good = self .__check_data_rows_have_been_processed (data_row_ids )
1006+ if all_good :
1007+ return
1008+
1009+ logger .debug (
1010+ 'Some of the data rows are still being processed, waiting...' )
1011+ time .sleep (sleep_interval )
1012+
1013+ def __check_data_rows_have_been_processed (self , data_row_ids : List [str ]):
1014+ data_row_ids_param = "data_row_ids"
1015+
1016+ query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]!) {
1017+ queryAllDataRowsHaveBeenProcessed(dataRowIds:$%s) {
1018+ allDataRowsHaveBeenProcessed
1019+ }
1020+ }""" % (data_row_ids_param , data_row_ids_param )
1021+
1022+ params = {}
1023+ params [data_row_ids_param ] = data_row_ids
1024+ response = self .client .execute (query_str , params )
1025+ return response ["queryAllDataRowsHaveBeenProcessed" ][
1026+ "allDataRowsHaveBeenProcessed" ]
1027+
9801028
9811029class ProjectMember (DbObject ):
9821030 user = Relationship .ToOne ("User" , cache = True )
0 commit comments