Skip to content

Commit fe3af35

Browse files
authored
Merge pull request #634 from Labelbox/mno/AL-2849
[AL-2849] Add upsert_predictions_and_send_to_project method
2 parents 937599b + 2eafbb3 commit fe3af35

File tree

5 files changed

+291
-9
lines changed

5 files changed

+291
-9
lines changed

labelbox/orm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ class Entity(metaclass=EntityMeta):
371371
LabelingFrontendOptions: Type[labelbox.LabelingFrontendOptions]
372372
Label: Type[labelbox.Label]
373373
MEAPredictionImport: Type[labelbox.MEAPredictionImport]
374+
MALPredictionImport: Type[labelbox.MALPredictionImport]
374375
Invite: Type[labelbox.Invite]
375376
InviteLimit: Type[labelbox.InviteLimit]
376377
ProjectRole: Type[labelbox.ProjectRole]

labelbox/schema/annotation_import.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,88 @@ def _create_mea_import_from_bytes(
318318
return cls(client, res["createModelErrorAnalysisPredictionImport"])
319319

320320

321+
class MEAToMALPredictionImport(AnnotationImport):
322+
project = Relationship.ToOne("Project", cache=True)
323+
324+
@property
325+
def parent_id(self) -> str:
326+
"""
327+
Identifier for this import. Used to refresh the status
328+
"""
329+
return self.project().uid
330+
331+
@classmethod
332+
def create_for_model_run_data_rows(cls, client: "labelbox.Client",
333+
model_run_id: str,
334+
data_row_ids: List[str], project_id: str,
335+
name: str) -> "MEAToMALPredictionImport":
336+
"""
337+
Create an MEA to MAL prediction import job from a list of data row ids of a specific model run
338+
339+
Args:
340+
client: Labelbox Client for executing queries
341+
data_row_ids: A list of data row ids
342+
model_run_id: model run id
343+
Returns:
344+
MEAToMALPredictionImport
345+
"""
346+
query_str = cls._get_model_run_data_rows_mutation()
347+
return cls(
348+
client,
349+
client.execute(query_str,
350+
params={
351+
"dataRowIds": data_row_ids,
352+
"modelRunId": model_run_id,
353+
"projectId": project_id,
354+
"name": name
355+
})["createMalPredictionImportForModelRunDataRows"])
356+
357+
@classmethod
358+
def from_name(cls,
359+
client: "labelbox.Client",
360+
project_id: str,
361+
name: str,
362+
as_json: bool = False) -> "MEAToMALPredictionImport":
363+
"""
364+
Retrieves an MEA to MAL import job.
365+
366+
Args:
367+
client: Labelbox Client for executing queries
368+
project_id: ID used for querying import jobs
369+
name: Name of the import job.
370+
Returns:
371+
MALPredictionImport
372+
"""
373+
query_str = """query getMEAToMALPredictionImportPyApi($projectId : ID!, $name: String!) {
374+
meaToMalPredictionImport(
375+
where: {projectId: $projectId, name: $name}){
376+
%s
377+
}}""" % query.results_query_part(cls)
378+
params = {
379+
"projectId": project_id,
380+
"name": name,
381+
}
382+
response = client.execute(query_str, params)
383+
if response is None:
384+
raise labelbox.exceptions.ResourceNotFoundError(
385+
MALPredictionImport, params)
386+
response = response["meaToMalPredictionImport"]
387+
if as_json:
388+
return response
389+
return cls(client, response)
390+
391+
@classmethod
392+
def _get_model_run_data_rows_mutation(cls) -> str:
393+
return """mutation createMalPredictionImportForModelRunDataRowsPyApi($dataRowIds: [ID!]!, $name: String!, $modelRunId: ID!, $projectId:ID!) {
394+
createMalPredictionImportForModelRunDataRows(data: {
395+
name: $name
396+
modelRunId: $modelRunId
397+
dataRowIds: $dataRowIds
398+
projectId: $projectId
399+
}) {%s}
400+
}""" % query.results_query_part(cls)
401+
402+
321403
class MALPredictionImport(AnnotationImport):
322404
project = Relationship.ToOne("Project", cache=True)
323405

labelbox/schema/model_run.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
logger = logging.getLogger(__name__)
2020

21+
DATAROWS_IMPORT_LIMIT = 25000
22+
2123

2224
class DataSplit(Enum):
2325
TRAINING = "TRAINING"
@@ -123,6 +125,56 @@ def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5):
123125
)
124126
time.sleep(sleep_time)
125127

128+
def upsert_predictions_and_send_to_project(
129+
self,
130+
name: str,
131+
predictions: Union[str, Path, Iterable[Dict]],
132+
project_id: str,
133+
priority: Optional[int] = 5,
134+
) -> 'MEAPredictionImport': # type: ignore
135+
""" Upload predictions and create a batch import to project.
136+
Args:
137+
name (str): name of the AnnotationImport job as well as the name of the batch import
138+
predictions (Iterable):
139+
iterable of annotation rows
140+
project_id (str): id of the project to import into
141+
priority (int): priority of the job
142+
Returns:
143+
(MEAPredictionImport, Batch, MEAToMALPredictionImport)
144+
"""
145+
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
146+
project = self.client.get_project(project_id)
147+
import_job = self.add_predictions(name, predictions)
148+
prediction_statuses = import_job.statuses
149+
mea_to_mal_data_rows_set = set([
150+
row['dataRow']['id']
151+
for row in prediction_statuses
152+
if row['status'] == 'SUCCESS'
153+
])
154+
mea_to_mal_data_rows = list(
155+
mea_to_mal_data_rows_set)[:DATAROWS_IMPORT_LIMIT]
156+
157+
if len(mea_to_mal_data_rows) >= DATAROWS_IMPORT_LIMIT:
158+
159+
logger.warning(
160+
f"Got {len(mea_to_mal_data_rows_set)} data rows to import, trimmed down to {DATAROWS_IMPORT_LIMIT} data rows"
161+
)
162+
if len(mea_to_mal_data_rows) == 0:
163+
return import_job, None, None
164+
165+
try:
166+
batch = project.create_batch(name, mea_to_mal_data_rows, priority)
167+
try:
168+
mal_prediction_import = Entity.MEAToMALPredictionImport.create_for_model_run_data_rows(
169+
data_row_ids=mea_to_mal_data_rows,
170+
project_id=project_id,
171+
**kwargs)
172+
return import_job, batch, mal_prediction_import
173+
except:
174+
return import_job, batch, None
175+
except:
176+
return import_job, None, None
177+
126178
def add_predictions(
127179
self,
128180
name: str,
@@ -264,11 +316,11 @@ def update_status(self,
264316

265317
@experimental
266318
def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
267-
"""
319+
"""
268320
Updates the Model Run's training metadata config
269-
Args:
321+
Args:
270322
config (dict): A dictionary of keys and values
271-
Returns:
323+
Returns:
272324
Model Run id and updated training metadata
273325
"""
274326
data: Dict[str, Any] = {'config': config}
@@ -285,9 +337,9 @@ def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
285337

286338
@experimental
287339
def reset_config(self) -> Dict[str, Any]:
288-
"""
340+
"""
289341
Resets Model Run's training metadata config
290-
Returns:
342+
Returns:
291343
Model Run id and reset training metadata
292344
"""
293345
res = self.client.execute(
@@ -300,10 +352,10 @@ def reset_config(self) -> Dict[str, Any]:
300352

301353
@experimental
302354
def get_config(self) -> Dict[str, Any]:
303-
"""
304-
Gets Model Run's training metadata
305-
Returns:
306-
training metadata as a dictionary
355+
"""
356+
Gets Model Run's training metadata
357+
Returns:
358+
training metadata as a dictionary
307359
"""
308360
res = self.client.execute("""query ModelRunPyApi($modelRunId: ID!){
309361
modelRun(where: {id : $modelRunId}){trainingMetadata}

tests/integration/annotation_import/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ def configured_project_pdf(client, ontology, rand_gen, pdf_url):
157157
dataset.delete()
158158

159159

160+
@pytest.fixture
161+
def configured_project_without_data_rows(client, configured_project, rand_gen):
162+
project = client.create_project(name=rand_gen(str))
163+
editor = list(
164+
client.get_labeling_frontends(
165+
where=LabelingFrontend.name == "editor"))[0]
166+
project.setup_editor(configured_project.ontology())
167+
project.update(queue_mode=project.QueueMode.Batch)
168+
yield project
169+
project.delete()
170+
171+
160172
@pytest.fixture
161173
def prediction_id_mapping(configured_project):
162174
#Maps tool types to feature schema ids
@@ -422,6 +434,7 @@ def model_run_with_model_run_data_rows(client, configured_project,
422434
model_run.upsert_labels(label_ids)
423435
time.sleep(3)
424436
yield model_run
437+
model_run.delete()
425438
# TODO: Delete resources when that is possible ..
426439

427440

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import uuid
2+
import ndjson
3+
import pytest
4+
5+
from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport
6+
"""
7+
- Here we only want to check that the uploads are calling the validation
8+
- Then with unit tests we can check the types of errors raised
9+
10+
"""
11+
12+
13+
def test_create_from_url(client, tmp_path, object_predictions,
14+
model_run_with_model_run_data_rows,
15+
configured_project_without_data_rows,
16+
annotation_import_test_helpers):
17+
name = str(uuid.uuid4())
18+
file_name = f"{name}.json"
19+
file_path = tmp_path / file_name
20+
21+
model_run_data_rows = [
22+
mrdr.data_row().uid
23+
for mrdr in model_run_with_model_run_data_rows.model_run_data_rows()
24+
]
25+
predictions = [
26+
p for p in object_predictions
27+
if p['dataRow']['id'] in model_run_data_rows
28+
]
29+
with file_path.open("w") as f:
30+
ndjson.dump(predictions, f)
31+
32+
# Needs to have data row ids
33+
34+
with open(file_path, "r") as f:
35+
url = client.upload_data(content=f.read(),
36+
filename=file_name,
37+
sign=True,
38+
content_type="application/json")
39+
40+
annotation_import, batch, mal_prediction_import = model_run_with_model_run_data_rows.upsert_predictions_and_send_to_project(
41+
name=name,
42+
predictions=url,
43+
project_id=configured_project_without_data_rows.uid,
44+
priority=5)
45+
46+
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
47+
annotation_import.wait_until_done()
48+
assert not annotation_import.errors
49+
assert annotation_import.statuses
50+
51+
assert batch
52+
assert batch.project().uid == configured_project_without_data_rows.uid
53+
54+
assert mal_prediction_import
55+
mal_prediction_import.wait_until_done()
56+
57+
assert not mal_prediction_import.errors
58+
assert mal_prediction_import.statuses
59+
60+
61+
def test_create_from_objects(model_run_with_model_run_data_rows,
62+
configured_project_without_data_rows,
63+
object_predictions,
64+
annotation_import_test_helpers):
65+
name = str(uuid.uuid4())
66+
model_run_data_rows = [
67+
mrdr.data_row().uid
68+
for mrdr in model_run_with_model_run_data_rows.model_run_data_rows()
69+
]
70+
predictions = [
71+
p for p in object_predictions
72+
if p['dataRow']['id'] in model_run_data_rows
73+
]
74+
annotation_import, batch, mal_prediction_import = model_run_with_model_run_data_rows.upsert_predictions_and_send_to_project(
75+
name=name,
76+
predictions=predictions,
77+
project_id=configured_project_without_data_rows.uid,
78+
priority=5)
79+
80+
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
81+
annotation_import.wait_until_done()
82+
assert not annotation_import.errors
83+
assert annotation_import.statuses
84+
85+
assert batch
86+
assert batch.project().uid == configured_project_without_data_rows.uid
87+
88+
assert mal_prediction_import
89+
mal_prediction_import.wait_until_done()
90+
91+
assert not mal_prediction_import.errors
92+
assert mal_prediction_import.statuses
93+
94+
95+
def test_create_from_local_file(tmp_path, model_run_with_model_run_data_rows,
96+
configured_project_without_data_rows,
97+
object_predictions,
98+
annotation_import_test_helpers):
99+
100+
name = str(uuid.uuid4())
101+
file_name = f"{name}.ndjson"
102+
file_path = tmp_path / file_name
103+
104+
model_run_data_rows = [
105+
mrdr.data_row().uid
106+
for mrdr in model_run_with_model_run_data_rows.model_run_data_rows()
107+
]
108+
predictions = [
109+
p for p in object_predictions
110+
if p['dataRow']['id'] in model_run_data_rows
111+
]
112+
113+
with file_path.open("w") as f:
114+
ndjson.dump(predictions, f)
115+
116+
annotation_import, batch, mal_prediction_import = model_run_with_model_run_data_rows.upsert_predictions_and_send_to_project(
117+
name=name,
118+
predictions=str(file_path),
119+
project_id=configured_project_without_data_rows.uid,
120+
priority=5)
121+
122+
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
123+
annotation_import.wait_until_done()
124+
assert not annotation_import.errors
125+
assert annotation_import.statuses
126+
127+
assert batch
128+
assert batch.project().uid == configured_project_without_data_rows.uid
129+
130+
assert mal_prediction_import
131+
mal_prediction_import.wait_until_done()
132+
133+
assert not mal_prediction_import.errors
134+
assert mal_prediction_import.statuses

0 commit comments

Comments
 (0)