Skip to content

Commit 7e66a04

Browse files
Revert "[PLT-150] Add unified create method for AnnotationImport, MEA… (#1546)
1 parent a07d582 commit 7e66a04

File tree

5 files changed

+25
-175
lines changed

5 files changed

+25
-175
lines changed

labelbox/schema/annotation_import.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import os
55
import time
6-
from typing import Any, BinaryIO, Dict, List, Optional, Union, TYPE_CHECKING, cast
6+
from typing import Any, BinaryIO, Dict, List, Union, TYPE_CHECKING, cast
77
from collections import defaultdict
88

99
from google.api_core import retry
@@ -241,47 +241,7 @@ def parent_id(self) -> str:
241241
raise NotImplementedError("Inheriting class must override")
242242

243243

244-
class CreatableAnnotationImport(AnnotationImport):
245-
246-
@classmethod
247-
def create(
248-
cls,
249-
client: "labelbox.Client",
250-
id: str,
251-
name: str,
252-
path: Optional[str] = None,
253-
url: Optional[str] = None,
254-
labels: Union[List[Dict[str, Any]], List["Label"]] = []
255-
) -> "AnnotationImport":
256-
if (not is_exactly_one_set(url, labels, path)):
257-
raise ValueError(
258-
"Must pass in a nonempty argument for one and only one of the following arguments: url, path, predictions"
259-
)
260-
if url:
261-
return cls.create_from_url(client, id, name, url)
262-
if path:
263-
return cls.create_from_file(client, id, name, path)
264-
return cls.create_from_objects(client, id, name, labels)
265-
266-
@classmethod
267-
def create_from_url(cls, client: "labelbox.Client", id: str, name: str,
268-
url: str) -> "AnnotationImport":
269-
raise NotImplementedError("Inheriting class must override")
270-
271-
@classmethod
272-
def create_from_file(cls, client: "labelbox.Client", id: str, name: str,
273-
path: str) -> "AnnotationImport":
274-
raise NotImplementedError("Inheriting class must override")
275-
276-
@classmethod
277-
def create_from_objects(
278-
cls, client: "labelbox.Client", id: str, name: str,
279-
labels: Union[List[Dict[str, Any]],
280-
List["Label"]]) -> "AnnotationImport":
281-
raise NotImplementedError("Inheriting class must override")
282-
283-
284-
class MEAPredictionImport(CreatableAnnotationImport):
244+
class MEAPredictionImport(AnnotationImport):
285245
model_run_id = Field.String("model_run_id")
286246

287247
@property
@@ -518,7 +478,7 @@ def _get_model_run_data_rows_mutation(cls) -> str:
518478
}""" % query.results_query_part(cls)
519479

520480

521-
class MALPredictionImport(CreatableAnnotationImport):
481+
class MALPredictionImport(AnnotationImport):
522482
project = Relationship.ToOne("Project", cache=True)
523483

524484
@property
@@ -678,7 +638,7 @@ def _create_mal_import_from_bytes(
678638
return cls(client, res["createModelAssistedLabelingPredictionImport"])
679639

680640

681-
class LabelImport(CreatableAnnotationImport):
641+
class LabelImport(AnnotationImport):
682642
project = Relationship.ToOne("Project", cache=True)
683643

684644
@property

labelbox/schema/model_run.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,17 @@ def add_predictions(
286286
Returns:
287287
AnnotationImport
288288
"""
289-
kwargs = dict(client=self.client, id=self.uid, name=name)
289+
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
290290
if isinstance(predictions, str) or isinstance(predictions, Path):
291291
if os.path.exists(predictions):
292-
return Entity.MEAPredictionImport.create(path=str(predictions),
293-
**kwargs)
292+
return Entity.MEAPredictionImport.create_from_file(
293+
path=str(predictions), **kwargs)
294294
else:
295-
return Entity.MEAPredictionImport.create(url=str(predictions),
296-
**kwargs)
295+
return Entity.MEAPredictionImport.create_from_url(
296+
url=str(predictions), **kwargs)
297297
elif isinstance(predictions, Iterable):
298-
return Entity.MEAPredictionImport.create(labels=predictions,
299-
**kwargs)
298+
return Entity.MEAPredictionImport.create_from_objects(
299+
predictions=predictions, **kwargs)
300300
else:
301301
raise ValueError(
302302
f'Invalid predictions given of type: {type(predictions)}')

labelbox/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def snake_case(s):
3939
return _convert(s, "_", lambda i: False)
4040

4141

42-
def is_exactly_one_set(*args):
43-
return sum([bool(arg) for arg in args]) == 1
42+
def is_exactly_one_set(x, y):
43+
return not (bool(x) == bool(y))
4444

4545

4646
def is_valid_uri(uri):

tests/data/annotation_import/test_label_import.py

Lines changed: 12 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import uuid
22
import pytest
3-
from labelbox import parser
43

54
from labelbox.schema.annotation_import import AnnotationImportState, LabelImport
65
"""
@@ -10,19 +9,6 @@
109
"""
1110

1211

13-
def test_create_with_url_arg(client, configured_project_with_one_data_row,
14-
annotation_import_test_helpers):
15-
name = str(uuid.uuid4())
16-
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
17-
label_import = LabelImport.create(
18-
client=client,
19-
id=configured_project_with_one_data_row.uid,
20-
name=name,
21-
url=url)
22-
assert label_import.parent_id == configured_project_with_one_data_row.uid
23-
annotation_import_test_helpers.check_running_state(label_import, name, url)
24-
25-
2612
def test_create_from_url(client, configured_project_with_one_data_row,
2713
annotation_import_test_helpers):
2814
name = str(uuid.uuid4())
@@ -36,22 +22,6 @@ def test_create_from_url(client, configured_project_with_one_data_row,
3622
annotation_import_test_helpers.check_running_state(label_import, name, url)
3723

3824

39-
def test_create_with_labels_arg(client, configured_project, object_predictions,
40-
annotation_import_test_helpers):
41-
"""this test should check running state only to validate running, not completed"""
42-
name = str(uuid.uuid4())
43-
44-
label_import = LabelImport.create(client=client,
45-
id=configured_project.uid,
46-
name=name,
47-
labels=object_predictions)
48-
49-
assert label_import.parent_id == configured_project.uid
50-
annotation_import_test_helpers.check_running_state(label_import, name)
51-
annotation_import_test_helpers.assert_file_content(
52-
label_import.input_file_url, object_predictions)
53-
54-
5525
def test_create_from_objects(client, configured_project, object_predictions,
5626
annotation_import_test_helpers):
5727
"""this test should check running state only to validate running, not completed"""
@@ -69,42 +39,20 @@ def test_create_from_objects(client, configured_project, object_predictions,
6939
label_import.input_file_url, object_predictions)
7040

7141

72-
def test_create_with_path_arg(client, tmp_path, project, object_predictions,
73-
annotation_import_test_helpers):
74-
name = str(uuid.uuid4())
75-
file_name = f"{name}.ndjson"
76-
file_path = tmp_path / file_name
77-
with file_path.open("w") as f:
78-
parser.dump(object_predictions, f)
79-
80-
label_import = LabelImport.create(client=client,
81-
id=project.uid,
82-
name=name,
83-
path=str(file_path))
84-
85-
assert label_import.parent_id == project.uid
86-
annotation_import_test_helpers.check_running_state(label_import, name)
87-
annotation_import_test_helpers.assert_file_content(
88-
label_import.input_file_url, object_predictions)
89-
90-
91-
def test_create_from_local_file(client, tmp_path, project, object_predictions,
92-
annotation_import_test_helpers):
93-
name = str(uuid.uuid4())
94-
file_name = f"{name}.ndjson"
95-
file_path = tmp_path / file_name
96-
with file_path.open("w") as f:
97-
parser.dump(object_predictions, f)
42+
# TODO: add me when we add this ability
43+
# def test_create_from_local_file(client, tmp_path, project,
44+
# object_predictions, annotation_import_test_helpers):
45+
# name = str(uuid.uuid4())
46+
# file_name = f"{name}.ndjson"
47+
# file_path = tmp_path / file_name
48+
# with file_path.open("w") as f:
49+
# ndjson.dump(object_predictions, f)
9850

99-
label_import = LabelImport.create_from_url(client=client,
100-
project_id=project.uid,
101-
name=name,
102-
url=str(file_path))
51+
# label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=str(file_path))
10352

104-
assert label_import.parent_id == project.uid
105-
annotation_import_test_helpers.check_running_state(label_import, name)
106-
annotation_import_test_helpers.assert_file_content(
107-
label_import.input_file_url, object_predictions)
53+
# assert label_import.parent_id == project.uid
54+
# annotation_import_test_helpers.check_running_state(label_import, name)
55+
# annotation_import_test_helpers.assert_file_content(label_import.input_file_url, object_predictions)
10856

10957

11058
def test_get(client, configured_project_with_one_data_row,

tests/data/annotation_import/test_mal_prediction_import.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)