1818from labelbox .orm .model import Entity , Field , Relationship
1919from labelbox .pagination import PaginatedCollection
2020from labelbox .schema .consensus_settings import ConsensusSettings
21+ from labelbox .schema .data_row import DataRow
2122from labelbox .schema .media_type import MediaType
2223from labelbox .schema .queue_mode import QueueMode
2324from labelbox .schema .resource_tag import ResourceTag
24- from labelbox .schema .data_row import DataRow
2525
2626if TYPE_CHECKING :
2727 from labelbox import BulkImportRequest
@@ -608,22 +608,31 @@ def create_batch(self,
608608
609609 self ._wait_until_data_rows_are_processed (
610610 dr_ids , self ._wait_processing_max_seconds )
611- method = 'createBatchV2'
612- query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
613- project(where: {id: $projectId}) {
614- %s(input: $batchInput) {
615- batch {
616- %s
617- }
618- failedDataRowIds
619- }
620- }
621- }
622- """ % (method , method , query .results_query_part (Entity .Batch ))
623611
624612 if consensus_settings :
625613 consensus_settings = ConsensusSettings (** consensus_settings ).dict (
626614 by_alias = True )
615+
616+ if len (dr_ids ) >= 10_000 :
617+ return self ._create_batch_async (name , dr_ids , priority ,
618+ consensus_settings )
619+ else :
620+ return self ._create_batch_sync (name , dr_ids , priority ,
621+ consensus_settings )
622+
623+ def _create_batch_sync (self , name , dr_ids , priority , consensus_settings ):
624+ method = 'createBatchV2'
625+ query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
626+ project(where: {id: $projectId}) {
627+ %s(input: $batchInput) {
628+ batch {
629+ %s
630+ }
631+ failedDataRowIds
632+ }
633+ }
634+ }
635+ """ % (method , method , query .results_query_part (Entity .Batch ))
627636 params = {
628637 "projectId" : self .uid ,
629638 "batchInput" : {
@@ -633,7 +642,6 @@ def create_batch(self,
633642 "consensusSettings" : consensus_settings
634643 }
635644 }
636-
637645 res = self .client .execute (query_str ,
638646 params ,
639647 timeout = 180.0 ,
@@ -645,6 +653,111 @@ def create_batch(self,
645653 batch ,
646654 failed_data_row_ids = res ['failedDataRowIds' ])
647655
656+ def _create_batch_async (self ,
657+ name : str ,
658+ dr_ids : List [str ],
659+ priority : int = 5 ,
660+ consensus_settings : Optional [Dict [str ,
661+ float ]] = None ):
662+ method = 'createEmptyBatch'
663+ create_empty_batch_mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateEmptyBatchInput!) {
664+ project(where: {id: $projectId}) {
665+ %s(input: $input) {
666+ id
667+ }
668+ }
669+ }
670+ """ % (method , method )
671+
672+ params = {
673+ "projectId" : self .uid ,
674+ "input" : {
675+ "name" : name ,
676+ "consensusSettings" : consensus_settings
677+ }
678+ }
679+
680+ res = self .client .execute (create_empty_batch_mutation_str ,
681+ params ,
682+ timeout = 180.0 ,
683+ experimental = True )["project" ][method ]
684+ batch_id = res ['id' ]
685+
686+ method = 'addDataRowsToBatchAsync'
687+ add_data_rows_mutation_str = """mutation %sPyApi($projectId: ID!, $input: AddDataRowsToBatchInput!) {
688+ project(where: {id: $projectId}) {
689+ %s(input: $input) {
690+ taskId
691+ }
692+ }
693+ }
694+ """ % (method , method )
695+
696+ params = {
697+ "projectId" : self .uid ,
698+ "input" : {
699+ "batchId" : batch_id ,
700+ "dataRowIds" : dr_ids ,
701+ "priority" : priority ,
702+ }
703+ }
704+
705+ res = self .client .execute (add_data_rows_mutation_str ,
706+ params ,
707+ timeout = 180.0 ,
708+ experimental = True )["project" ][method ]
709+
710+ task_id = res ['taskId' ]
711+
712+ timeout_seconds = 600
713+ sleep_time = 2
714+ get_task_query_str = """query %s($taskId: ID!) {
715+ task(where: {id: $taskId}) {
716+ status
717+ }
718+ }
719+ """ % "getTaskPyApi"
720+
721+ while True :
722+ task_status = self .client .execute (
723+ get_task_query_str , {'taskId' : task_id },
724+ experimental = True )['task' ]['status' ]
725+
726+ if task_status == "COMPLETE" :
727+ # obtain batch entity to return
728+ get_batch_str = """query %s($projectId: ID!, $batchId: ID!) {
729+ project(where: {id: $projectId}) {
730+ batches(where: {id: $batchId}) {
731+ nodes {
732+ %s
733+ }
734+ }
735+ }
736+ }
737+ """ % ("getProjectBatchPyApi" ,
738+ query .results_query_part (Entity .Batch ))
739+
740+ batch = self .client .execute (
741+ get_batch_str , {
742+ "projectId" : self .uid ,
743+ "batchId" : batch_id
744+ },
745+ timeout = 180.0 ,
746+ experimental = True )["project" ]["batches" ]["nodes" ][0 ]
747+
748+ # TODO async endpoints currently do not provide failed_data_row_ids in response
749+ return Entity .Batch (self .client , self .uid , batch )
750+ elif task_status == "IN_PROGRESS" :
751+ timeout_seconds -= sleep_time
752+ if timeout_seconds <= 0 :
753+ raise LabelboxError (
754+ f"Timed out while waiting for batch to be created." )
755+ logger .debug ("Creating batch, waiting for server..." , self .uid )
756+ time .sleep (sleep_time )
757+ continue
758+ else :
759+ raise LabelboxError (f"Batch was not created successfully." )
760+
648761 def _update_queue_mode (self , mode : "QueueMode" ) -> "QueueMode" :
649762 """
650763 Updates the queueing mode of this project.
0 commit comments