Skip to content

Commit e55026d

Browse files
author
Val Brodsky
committed
Update integration test
1 parent 7f27b78 commit e55026d

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

labelbox/schema/id_type.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from enum import Enum
2+
3+
4+
class IdType(str, Enum):
5+
"""
6+
The type of id used to identify a data row.
7+
8+
Currently supported types are:
9+
- DataRowId: The id assigned to a data row by Labelbox.
10+
- GlobalKey: The id assigned to a data row by the user.
11+
"""
12+
DataRowId = "ID"
13+
GlobalKey = "GKEY"

labelbox/schema/project.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from labelbox.schema.export_params import ProjectExportParams
3131
from labelbox.schema.export_task import ExportTask
3232
from labelbox.schema.id_type import IdType
33-
from labelbox.schema.identifiable import DataRowIdentifier, UniqueId
33+
from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId
3434
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
3535
from labelbox.schema.media_type import MediaType
3636
from labelbox.schema.queue_mode import QueueMode
@@ -62,16 +62,19 @@ def validate_labeling_parameter_overrides(
6262
)
6363
data_row_identifier = row[0]
6464
priority = row[1]
65-
if not isinstance(data_row_identifier,
66-
Entity.DataRow) and not isinstance(
67-
data_row_identifier, DataRowIdentifier):
65+
valid_types = (Entity.DataRow, UniqueId, GlobalKey)
66+
if not isinstance(data_row_identifier, valid_types):
6867
raise TypeError(
69-
f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row_identifier)}. Index: {idx}"
68+
f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}"
7069
)
7170

7271
if not isinstance(priority, int):
72+
if isinstance(data_row_identifier, Entity.DataRow):
73+
id = data_row_identifier.uid
74+
else:
75+
id = data_row_identifier
7376
raise TypeError(
74-
f"Priority must be an int. Found {type(priority)} for data_row_identifier {data_row_identifier}. Index: {idx}"
77+
f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}"
7578
)
7679

7780

@@ -1206,7 +1209,8 @@ def set_labeling_parameter_overrides(
12061209
for data_row, priority in data:
12071210
if isinstance(data_row, DataRow):
12081211
data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.uid}\", idType: {IdType.DataRowId}}}, priority: {priority}}},"
1209-
elif isinstance(data_row, DataRowIdentifier):
1212+
elif isinstance(data_row, UniqueId) or isinstance(
1213+
data_row, GlobalKey):
12101214
data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.key}\", idType: {data_row.id_type}}}, priority: {priority}}},"
12111215
else:
12121216
raise TypeError(

tests/integration/test_labeling_parameter_overrides.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from labelbox import DataRow
3+
from labelbox.schema.identifiable import GlobalKey, UniqueId
34
from labelbox.schema.identifiables import GlobalKeys, UniqueIds
45

56

@@ -27,17 +28,36 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
2728
for override in updated_overrides:
2829
assert isinstance(override.data_row(), DataRow)
2930

31+
data = [(UniqueId(data_rows[0].uid), 1, 2), (UniqueId(data_rows[1].uid), 2),
32+
(UniqueId(data_rows[2].uid), 3)]
33+
success = project.set_labeling_parameter_overrides(data)
34+
assert success
35+
updated_overrides = list(project.labeling_parameter_overrides())
36+
assert len(updated_overrides) == 3
37+
assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1}
38+
assert {o.priority for o in updated_overrides} == {1, 2, 3}
39+
40+
data = [(GlobalKey(data_rows[0].global_key), 2, 2),
41+
(GlobalKey(data_rows[1].global_key), 3, 3),
42+
(GlobalKey(data_rows[2].global_key), 4)]
43+
success = project.set_labeling_parameter_overrides(data)
44+
assert success
45+
updated_overrides = list(project.labeling_parameter_overrides())
46+
assert len(updated_overrides) == 3
47+
assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1}
48+
assert {o.priority for o in updated_overrides} == {2, 3, 4}
49+
3050
with pytest.raises(TypeError) as exc_info:
3151
data = [(data_rows[2], "a_string", 3)]
3252
project.set_labeling_parameter_overrides(data)
3353
assert str(exc_info.value) == \
34-
f"Priority must be an int. Found <class 'str'> for data_row {data_rows[2]}. Index: 0"
54+
f"Priority must be an int. Found <class 'str'> for data_row_identifier {data_rows[2].uid}"
3555

3656
with pytest.raises(TypeError) as exc_info:
3757
data = [(data_rows[2].uid, 1)]
3858
project.set_labeling_parameter_overrides(data)
3959
assert str(exc_info.value) == \
40-
"data_row should be be of type DataRow. Found <class 'str'>. Index: 0"
60+
f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found <class 'str'> for data_row_identifier {data_rows[2].uid}"
4161

4262

4363
def test_set_labeling_priority(consensus_project_with_batch):

0 commit comments

Comments
 (0)