Skip to content

Commit 7f27b78

Browse files
author
Val Brodsky
committed
Support set_labeling_parameter_overrides for global keys
1 parent fbea8c9 commit 7f27b78

File tree

6 files changed

+127
-52
lines changed

6 files changed

+127
-52
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@
3434
from labelbox.schema.queue_mode import QueueMode
3535
from labelbox.schema.task_queue import TaskQueue
3636
from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds
37+
from labelbox.schema.identifiable import UniqueId, GlobalKey

labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
import labelbox.schema.iam_integration
2323
import labelbox.schema.media_type
2424
import labelbox.schema.identifiables
25+
import labelbox.schema.identifiable

labelbox/schema/identifiable.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,52 @@
1-
from abc import ABC, abstractmethod
2-
from typing import List, Union
1+
from abc import ABC
2+
from typing import Union
3+
4+
from labelbox.schema.id_type import IdType
35

46

57
class Identifiable(ABC):
68
"""
79
Base class for any object representing a unique identifier.
810
"""
911

10-
def __init__(self, key: str):
12+
def __init__(self, key: str, id_type: IdType):
1113
self._key = key
14+
self._id_type = id_type
1215

1316
@property
1417
def key(self):
15-
return self.key
18+
return self._key
19+
20+
@property
21+
def id_type(self):
22+
return self._id_type
1623

1724
def __eq__(self, other):
18-
return other.key == self.key
25+
return other.key == self.key and other.id_type == self.id_type
1926

2027
def __hash__(self):
21-
hash(self.key)
28+
return hash((self.key, self.id_type))
2229

2330
def __str__(self):
24-
return self.key.__str__()
31+
return f"{self.id_type}:{self.key}"
2532

2633

2734
class UniqueId(Identifiable):
2835
"""
2936
Represents a unique, internally generated id.
3037
"""
31-
pass
38+
39+
def __init__(self, key: str):
40+
super().__init__(key, IdType.DataRowId)
3241

3342

3443
class GlobalKey(Identifiable):
3544
"""
3645
Represents a user generated id.
3746
"""
38-
pass
47+
48+
def __init__(self, key: str):
49+
super().__init__(key, IdType.GlobalKey)
50+
51+
52+
DataRowIdentifier = Union[UniqueId, GlobalKey]

labelbox/schema/identifiables.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,6 @@
1-
from enum import Enum
21
from typing import List, Union
32

4-
5-
class IdType(str, Enum):
6-
"""
7-
The type of id used to identify a data row.
8-
9-
Currently supported types are:
10-
- DataRowId: The id assigned to a data row by Labelbox.
11-
- GlobalKey: The id assigned to a data row by the user.
12-
"""
13-
DataRowId = "ID"
14-
GlobalKey = "GKEY"
3+
from labelbox.schema.id_type import IdType
154

165

176
class Identifiables:

labelbox/schema/project.py

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
import logging
3+
from string import Template
34
import time
45
import warnings
56
from collections import namedtuple
67
from datetime import datetime, timezone
78
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar, Union, overload
9+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, overload
910
from urllib.parse import urlparse
1011

1112
import requests
@@ -28,6 +29,8 @@
2829
from labelbox.schema.export_filters import ProjectExportFilters, validate_datetime, build_filters
2930
from labelbox.schema.export_params import ProjectExportParams
3031
from labelbox.schema.export_task import ExportTask
32+
from labelbox.schema.id_type import IdType
33+
from labelbox.schema.identifiable import DataRowIdentifier, UniqueId
3134
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
3235
from labelbox.schema.media_type import MediaType
3336
from labelbox.schema.queue_mode import QueueMode
@@ -43,9 +46,35 @@
4346
except ImportError:
4447
pass
4548

49+
DataRowPriority = int
50+
LabelingParameterOverrideInput = Tuple[Union[DataRow, DataRowIdentifier],
51+
DataRowPriority]
52+
4653
logger = logging.getLogger(__name__)
4754

4855

56+
def validate_labeling_parameter_overrides(
57+
data: List[LabelingParameterOverrideInput]) -> None:
58+
for idx, row in enumerate(data):
59+
if len(row) < 2:
60+
raise TypeError(
61+
f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
62+
)
63+
data_row_identifier = row[0]
64+
priority = row[1]
65+
if not isinstance(data_row_identifier,
66+
Entity.DataRow) and not isinstance(
67+
data_row_identifier, DataRowIdentifier):
68+
raise TypeError(
69+
f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row_identifier)}. Index: {idx}"
70+
)
71+
72+
if not isinstance(priority, int):
73+
raise TypeError(
74+
f"Priority must be an int. Found {type(priority)} for data_row_identifier {data_row_identifier}. Index: {idx}"
75+
)
76+
77+
4978
class Project(DbObject, Updateable, Deletable):
5079
""" A Project is a container that includes a labeling frontend, an ontology,
5180
datasets and labels.
@@ -1129,36 +1158,25 @@ def get_queue_mode(self) -> "QueueMode":
11291158
else:
11301159
raise ValueError("Status not known")
11311160

1132-
def validate_labeling_parameter_overrides(self, data) -> None:
1133-
for idx, row in enumerate(data):
1134-
if len(row) < 2:
1135-
raise TypeError(
1136-
f"Data must be a list of tuples containing a DataRow and priority (int). Found {len(row)} items. Index: {idx}"
1137-
)
1138-
data_row = row[0]
1139-
priority = row[1]
1140-
if not isinstance(data_row, Entity.DataRow):
1141-
raise TypeError(
1142-
f"data_row should be be of type DataRow. Found {type(data_row)}. Index: {idx}"
1143-
)
1144-
1145-
if not isinstance(priority, int):
1146-
raise TypeError(
1147-
f"Priority must be an int. Found {type(priority)} for data_row {data_row}. Index: {idx}"
1148-
)
1149-
1150-
def set_labeling_parameter_overrides(self, data) -> bool:
1161+
def set_labeling_parameter_overrides(
1162+
self, data: List[LabelingParameterOverrideInput]) -> bool:
11511163
""" Adds labeling parameter overrides to this project.
11521164
11531165
See information on priority here:
11541166
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system
11551167
11561168
>>> project.set_labeling_parameter_overrides([
1157-
>>> (data_row_1, 2), (data_row_2, 1)])
1169+
>>> (data_row_id1, 2), (data_row_id2, 1)])
1170+
or
1171+
>>> project.set_labeling_parameter_overrides([
1172+
>>> (data_row_gk1, 2), (data_row_gk2, 1)])
11581173
11591174
Args:
11601175
data (iterable): An iterable of tuples. Each tuple must contain
1161-
(DataRow, priority<int>) for the new override.
1176+
either (DataRow, DataRowPriority<int>)
1177+
or (DataRowIdentifier, priority<int>) for the new override.
1178+
DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
1179+
NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead.
11621180
11631181
Priority:
11641182
* Data will be labeled in priority order.
@@ -1174,15 +1192,30 @@ def set_labeling_parameter_overrides(self, data) -> bool:
11741192
bool, indicates if the operation was a success.
11751193
"""
11761194
data = [t[:2] for t in data]
1177-
self.validate_labeling_parameter_overrides(data)
1178-
data_str = ",\n".join("{dataRow: {id: \"%s\"}, priority: %d }" %
1179-
(data_row.uid, priority)
1180-
for data_row, priority in data)
1181-
id_param = "projectId"
1182-
query_str = """mutation SetLabelingParameterOverridesPyApi($%s: ID!){
1183-
project(where: { id: $%s }) {setLabelingParameterOverrides
1184-
(data: [%s]) {success}}} """ % (id_param, id_param, data_str)
1185-
res = self.client.execute(query_str, {id_param: self.uid})
1195+
validate_labeling_parameter_overrides(data)
1196+
1197+
template = Template(
1198+
"""mutation SetLabelingParameterOverridesPyApi($$projectId: ID!)
1199+
{project(where: { id: $$projectId })
1200+
{setLabelingParameterOverrides
1201+
(dataWithDataRowIdentifiers: [$dataWithDataRowIdentifiers])
1202+
{success}}}
1203+
""")
1204+
1205+
data_rows_with_identifiers = ""
1206+
for data_row, priority in data:
1207+
if isinstance(data_row, DataRow):
1208+
data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.uid}\", idType: {IdType.DataRowId}}}, priority: {priority}}},"
1209+
elif isinstance(data_row, DataRowIdentifier):
1210+
data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.key}\", idType: {data_row.id_type}}}, priority: {priority}}},"
1211+
else:
1212+
raise TypeError(
1213+
f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}."
1214+
)
1215+
1216+
query_str = template.substitute(
1217+
dataWithDataRowIdentifiers=data_rows_with_identifiers)
1218+
res = self.client.execute(query_str, {"projectId": self.uid})
11861219
return res["project"]["setLabelingParameterOverrides"]["success"]
11871220

11881221
@overload
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
from unittest.mock import MagicMock
3+
4+
from labelbox.schema.data_row import DataRow
5+
from labelbox.schema.identifiable import GlobalKey, UniqueId
6+
from labelbox.schema.project import validate_labeling_parameter_overrides
7+
8+
9+
def test_validate_labeling_parameter_overrides_valid_data():
10+
mock_data_row = MagicMock(spec=DataRow)
11+
mock_data_row.uid = "abc"
12+
data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
13+
validate_labeling_parameter_overrides(data)
14+
15+
16+
def test_validate_labeling_parameter_overrides_invalid_data():
17+
data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
18+
with pytest.raises(TypeError):
19+
validate_labeling_parameter_overrides(data)
20+
21+
22+
def test_validate_labeling_parameter_overrides_invalid_priority():
23+
mock_data_row = MagicMock(spec=DataRow)
24+
mock_data_row.uid = "abc"
25+
data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2),
26+
(GlobalKey("hij"), 3)]
27+
with pytest.raises(TypeError):
28+
validate_labeling_parameter_overrides(data)
29+
30+
31+
def test_validate_labeling_parameter_overrides_invalid_tuple_length():
32+
mock_data_row = MagicMock(spec=DataRow)
33+
mock_data_row.uid = "abc"
34+
data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2),
35+
(GlobalKey("hij"))]
36+
with pytest.raises(TypeError):
37+
validate_labeling_parameter_overrides(data)

0 commit comments

Comments
 (0)