Skip to content

Commit 8584745

Browse files
authored
Vb/set labeling params overrides gk sdk 412 (#1336)
2 parents fbea8c9 + e55026d commit 8584745

File tree

8 files changed

+166
-54
lines changed

8 files changed

+166
-54
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@
3434
from labelbox.schema.queue_mode import QueueMode
3535
from labelbox.schema.task_queue import TaskQueue
3636
from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds
37+
from labelbox.schema.identifiable import UniqueId, GlobalKey

labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
import labelbox.schema.iam_integration
2323
import labelbox.schema.media_type
2424
import labelbox.schema.identifiables
25+
import labelbox.schema.identifiable

labelbox/schema/id_type.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from enum import Enum
2+
3+
4+
class IdType(str, Enum):
5+
"""
6+
The type of id used to identify a data row.
7+
8+
Currently supported types are:
9+
- DataRowId: The id assigned to a data row by Labelbox.
10+
- GlobalKey: The id assigned to a data row by the user.
11+
"""
12+
DataRowId = "ID"
13+
GlobalKey = "GKEY"

labelbox/schema/identifiable.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,52 @@
1-
from abc import ABC, abstractmethod
2-
from typing import List, Union
1+
from abc import ABC
2+
from typing import Union
3+
4+
from labelbox.schema.id_type import IdType
35

46

57
class Identifiable(ABC):
68
"""
79
Base class for any object representing a unique identifier.
810
"""
911

10-
def __init__(self, key: str):
12+
def __init__(self, key: str, id_type: IdType):
1113
self._key = key
14+
self._id_type = id_type
1215

1316
@property
1417
def key(self):
15-
return self.key
18+
return self._key
19+
20+
@property
21+
def id_type(self):
22+
return self._id_type
1623

1724
def __eq__(self, other):
18-
return other.key == self.key
25+
return other.key == self.key and other.id_type == self.id_type
1926

2027
def __hash__(self):
21-
hash(self.key)
28+
return hash((self.key, self.id_type))
2229

2330
def __str__(self):
24-
return self.key.__str__()
31+
return f"{self.id_type}:{self.key}"
2532

2633

2734
class UniqueId(Identifiable):
2835
"""
2936
Represents a unique, internally generated id.
3037
"""
31-
pass
38+
39+
def __init__(self, key: str):
40+
super().__init__(key, IdType.DataRowId)
3241

3342

3443
class GlobalKey(Identifiable):
3544
"""
3645
Represents a user generated id.
3746
"""
38-
pass
47+
48+
def __init__(self, key: str):
49+
super().__init__(key, IdType.GlobalKey)
50+
51+
52+
DataRowIdentifier = Union[UniqueId, GlobalKey]

labelbox/schema/identifiables.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,6 @@
1-
from enum import Enum
21
from typing import List, Union
32

4-
5-
class IdType(str, Enum):
6-
"""
7-
The type of id used to identify a data row.
8-
9-
Currently supported types are:
10-
- DataRowId: The id assigned to a data row by Labelbox.
11-
- GlobalKey: The id assigned to a data row by the user.
12-
"""
13-
DataRowId = "ID"
14-
GlobalKey = "GKEY"
3+
from labelbox.schema.id_type import IdType
154

165

176
class Identifiables:

labelbox/schema/project.py

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
import logging
3+
from string import Template
34
import time
45
import warnings
56
from collections import namedtuple
67
from datetime import datetime, timezone
78
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar, Union, overload
9+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, overload
910
from urllib.parse import urlparse
1011

1112
import requests
@@ -28,6 +29,8 @@
2829
from labelbox.schema.export_filters import ProjectExportFilters, validate_datetime, build_filters
2930
from labelbox.schema.export_params import ProjectExportParams
3031
from labelbox.schema.export_task import ExportTask
32+
from labelbox.schema.id_type import IdType
33+
from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId
3134
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
3235
from labelbox.schema.media_type import MediaType
3336
from labelbox.schema.queue_mode import QueueMode
@@ -43,9 +46,38 @@
4346
except ImportError:
4447
pass
4548

49+
DataRowPriority = int
50+
LabelingParameterOverrideInput = Tuple[Union[DataRow, DataRowIdentifier],
51+
DataRowPriority]
52+
4653
logger = logging.getLogger(__name__)
4754

4855

56+
def validate_labeling_parameter_overrides(
57+
data: List[LabelingParameterOverrideInput]) -> None:
58+
for idx, row in enumerate(data):
59+
if len(row) < 2:
60+
raise TypeError(
61+
f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
62+
)
63+
data_row_identifier = row[0]
64+
priority = row[1]
65+
valid_types = (Entity.DataRow, UniqueId, GlobalKey)
66+
if not isinstance(data_row_identifier, valid_types):
67+
raise TypeError(
68+
f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}"
69+
)
70+
71+
if not isinstance(priority, int):
72+
if isinstance(data_row_identifier, Entity.DataRow):
73+
id = data_row_identifier.uid
74+
else:
75+
id = data_row_identifier
76+
raise TypeError(
77+
f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}"
78+
)
79+
80+
4981
class Project(DbObject, Updateable, Deletable):
5082
""" A Project is a container that includes a labeling frontend, an ontology,
5183
datasets and labels.
@@ -1129,36 +1161,25 @@ def get_queue_mode(self) -> "QueueMode":
11291161
else:
11301162
raise ValueError("Status not known")
11311163

1132-
def validate_labeling_parameter_overrides(self, data) -> None:
1133-
for idx, row in enumerate(data):
1134-
if len(row) < 2:
1135-
raise TypeError(
1136-
f"Data must be a list of tuples containing a DataRow and priority (int). Found {len(row)} items. Index: {idx}"
1137-
)
1138-
data_row = row[0]
1139-
priority = row[1]
1140-
if not isinstance(data_row, Entity.DataRow):
1141-
raise TypeError(
1142-
f"data_row should be be of type DataRow. Found {type(data_row)}. Index: {idx}"
1143-
)
1144-
1145-
if not isinstance(priority, int):
1146-
raise TypeError(
1147-
f"Priority must be an int. Found {type(priority)} for data_row {data_row}. Index: {idx}"
1148-
)
1149-
1150-
def set_labeling_parameter_overrides(self, data) -> bool:
1164+
def set_labeling_parameter_overrides(
1165+
self, data: List[LabelingParameterOverrideInput]) -> bool:
11511166
""" Adds labeling parameter overrides to this project.
11521167
11531168
See information on priority here:
11541169
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system
11551170
11561171
>>> project.set_labeling_parameter_overrides([
1157-
>>> (data_row_1, 2), (data_row_2, 1)])
1172+
>>> (data_row_id1, 2), (data_row_id2, 1)])
1173+
or
1174+
>>> project.set_labeling_parameter_overrides([
1175+
>>> (data_row_gk1, 2), (data_row_gk2, 1)])
11581176
11591177
Args:
11601178
data (iterable): An iterable of tuples. Each tuple must contain
1161-
(DataRow, priority<int>) for the new override.
1179+
either (DataRow, DataRowPriority<int>)
1180+
or (DataRowIdentifier, priority<int>) for the new override.
1181+
DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
1182+
NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead.
11621183
11631184
Priority:
11641185
* Data will be labeled in priority order.
@@ -1174,15 +1195,31 @@ def set_labeling_parameter_overrides(self, data) -> bool:
11741195
bool, indicates if the operation was a success.
11751196
"""
11761197
data = [t[:2] for t in data]
1177-
self.validate_labeling_parameter_overrides(data)
1178-
data_str = ",\n".join("{dataRow: {id: \"%s\"}, priority: %d }" %
1179-
(data_row.uid, priority)
1180-
for data_row, priority in data)
1181-
id_param = "projectId"
1182-
query_str = """mutation SetLabelingParameterOverridesPyApi($%s: ID!){
1183-
project(where: { id: $%s }) {setLabelingParameterOverrides
1184-
(data: [%s]) {success}}} """ % (id_param, id_param, data_str)
1185-
res = self.client.execute(query_str, {id_param: self.uid})
1198+
validate_labeling_parameter_overrides(data)
1199+
1200+
template = Template(
1201+
"""mutation SetLabelingParameterOverridesPyApi($$projectId: ID!)
1202+
{project(where: { id: $$projectId })
1203+
{setLabelingParameterOverrides
1204+
(dataWithDataRowIdentifiers: [$dataWithDataRowIdentifiers])
1205+
{success}}}
1206+
""")
1207+
1208+
data_rows_with_identifiers = ""
1209+
for data_row, priority in data:
1210+
if isinstance(data_row, DataRow):
1211+
data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.uid}\", idType: {IdType.DataRowId}}}, priority: {priority}}},"
1212+
elif isinstance(data_row, UniqueId) or isinstance(
1213+
data_row, GlobalKey):
1214+
data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.key}\", idType: {data_row.id_type}}}, priority: {priority}}},"
1215+
else:
1216+
raise TypeError(
1217+
f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}."
1218+
)
1219+
1220+
query_str = template.substitute(
1221+
dataWithDataRowIdentifiers=data_rows_with_identifiers)
1222+
res = self.client.execute(query_str, {"projectId": self.uid})
11861223
return res["project"]["setLabelingParameterOverrides"]["success"]
11871224

11881225
@overload

tests/integration/test_labeling_parameter_overrides.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from labelbox import DataRow
3+
from labelbox.schema.identifiable import GlobalKey, UniqueId
34
from labelbox.schema.identifiables import GlobalKeys, UniqueIds
45

56

@@ -27,17 +28,36 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
2728
for override in updated_overrides:
2829
assert isinstance(override.data_row(), DataRow)
2930

31+
data = [(UniqueId(data_rows[0].uid), 1, 2), (UniqueId(data_rows[1].uid), 2),
32+
(UniqueId(data_rows[2].uid), 3)]
33+
success = project.set_labeling_parameter_overrides(data)
34+
assert success
35+
updated_overrides = list(project.labeling_parameter_overrides())
36+
assert len(updated_overrides) == 3
37+
assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1}
38+
assert {o.priority for o in updated_overrides} == {1, 2, 3}
39+
40+
data = [(GlobalKey(data_rows[0].global_key), 2, 2),
41+
(GlobalKey(data_rows[1].global_key), 3, 3),
42+
(GlobalKey(data_rows[2].global_key), 4)]
43+
success = project.set_labeling_parameter_overrides(data)
44+
assert success
45+
updated_overrides = list(project.labeling_parameter_overrides())
46+
assert len(updated_overrides) == 3
47+
assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1}
48+
assert {o.priority for o in updated_overrides} == {2, 3, 4}
49+
3050
with pytest.raises(TypeError) as exc_info:
3151
data = [(data_rows[2], "a_string", 3)]
3252
project.set_labeling_parameter_overrides(data)
3353
assert str(exc_info.value) == \
34-
f"Priority must be an int. Found <class 'str'> for data_row {data_rows[2]}. Index: 0"
54+
f"Priority must be an int. Found <class 'str'> for data_row_identifier {data_rows[2].uid}"
3555

3656
with pytest.raises(TypeError) as exc_info:
3757
data = [(data_rows[2].uid, 1)]
3858
project.set_labeling_parameter_overrides(data)
3959
assert str(exc_info.value) == \
40-
"data_row should be be of type DataRow. Found <class 'str'>. Index: 0"
60+
f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found <class 'str'> for data_row_identifier {data_rows[2].uid}"
4161

4262

4363
def test_set_labeling_priority(consensus_project_with_batch):
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
from unittest.mock import MagicMock
3+
4+
from labelbox.schema.data_row import DataRow
5+
from labelbox.schema.identifiable import GlobalKey, UniqueId
6+
from labelbox.schema.project import validate_labeling_parameter_overrides
7+
8+
9+
def test_validate_labeling_parameter_overrides_valid_data():
10+
mock_data_row = MagicMock(spec=DataRow)
11+
mock_data_row.uid = "abc"
12+
data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
13+
validate_labeling_parameter_overrides(data)
14+
15+
16+
def test_validate_labeling_parameter_overrides_invalid_data():
17+
data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
18+
with pytest.raises(TypeError):
19+
validate_labeling_parameter_overrides(data)
20+
21+
22+
def test_validate_labeling_parameter_overrides_invalid_priority():
23+
mock_data_row = MagicMock(spec=DataRow)
24+
mock_data_row.uid = "abc"
25+
data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2),
26+
(GlobalKey("hij"), 3)]
27+
with pytest.raises(TypeError):
28+
validate_labeling_parameter_overrides(data)
29+
30+
31+
def test_validate_labeling_parameter_overrides_invalid_tuple_length():
32+
mock_data_row = MagicMock(spec=DataRow)
33+
mock_data_row.uid = "abc"
34+
data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2),
35+
(GlobalKey("hij"))]
36+
with pytest.raises(TypeError):
37+
validate_labeling_parameter_overrides(data)

0 commit comments

Comments
 (0)