Skip to content

Commit 2eafbb3

Browse files
author
Matt Sokoloff
committed
separate mea to mal import class
1 parent ef3a6de commit 2eafbb3

File tree

4 files changed

+156
-109
lines changed

4 files changed

+156
-109
lines changed

labelbox/schema/annotation_import.py

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,88 @@ def _create_mea_import_from_bytes(
318318
return cls(client, res["createModelErrorAnalysisPredictionImport"])
319319

320320

321+
class MEAToMALPredictionImport(AnnotationImport):
322+
project = Relationship.ToOne("Project", cache=True)
323+
324+
@property
325+
def parent_id(self) -> str:
326+
"""
327+
Identifier for this import. Used to refresh the status
328+
"""
329+
return self.project().uid
330+
331+
@classmethod
332+
def create_for_model_run_data_rows(cls, client: "labelbox.Client",
333+
model_run_id: str,
334+
data_row_ids: List[str], project_id: str,
335+
name: str) -> "MEAToMALPredictionImport":
336+
"""
337+
Create an MEA to MAL prediction import job from a list of data row ids of a specific model run
338+
339+
Args:
340+
client: Labelbox Client for executing queries
341+
data_row_ids: A list of data row ids
342+
model_run_id: model run id
343+
Returns:
344+
MEAToMALPredictionImport
345+
"""
346+
query_str = cls._get_model_run_data_rows_mutation()
347+
return cls(
348+
client,
349+
client.execute(query_str,
350+
params={
351+
"dataRowIds": data_row_ids,
352+
"modelRunId": model_run_id,
353+
"projectId": project_id,
354+
"name": name
355+
})["createMalPredictionImportForModelRunDataRows"])
356+
357+
@classmethod
358+
def from_name(cls,
359+
client: "labelbox.Client",
360+
project_id: str,
361+
name: str,
362+
as_json: bool = False) -> "MEAToMALPredictionImport":
363+
"""
364+
Retrieves an MEA to MAL import job.
365+
366+
Args:
367+
client: Labelbox Client for executing queries
368+
project_id: ID used for querying import jobs
369+
name: Name of the import job.
370+
Returns:
371+
MALPredictionImport
372+
"""
373+
query_str = """query getMEAToMALPredictionImportPyApi($projectId : ID!, $name: String!) {
374+
meaToMalPredictionImport(
375+
where: {projectId: $projectId, name: $name}){
376+
%s
377+
}}""" % query.results_query_part(cls)
378+
params = {
379+
"projectId": project_id,
380+
"name": name,
381+
}
382+
response = client.execute(query_str, params)
383+
if response is None:
384+
raise labelbox.exceptions.ResourceNotFoundError(
385+
MALPredictionImport, params)
386+
response = response["meaToMalPredictionImport"]
387+
if as_json:
388+
return response
389+
return cls(client, response)
390+
391+
@classmethod
392+
def _get_model_run_data_rows_mutation(cls) -> str:
393+
return """mutation createMalPredictionImportForModelRunDataRowsPyApi($dataRowIds: [ID!]!, $name: String!, $modelRunId: ID!, $projectId:ID!) {
394+
createMalPredictionImportForModelRunDataRows(data: {
395+
name: $name
396+
modelRunId: $modelRunId
397+
dataRowIds: $dataRowIds
398+
projectId: $projectId
399+
}) {%s}
400+
}""" % query.results_query_part(cls)
401+
402+
321403
class MALPredictionImport(AnnotationImport):
322404
project = Relationship.ToOne("Project", cache=True)
323405

@@ -401,32 +483,6 @@ def create_from_url(cls, client: "labelbox.Client", project_id: str,
401483
else:
402484
raise ValueError(f"Url {url} is not reachable")
403485

404-
@classmethod
405-
def create_for_model_run_data_rows(cls, client: "labelbox.Client",
406-
model_run_id: str,
407-
data_row_ids: List[str], project_id: str,
408-
name: str) -> "MALPredictionImport":
409-
"""
410-
Create an MAL prediction import job from a list of data row ids of a specific model run
411-
412-
Args:
413-
client: Labelbox Client for executing queries
414-
data_row_ids: A list of data row ids
415-
model_run_id: model run id
416-
Returns:
417-
MALPredictionImport
418-
"""
419-
query_str = cls._get_model_run_data_rows_mutation()
420-
return cls(
421-
client,
422-
client.execute(query_str,
423-
params={
424-
"dataRowIds": data_row_ids,
425-
"modelRunId": model_run_id,
426-
"projectId": project_id,
427-
"name": name
428-
})["createMalPredictionImportForModelRunDataRows"])
429-
430486
@classmethod
431487
def from_name(cls,
432488
client: "labelbox.Client",
@@ -471,17 +527,6 @@ def _get_url_mutation(cls) -> str:
471527
}) {%s}
472528
}""" % query.results_query_part(cls)
473529

474-
@classmethod
475-
def _get_model_run_data_rows_mutation(cls) -> str:
476-
return """mutation createMalPredictionImportForModelRunDataRowsPyApi($dataRowIds: [ID!]!, $name: String!, $modelRunId: ID!, $projectId:ID!) {
477-
createMalPredictionImportForModelRunDataRows(data: {
478-
name: $name
479-
modelRunId: $modelRunId
480-
dataRowIds: $dataRowIds
481-
projectId: $projectId
482-
}) {id importType inputFileUrl errorFileUrl project { id name } name statusFileUrl state progress}
483-
}"""
484-
485530
@classmethod
486531
def _get_file_mutation(cls) -> str:
487532
return """mutation createMALPredictionImportByFilePyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) {

labelbox/schema/model_run.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def upsert_predictions_and_send_to_project(
140140
project_id (str): id of the project to import into
141141
priority (int): priority of the job
142142
Returns:
143-
(AnnotationImport, Project)
143+
(MEAPredictionImport, Batch, MEAToMALPredictionImport)
144144
"""
145145
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
146146
project = self.client.get_project(project_id)
@@ -165,7 +165,7 @@ def upsert_predictions_and_send_to_project(
165165
try:
166166
batch = project.create_batch(name, mea_to_mal_data_rows, priority)
167167
try:
168-
mal_prediction_import = Entity.MALPredictionImport.create_for_model_run_data_rows(
168+
mal_prediction_import = Entity.MEAToMALPredictionImport.create_for_model_run_data_rows(
169169
data_row_ids=mea_to_mal_data_rows,
170170
project_id=project_id,
171171
**kwargs)
@@ -316,11 +316,11 @@ def update_status(self,
316316

317317
@experimental
318318
def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
319-
"""
319+
"""
320320
Updates the Model Run's training metadata config
321-
Args:
321+
Args:
322322
config (dict): A dictionary of keys and values
323-
Returns:
323+
Returns:
324324
Model Run id and updated training metadata
325325
"""
326326
data: Dict[str, Any] = {'config': config}
@@ -337,9 +337,9 @@ def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
337337

338338
@experimental
339339
def reset_config(self) -> Dict[str, Any]:
340-
"""
340+
"""
341341
Resets Model Run's training metadata config
342-
Returns:
342+
Returns:
343343
Model Run id and reset training metadata
344344
"""
345345
res = self.client.execute(
@@ -352,10 +352,10 @@ def reset_config(self) -> Dict[str, Any]:
352352

353353
@experimental
354354
def get_config(self) -> Dict[str, Any]:
355-
"""
356-
Gets Model Run's training metadata
357-
Returns:
358-
training metadata as a dictionary
355+
"""
356+
Gets Model Run's training metadata
357+
Returns:
358+
training metadata as a dictionary
359359
"""
360360
res = self.client.execute("""query ModelRunPyApi($modelRunId: ID!){
361361
modelRun(where: {id : $modelRunId}){trainingMetadata}

tests/integration/annotation_import/conftest.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,15 @@ def configured_project_pdf(client, ontology, rand_gen, pdf_url):
158158

159159

160160
@pytest.fixture
161-
def configured_project_without_data_rows(client, ontology, rand_gen):
161+
def configured_project_without_data_rows(client, configured_project, rand_gen):
162162
project = client.create_project(name=rand_gen(str))
163-
dataset = client.create_dataset(name=rand_gen(str))
164163
editor = list(
165164
client.get_labeling_frontends(
166165
where=LabelingFrontend.name == "editor"))[0]
167-
project.setup(editor, ontology)
166+
project.setup_editor(configured_project.ontology())
168167
project.update(queue_mode=project.QueueMode.Batch)
169168
yield project
170169
project.delete()
171-
dataset.delete()
172170

173171

174172
@pytest.fixture
@@ -436,6 +434,7 @@ def model_run_with_model_run_data_rows(client, configured_project,
436434
model_run.upsert_labels(label_ids)
437435
time.sleep(3)
438436
yield model_run
437+
model_run.delete()
439438
# TODO: Delete resources when that is possible ..
440439

441440

tests/integration/annotation_import/test_upsert_prediction_import.py

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,19 @@ def test_create_from_url(client, tmp_path, object_predictions,
1717
name = str(uuid.uuid4())
1818
file_name = f"{name}.json"
1919
file_path = tmp_path / file_name
20+
21+
model_run_data_rows = [
22+
mrdr.data_row().uid
23+
for mrdr in model_run_with_model_run_data_rows.model_run_data_rows()
24+
]
25+
predictions = [
26+
p for p in object_predictions
27+
if p['dataRow']['id'] in model_run_data_rows
28+
]
2029
with file_path.open("w") as f:
21-
ndjson.dump(object_predictions, f)
30+
ndjson.dump(predictions, f)
31+
32+
# Needs to have data row ids
2233

2334
with open(file_path, "r") as f:
2435
url = client.upload_data(content=f.read(),
@@ -33,55 +44,74 @@ def test_create_from_url(client, tmp_path, object_predictions,
3344
priority=5)
3445

3546
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
36-
annotation_import_test_helpers.check_running_state(annotation_import, name,
37-
url)
3847
annotation_import.wait_until_done()
48+
assert not annotation_import.errors
49+
assert annotation_import.statuses
50+
51+
assert batch
52+
assert batch.project().uid == configured_project_without_data_rows.uid
53+
54+
assert mal_prediction_import
55+
mal_prediction_import.wait_until_done()
3956

40-
if batch:
41-
assert batch.project().uid == configured_project_without_data_rows.uid
42-
if mal_prediction_import:
43-
mal_prediction_import.wait_until_done()
57+
assert not mal_prediction_import.errors
58+
assert mal_prediction_import.statuses
4459

4560

4661
def test_create_from_objects(model_run_with_model_run_data_rows,
4762
configured_project_without_data_rows,
4863
object_predictions,
4964
annotation_import_test_helpers):
5065
name = str(uuid.uuid4())
51-
66+
model_run_data_rows = [
67+
mrdr.data_row().uid
68+
for mrdr in model_run_with_model_run_data_rows.model_run_data_rows()
69+
]
70+
predictions = [
71+
p for p in object_predictions
72+
if p['dataRow']['id'] in model_run_data_rows
73+
]
5274
annotation_import, batch, mal_prediction_import = model_run_with_model_run_data_rows.upsert_predictions_and_send_to_project(
5375
name=name,
54-
predictions=object_predictions,
76+
predictions=predictions,
5577
project_id=configured_project_without_data_rows.uid,
5678
priority=5)
5779

5880
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
59-
annotation_import_test_helpers.check_running_state(annotation_import, name)
60-
annotation_import_test_helpers.assert_file_content(
61-
annotation_import.input_file_url, object_predictions)
6281
annotation_import.wait_until_done()
82+
assert not annotation_import.errors
83+
assert annotation_import.statuses
6384

64-
if batch:
65-
assert batch.project().uid == configured_project_without_data_rows.uid
85+
assert batch
86+
assert batch.project().uid == configured_project_without_data_rows.uid
6687

67-
if mal_prediction_import:
68-
annotation_import_test_helpers.check_running_state(
69-
mal_prediction_import, name)
70-
mal_prediction_import.wait_until_done()
88+
assert mal_prediction_import
89+
mal_prediction_import.wait_until_done()
90+
91+
assert not mal_prediction_import.errors
92+
assert mal_prediction_import.statuses
7193

7294

7395
def test_create_from_local_file(tmp_path, model_run_with_model_run_data_rows,
7496
configured_project_without_data_rows,
7597
object_predictions,
7698
annotation_import_test_helpers):
99+
77100
name = str(uuid.uuid4())
78101
file_name = f"{name}.ndjson"
79102
file_path = tmp_path / file_name
80-
with file_path.open("w") as f:
81-
ndjson.dump(object_predictions, f)
82103

83-
annotation_import = model_run_with_model_run_data_rows.add_predictions(
84-
name=name, predictions=str(file_path))
104+
model_run_data_rows = [
105+
mrdr.data_row().uid
106+
for mrdr in model_run_with_model_run_data_rows.model_run_data_rows()
107+
]
108+
predictions = [
109+
p for p in object_predictions
110+
if p['dataRow']['id'] in model_run_data_rows
111+
]
112+
113+
with file_path.open("w") as f:
114+
ndjson.dump(predictions, f)
85115

86116
annotation_import, batch, mal_prediction_import = model_run_with_model_run_data_rows.upsert_predictions_and_send_to_project(
87117
name=name,
@@ -90,42 +120,15 @@ def test_create_from_local_file(tmp_path, model_run_with_model_run_data_rows,
90120
priority=5)
91121

92122
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
93-
annotation_import_test_helpers.check_running_state(annotation_import, name)
94-
annotation_import_test_helpers.assert_file_content(
95-
annotation_import.input_file_url, object_predictions)
96123
annotation_import.wait_until_done()
124+
assert not annotation_import.errors
125+
assert annotation_import.statuses
97126

98-
if batch:
99-
assert batch.project().uid == configured_project_without_data_rows.uid
100-
101-
if mal_prediction_import:
102-
annotation_import_test_helpers.check_running_state(
103-
mal_prediction_import, name)
104-
mal_prediction_import.wait_until_done()
127+
assert batch
128+
assert batch.project().uid == configured_project_without_data_rows.uid
105129

130+
assert mal_prediction_import
131+
mal_prediction_import.wait_until_done()
106132

107-
@pytest.mark.slow
108-
def test_wait_till_done(model_run_predictions,
109-
model_run_with_model_run_data_rows):
110-
name = str(uuid.uuid4())
111-
annotation_import = model_run_with_model_run_data_rows.add_predictions(
112-
name=name, predictions=model_run_predictions)
113-
114-
assert len(annotation_import.inputs) == len(model_run_predictions)
115-
annotation_import.wait_until_done()
116-
assert annotation_import.state == AnnotationImportState.FINISHED
117-
# Check that the status files are being returned as expected
118-
assert len(annotation_import.errors) == 0
119-
assert len(annotation_import.inputs) == len(model_run_predictions)
120-
input_uuids = [
121-
input_annot['uuid'] for input_annot in annotation_import.inputs
122-
]
123-
inference_uuids = [pred['uuid'] for pred in model_run_predictions]
124-
assert set(input_uuids) == set(inference_uuids)
125-
assert len(annotation_import.statuses) == len(model_run_predictions)
126-
for status in annotation_import.statuses:
127-
assert status['status'] == 'SUCCESS'
128-
status_uuids = [
129-
input_annot['uuid'] for input_annot in annotation_import.statuses
130-
]
131-
assert set(input_uuids) == set(status_uuids)
133+
assert not mal_prediction_import.errors
134+
assert mal_prediction_import.statuses

0 commit comments

Comments
 (0)