1212from labelbox .orm .query import results_query_part
1313from labelbox .orm .model import Field , Relationship , Entity
1414from labelbox .orm .db_object import DbObject , experimental
15- from labelbox .data .annotation_types import LabelList
1615
1716if TYPE_CHECKING :
1817 from labelbox import MEAPredictionImport
@@ -128,11 +127,11 @@ def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5):
128127 def upsert_predictions_and_send_to_project (
129128 self ,
130129 name : str ,
131- predictions : LabelList ,
130+ predictions : Union [ str , Path , Iterable [ Dict ]] ,
132131 project_id : str ,
133132 priority : Optional [int ] = 5 ,
134133 ) -> 'MEAPredictionImport' : # type: ignore
135- """ Upload predictions and creates a batch import to project.
134+ """ Upload predictions and create a batch import to project.
136135 Args:
137136 name (str): name of the AnnotationImport job as well as the name of the batch import
138137 predictions (Iterable):
@@ -142,16 +141,21 @@ def upsert_predictions_and_send_to_project(
142141 Returns:
143142 (AnnotationImport, Project)
144143 """
145- data_rows_set = set (
146- map (lambda x : x .data_row .id , predictions )[:DATAROWS_IMPORT_LIMIT ])
147- data_rows = list (data_rows_set )
144+ kwargs = dict (client = self .client , model_run_id = self .uid , name = name )
148145 project = self .client .get_project (project_id )
149- batch = project .create_batch (name , data_rows , priority )
150-
151- predictions_for_data_rows = filter (
152- lambda x : x .data_row .id in data_rows_set , predictions )
153-
154- return (self .add_predictions (name , predictions_for_data_rows ), batch )
146+ import_job = self .add_predictions (name , predictions )
147+ prediction_statuses = import_job .statuses ()
148+ mea_to_mal_data_rows_set = set ([
149+ row ['dataRow' ]['id' ]
150+ for row in prediction_statuses
151+ if row ['status' ] == 'SUCCESS'
152+ ])
153+ mea_to_mal_data_rows = list (
154+ mea_to_mal_data_rows_set )[:DATAROWS_IMPORT_LIMIT ]
155+ batch = project .create_batch (name , mea_to_mal_data_rows , priority )
156+ mal_prediction_import = Entity .MALPredictionImport .create_for_model_run_data_rows (
157+ data_row_ids = mea_to_mal_data_rows , project_id = project_id , ** kwargs )
158+ return mea_prediction_import , batch , mal_prediction_import
155159
156160 def add_predictions (
157161 self ,
0 commit comments