Skip to content

Commit 69c131d

Browse files
committed
Use new resolver
1 parent 5a9b34c commit 69c131d

File tree

3 files changed

+54
-12
lines changed

3 files changed

+54
-12
lines changed

labelbox/orm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ class Entity(metaclass=EntityMeta):
371371
LabelingFrontendOptions: Type[labelbox.LabelingFrontendOptions]
372372
Label: Type[labelbox.Label]
373373
MEAPredictionImport: Type[labelbox.MEAPredictionImport]
374+
MALPredictionImport: Type[labelbox.MALPredictionImport]
374375
Invite: Type[labelbox.Invite]
375376
InviteLimit: Type[labelbox.InviteLimit]
376377
ProjectRole: Type[labelbox.ProjectRole]

labelbox/schema/annotation_import.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,32 @@ def create_from_url(cls, client: "labelbox.Client", project_id: str,
401401
else:
402402
raise ValueError(f"Url {url} is not reachable")
403403

404+
@classmethod
405+
def create_for_model_run_data_rows(cls, client: "labelbox.Client",
406+
model_run_id: str,
407+
data_row_ids: List[str], project_id: str,
408+
name: str) -> "MALPredictionImport":
409+
"""
410+
Create an MAL prediction import job from a list of data row ids of a specific model run
411+
412+
Args:
413+
client: Labelbox Client for executing queries
414+
data_row_ids: A list of data row ids
415+
model_run_id: model run id
416+
Returns:
417+
MALPredictionImport
418+
"""
419+
query_str = cls._get_model_run_data_rows_mutation()
420+
return cls(
421+
client,
422+
client.execute(query_str,
423+
params={
424+
"dataRowIds": data_row_ids,
425+
"modelRunId": model_run_id,
426+
"projectId": project_id,
427+
"name": name
428+
})["createMalPredictionImportForModelRunDataRows"])
429+
404430
@classmethod
405431
def from_name(cls,
406432
client: "labelbox.Client",
@@ -445,6 +471,17 @@ def _get_url_mutation(cls) -> str:
445471
}) {%s}
446472
}""" % query.results_query_part(cls)
447473

474+
@classmethod
475+
def _get_model_run_data_rows_mutation(cls) -> str:
476+
return """mutation createMalPredictionImportForModelRunDataRows($projectId : ID!, $dataRowIds: [ID!]!, $name: String!, $modelRunId: ID!) {
477+
createMalPredictionImportForModelRunDataRows(data: {
478+
name: $name
479+
modelRunId: $modelRunId
480+
dataRowIds: $dataRowIds
481+
projectId: $projectId
482+
}) {%s}
483+
}""" % query.results_query_part(cls)
484+
448485
@classmethod
449486
def _get_file_mutation(cls) -> str:
450487
return """mutation createMALPredictionImportByFilePyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {

labelbox/schema/model_run.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from labelbox.orm.query import results_query_part
1313
from labelbox.orm.model import Field, Relationship, Entity
1414
from labelbox.orm.db_object import DbObject, experimental
15-
from labelbox.data.annotation_types import LabelList
1615

1716
if TYPE_CHECKING:
1817
from labelbox import MEAPredictionImport
@@ -128,11 +127,11 @@ def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5):
128127
def upsert_predictions_and_send_to_project(
129128
self,
130129
name: str,
131-
predictions: LabelList,
130+
predictions: Union[str, Path, Iterable[Dict]],
132131
project_id: str,
133132
priority: Optional[int] = 5,
134133
) -> 'MEAPredictionImport': # type: ignore
135-
""" Upload predictions and creates a batch import to project.
134+
""" Upload predictions and create a batch import to project.
136135
Args:
137136
name (str): name of the AnnotationImport job as well as the name of the batch import
138137
predictions (Iterable):
@@ -142,16 +141,21 @@ def upsert_predictions_and_send_to_project(
142141
Returns:
143142
(AnnotationImport, Project)
144143
"""
145-
data_rows_set = set(
146-
map(lambda x: x.data_row.id, predictions)[:DATAROWS_IMPORT_LIMIT])
147-
data_rows = list(data_rows_set)
144+
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
148145
project = self.client.get_project(project_id)
149-
batch = project.create_batch(name, data_rows, priority)
150-
151-
predictions_for_data_rows = filter(
152-
lambda x: x.data_row.id in data_rows_set, predictions)
153-
154-
return (self.add_predictions(name, predictions_for_data_rows), batch)
146+
import_job = self.add_predictions(name, predictions)
147+
prediction_statuses = import_job.statuses()
148+
mea_to_mal_data_rows_set = set([
149+
row['dataRow']['id']
150+
for row in prediction_statuses
151+
if row['status'] == 'SUCCESS'
152+
])
153+
mea_to_mal_data_rows = list(
154+
mea_to_mal_data_rows_set)[:DATAROWS_IMPORT_LIMIT]
155+
batch = project.create_batch(name, mea_to_mal_data_rows, priority)
156+
mal_prediction_import = Entity.MALPredictionImport.create_for_model_run_data_rows(
157+
data_row_ids=mea_to_mal_data_rows, project_id=project_id, **kwargs)
158+
return mea_prediction_import, batch, mal_prediction_import
155159

156160
def add_predictions(
157161
self,

0 commit comments

Comments
 (0)