Skip to content

Commit c69da34

Browse files
committed
Test&Mutation Fixes
1 parent fcfb3c4 commit c69da34

File tree

4 files changed

+147
-7
lines changed

4 files changed

+147
-7
lines changed

labelbox/schema/annotation_import.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,14 @@ def _get_url_mutation(cls) -> str:
473473

474474
@classmethod
475475
def _get_model_run_data_rows_mutation(cls) -> str:
476-
return """mutation createMalPredictionImportForModelRunDataRows($projectId : ID!, $dataRowIds: [ID!]!, $name: String!, $modelRunId: ID!) {
476+
return """mutation createMalPredictionImportForModelRunDataRowsPyApi($dataRowIds: [ID!]!, $name: String!, $modelRunId: ID!, $projectId:ID!) {
477477
createMalPredictionImportForModelRunDataRows(data: {
478478
name: $name
479479
modelRunId: $modelRunId
480480
dataRowIds: $dataRowIds
481481
projectId: $projectId
482-
}) {%s}
483-
}""" % query.results_query_part(cls)
482+
}) {id importType inputFileUrl errorFileUrl project { id name } name statusFileUrl state progress}
483+
}"""
484484

485485
@classmethod
486486
def _get_file_mutation(cls) -> str:

labelbox/schema/model_run.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,22 @@ def upsert_predictions_and_send_to_project(
144144
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
145145
project = self.client.get_project(project_id)
146146
import_job = self.add_predictions(name, predictions)
147-
prediction_statuses = import_job.statuses()
147+
prediction_statuses = import_job.statuses
148148
mea_to_mal_data_rows_set = set([
149149
row['dataRow']['id']
150150
for row in prediction_statuses
151151
if row['status'] == 'SUCCESS'
152152
])
153153
mea_to_mal_data_rows = list(
154154
mea_to_mal_data_rows_set)[:DATAROWS_IMPORT_LIMIT]
155-
logger.warning(
156-
f"Got {len(mea_to_mal_data_rows_set)} data rows to import, trimmed down to {DATAROWS_IMPORT_LIMIT} data rows"
157-
)
155+
156+
if len(mea_to_mal_data_rows) >= DATAROWS_IMPORT_LIMIT:
157+
158+
logger.warning(
159+
f"Got {len(mea_to_mal_data_rows_set)} data rows to import, trimmed down to {DATAROWS_IMPORT_LIMIT} data rows"
160+
)
161+
if len(mea_to_mal_data_rows) == 0:
162+
return import_job, None, None
158163

159164
batch = project.create_batch(name, mea_to_mal_data_rows, priority)
160165
mal_prediction_import = Entity.MALPredictionImport.create_for_model_run_data_rows(

tests/integration/annotation_import/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ def configured_project(client, ontology, rand_gen, image_url):
121121
dataset.delete()
122122

123123

124+
@pytest.fixture
125+
def configured_project_without_data_rows(client, ontology, rand_gen):
126+
project = client.create_project(name=rand_gen(str))
127+
dataset = client.create_dataset(name=rand_gen(str))
128+
editor = list(
129+
client.get_labeling_frontends(
130+
where=LabelingFrontend.name == "editor"))[0]
131+
project.setup(editor, ontology)
132+
project.update(queue_mode=project.QueueMode.Batch)
133+
yield project
134+
project.delete()
135+
dataset.delete()
136+
137+
124138
@pytest.fixture
125139
def prediction_id_mapping(configured_project):
126140
#Maps tool types to feature schema ids
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import uuid
2+
import ndjson
3+
import pytest
4+
5+
from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport
6+
"""
7+
- Here we only want to check that the uploads are calling the validation
8+
- Then with unit tests we can check the types of errors raised
9+
10+
"""
11+
12+
13+
def test_create_from_url(model_run_with_model_run_data_rows,
14+
configured_project_without_data_rows,
15+
annotation_import_test_helpers):
16+
name = str(uuid.uuid4())
17+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
18+
19+
annotation_import, batch, mal_prediction_import = model_run_with_model_run_data_rows.upsert_predictions_and_send_to_project(
20+
name=name,
21+
predictions=url,
22+
project_id=configured_project_without_data_rows.uid,
23+
priority=5)
24+
25+
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
26+
annotation_import_test_helpers.check_running_state(annotation_import, name,
27+
url)
28+
annotation_import.wait_until_done()
29+
30+
if batch:
31+
assert batch.project().uid == configured_project_without_data_rows.uid
32+
if mal_prediction_import:
33+
mal_prediction_import.wait_until_done()
34+
35+
36+
def test_create_from_objects(model_run_with_model_run_data_rows,
37+
configured_project_without_data_rows,
38+
object_predictions,
39+
annotation_import_test_helpers):
40+
name = str(uuid.uuid4())
41+
42+
annotation_import, batch, mal_prediction_import = model_run_with_model_run_data_rows.upsert_predictions_and_send_to_project(
43+
name=name,
44+
predictions=object_predictions,
45+
project_id=configured_project_without_data_rows.uid,
46+
priority=5)
47+
48+
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
49+
annotation_import_test_helpers.check_running_state(annotation_import, name)
50+
annotation_import_test_helpers.assert_file_content(
51+
annotation_import.input_file_url, object_predictions)
52+
annotation_import.wait_until_done()
53+
54+
if batch:
55+
assert batch.project().uid == configured_project_without_data_rows.uid
56+
57+
if mal_prediction_import:
58+
annotation_import_test_helpers.check_running_state(
59+
mal_prediction_import, name)
60+
mal_prediction_import.wait_until_done()
61+
62+
63+
def test_create_from_local_file(tmp_path, model_run_with_model_run_data_rows,
64+
configured_project_without_data_rows,
65+
object_predictions,
66+
annotation_import_test_helpers):
67+
name = str(uuid.uuid4())
68+
file_name = f"{name}.ndjson"
69+
file_path = tmp_path / file_name
70+
with file_path.open("w") as f:
71+
ndjson.dump(object_predictions, f)
72+
73+
annotation_import = model_run_with_model_run_data_rows.add_predictions(
74+
name=name, predictions=str(file_path))
75+
76+
annotation_import, batch, mal_prediction_import = model_run_with_model_run_data_rows.upsert_predictions_and_send_to_project(
77+
name=name,
78+
predictions=str(file_path),
79+
project_id=configured_project_without_data_rows.uid,
80+
priority=5)
81+
82+
assert annotation_import.model_run_id == model_run_with_model_run_data_rows.uid
83+
annotation_import_test_helpers.check_running_state(annotation_import, name)
84+
annotation_import_test_helpers.assert_file_content(
85+
annotation_import.input_file_url, object_predictions)
86+
annotation_import.wait_until_done()
87+
88+
if batch:
89+
assert batch.project().uid == configured_project_without_data_rows.uid
90+
91+
if mal_prediction_import:
92+
annotation_import_test_helpers.check_running_state(
93+
mal_prediction_import, name)
94+
mal_prediction_import.wait_until_done()
95+
96+
97+
@pytest.mark.slow
98+
def test_wait_till_done(model_run_predictions,
99+
model_run_with_model_run_data_rows):
100+
name = str(uuid.uuid4())
101+
annotation_import = model_run_with_model_run_data_rows.add_predictions(
102+
name=name, predictions=model_run_predictions)
103+
104+
assert len(annotation_import.inputs) == len(model_run_predictions)
105+
annotation_import.wait_until_done()
106+
assert annotation_import.state == AnnotationImportState.FINISHED
107+
# Check that the status files are being returned as expected
108+
assert len(annotation_import.errors) == 0
109+
assert len(annotation_import.inputs) == len(model_run_predictions)
110+
input_uuids = [
111+
input_annot['uuid'] for input_annot in annotation_import.inputs
112+
]
113+
inference_uuids = [pred['uuid'] for pred in model_run_predictions]
114+
assert set(input_uuids) == set(inference_uuids)
115+
assert len(annotation_import.statuses) == len(model_run_predictions)
116+
for status in annotation_import.statuses:
117+
assert status['status'] == 'SUCCESS'
118+
status_uuids = [
119+
input_annot['uuid'] for input_annot in annotation_import.statuses
120+
]
121+
assert set(input_uuids) == set(status_uuids)

0 commit comments

Comments
 (0)