Skip to content

Commit 035f027

Browse files
Revert "Revert "[PLT-150] Add unified create method for AnnotationImp… (#1547)
1 parent d3700ca commit 035f027

File tree

5 files changed

+175
-25
lines changed

5 files changed

+175
-25
lines changed

labelbox/schema/annotation_import.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import os
55
import time
6-
from typing import Any, BinaryIO, Dict, List, Union, TYPE_CHECKING, cast
6+
from typing import Any, BinaryIO, Dict, List, Optional, Union, TYPE_CHECKING, cast
77
from collections import defaultdict
88

99
from google.api_core import retry
@@ -241,7 +241,47 @@ def parent_id(self) -> str:
241241
raise NotImplementedError("Inheriting class must override")
242242

243243

244-
class MEAPredictionImport(AnnotationImport):
244+
class CreatableAnnotationImport(AnnotationImport):
245+
246+
@classmethod
247+
def create(
248+
cls,
249+
client: "labelbox.Client",
250+
id: str,
251+
name: str,
252+
path: Optional[str] = None,
253+
url: Optional[str] = None,
254+
labels: Union[List[Dict[str, Any]], List["Label"]] = []
255+
) -> "AnnotationImport":
256+
if (not is_exactly_one_set(url, labels, path)):
257+
raise ValueError(
258+
"Must pass in a nonempty argument for one and only one of the following arguments: url, path, predictions"
259+
)
260+
if url:
261+
return cls.create_from_url(client, id, name, url)
262+
if path:
263+
return cls.create_from_file(client, id, name, path)
264+
return cls.create_from_objects(client, id, name, labels)
265+
266+
@classmethod
267+
def create_from_url(cls, client: "labelbox.Client", id: str, name: str,
268+
url: str) -> "AnnotationImport":
269+
raise NotImplementedError("Inheriting class must override")
270+
271+
@classmethod
272+
def create_from_file(cls, client: "labelbox.Client", id: str, name: str,
273+
path: str) -> "AnnotationImport":
274+
raise NotImplementedError("Inheriting class must override")
275+
276+
@classmethod
277+
def create_from_objects(
278+
cls, client: "labelbox.Client", id: str, name: str,
279+
labels: Union[List[Dict[str, Any]],
280+
List["Label"]]) -> "AnnotationImport":
281+
raise NotImplementedError("Inheriting class must override")
282+
283+
284+
class MEAPredictionImport(CreatableAnnotationImport):
245285
model_run_id = Field.String("model_run_id")
246286

247287
@property
@@ -478,7 +518,7 @@ def _get_model_run_data_rows_mutation(cls) -> str:
478518
}""" % query.results_query_part(cls)
479519

480520

481-
class MALPredictionImport(AnnotationImport):
521+
class MALPredictionImport(CreatableAnnotationImport):
482522
project = Relationship.ToOne("Project", cache=True)
483523

484524
@property
@@ -638,7 +678,7 @@ def _create_mal_import_from_bytes(
638678
return cls(client, res["createModelAssistedLabelingPredictionImport"])
639679

640680

641-
class LabelImport(AnnotationImport):
681+
class LabelImport(CreatableAnnotationImport):
642682
project = Relationship.ToOne("Project", cache=True)
643683

644684
@property

labelbox/schema/model_run.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,17 @@ def add_predictions(
286286
Returns:
287287
AnnotationImport
288288
"""
289-
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
289+
kwargs = dict(client=self.client, id=self.uid, name=name)
290290
if isinstance(predictions, str) or isinstance(predictions, Path):
291291
if os.path.exists(predictions):
292-
return Entity.MEAPredictionImport.create_from_file(
293-
path=str(predictions), **kwargs)
292+
return Entity.MEAPredictionImport.create(path=str(predictions),
293+
**kwargs)
294294
else:
295-
return Entity.MEAPredictionImport.create_from_url(
296-
url=str(predictions), **kwargs)
295+
return Entity.MEAPredictionImport.create(url=str(predictions),
296+
**kwargs)
297297
elif isinstance(predictions, Iterable):
298-
return Entity.MEAPredictionImport.create_from_objects(
299-
predictions=predictions, **kwargs)
298+
return Entity.MEAPredictionImport.create(labels=predictions,
299+
**kwargs)
300300
else:
301301
raise ValueError(
302302
f'Invalid predictions given of type: {type(predictions)}')

labelbox/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def snake_case(s):
3939
return _convert(s, "_", lambda i: False)
4040

4141

42-
def is_exactly_one_set(x, y):
43-
return not (bool(x) == bool(y))
42+
def is_exactly_one_set(*args):
43+
return sum([bool(arg) for arg in args]) == 1
4444

4545

4646
def is_valid_uri(uri):

tests/data/annotation_import/test_label_import.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22
import pytest
3+
from labelbox import parser
34

45
from labelbox.schema.annotation_import import AnnotationImportState, LabelImport
56
"""
@@ -9,6 +10,19 @@
910
"""
1011

1112

13+
def test_create_with_url_arg(client, configured_project_with_one_data_row,
14+
annotation_import_test_helpers):
15+
name = str(uuid.uuid4())
16+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
17+
label_import = LabelImport.create(
18+
client=client,
19+
id=configured_project_with_one_data_row.uid,
20+
name=name,
21+
url=url)
22+
assert label_import.parent_id == configured_project_with_one_data_row.uid
23+
annotation_import_test_helpers.check_running_state(label_import, name, url)
24+
25+
1226
def test_create_from_url(client, configured_project_with_one_data_row,
1327
annotation_import_test_helpers):
1428
name = str(uuid.uuid4())
@@ -22,6 +36,22 @@ def test_create_from_url(client, configured_project_with_one_data_row,
2236
annotation_import_test_helpers.check_running_state(label_import, name, url)
2337

2438

39+
def test_create_with_labels_arg(client, configured_project, object_predictions,
40+
annotation_import_test_helpers):
41+
"""this test should check running state only to validate running, not completed"""
42+
name = str(uuid.uuid4())
43+
44+
label_import = LabelImport.create(client=client,
45+
id=configured_project.uid,
46+
name=name,
47+
labels=object_predictions)
48+
49+
assert label_import.parent_id == configured_project.uid
50+
annotation_import_test_helpers.check_running_state(label_import, name)
51+
annotation_import_test_helpers.assert_file_content(
52+
label_import.input_file_url, object_predictions)
53+
54+
2555
def test_create_from_objects(client, configured_project, object_predictions,
2656
annotation_import_test_helpers):
2757
"""this test should check running state only to validate running, not completed"""
@@ -39,20 +69,42 @@ def test_create_from_objects(client, configured_project, object_predictions,
3969
label_import.input_file_url, object_predictions)
4070

4171

42-
# TODO: add me when we add this ability
43-
# def test_create_from_local_file(client, tmp_path, project,
44-
# object_predictions, annotation_import_test_helpers):
45-
# name = str(uuid.uuid4())
46-
# file_name = f"{name}.ndjson"
47-
# file_path = tmp_path / file_name
48-
# with file_path.open("w") as f:
49-
# ndjson.dump(object_predictions, f)
72+
def test_create_with_path_arg(client, tmp_path, project, object_predictions,
73+
annotation_import_test_helpers):
74+
name = str(uuid.uuid4())
75+
file_name = f"{name}.ndjson"
76+
file_path = tmp_path / file_name
77+
with file_path.open("w") as f:
78+
parser.dump(object_predictions, f)
79+
80+
label_import = LabelImport.create(client=client,
81+
id=project.uid,
82+
name=name,
83+
path=str(file_path))
84+
85+
assert label_import.parent_id == project.uid
86+
annotation_import_test_helpers.check_running_state(label_import, name)
87+
annotation_import_test_helpers.assert_file_content(
88+
label_import.input_file_url, object_predictions)
89+
90+
91+
def test_create_from_local_file(client, tmp_path, project, object_predictions,
92+
annotation_import_test_helpers):
93+
name = str(uuid.uuid4())
94+
file_name = f"{name}.ndjson"
95+
file_path = tmp_path / file_name
96+
with file_path.open("w") as f:
97+
parser.dump(object_predictions, f)
5098

51-
# label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=str(file_path))
99+
label_import = LabelImport.create_from_url(client=client,
100+
project_id=project.uid,
101+
name=name,
102+
url=str(file_path))
52103

53-
# assert label_import.parent_id == project.uid
54-
# annotation_import_test_helpers.check_running_state(label_import, name)
55-
# annotation_import_test_helpers.assert_file_content(label_import.input_file_url, object_predictions)
104+
assert label_import.parent_id == project.uid
105+
annotation_import_test_helpers.check_running_state(label_import, name)
106+
annotation_import_test_helpers.assert_file_content(
107+
label_import.input_file_url, object_predictions)
56108

57109

58110
def test_get(client, configured_project_with_one_data_row,
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import uuid
2+
import pytest
3+
4+
from labelbox import parser
5+
from labelbox.schema.annotation_import import MALPredictionImport
6+
"""
7+
- Here we only want to check that the uploads are calling the validation
8+
- Then with unit tests we can check the types of errors raised
9+
10+
"""
11+
12+
13+
def test_create_with_url_arg(client, configured_project_with_one_data_row,
14+
annotation_import_test_helpers):
15+
name = str(uuid.uuid4())
16+
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
17+
label_import = MALPredictionImport.create(
18+
client=client,
19+
id=configured_project_with_one_data_row.uid,
20+
name=name,
21+
url=url)
22+
assert label_import.parent_id == configured_project_with_one_data_row.uid
23+
annotation_import_test_helpers.check_running_state(label_import, name, url)
24+
25+
26+
def test_create_with_labels_arg(client, configured_project, object_predictions,
27+
annotation_import_test_helpers):
28+
"""this test should check running state only to validate running, not completed"""
29+
name = str(uuid.uuid4())
30+
31+
label_import = MALPredictionImport.create(client=client,
32+
id=configured_project.uid,
33+
name=name,
34+
labels=object_predictions)
35+
36+
assert label_import.parent_id == configured_project.uid
37+
annotation_import_test_helpers.check_running_state(label_import, name)
38+
annotation_import_test_helpers.assert_file_content(
39+
label_import.input_file_url, object_predictions)
40+
41+
42+
def test_create_with_path_arg(client, tmp_path, project, object_predictions,
43+
annotation_import_test_helpers):
44+
name = str(uuid.uuid4())
45+
file_name = f"{name}.ndjson"
46+
file_path = tmp_path / file_name
47+
with file_path.open("w") as f:
48+
parser.dump(object_predictions, f)
49+
50+
label_import = MALPredictionImport.create(client=client,
51+
id=project.uid,
52+
name=name,
53+
path=str(file_path))
54+
55+
assert label_import.parent_id == project.uid
56+
annotation_import_test_helpers.check_running_state(label_import, name)
57+
annotation_import_test_helpers.assert_file_content(
58+
label_import.input_file_url, object_predictions)

0 commit comments

Comments
 (0)