Skip to content

Commit deac9c4

Browse files
authored
Add global key support for Project move_data_rows_to_task_queue (#1284)
2 parents b29eff3 + d61b973 commit deac9c4

File tree

7 files changed

+162
-27
lines changed

7 files changed

+162
-27
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@
3333
from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice
3434
from labelbox.schema.queue_mode import QueueMode
3535
from labelbox.schema.task_queue import TaskQueue
36+
from labelbox.schema.identifiables import UniqueIds, GlobalKeys

labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
import labelbox.schema.batch
2222
import labelbox.schema.iam_integration
2323
import labelbox.schema.media_type
24+
import labelbox.schema.identifiables

labelbox/schema/identifiable.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Union
3+
4+
5+
class Identifiable(ABC):
6+
"""
7+
Base class for any object representing a unique identifier.
8+
"""
9+
10+
def __init__(self, key: str):
11+
self._key = key
12+
13+
@property
14+
def key(self):
15+
return self.key
16+
17+
def __eq__(self, other):
18+
return other.key == self.key
19+
20+
def __hash__(self):
21+
hash(self.key)
22+
23+
def __str__(self):
24+
return self.key.__str__()
25+
26+
27+
class UniqueId(Identifiable):
28+
"""
29+
Represents a unique, internally generated id.
30+
"""
31+
pass
32+
33+
34+
class GlobalKey(Identifiable):
35+
"""
36+
Represents a user generated id.
37+
"""
38+
pass

labelbox/schema/identifiables.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from enum import Enum
2+
from typing import List, Union
3+
4+
5+
class IdType(str, Enum):
6+
"""
7+
The type of id used to identify a data row.
8+
Currently supported types are:
9+
- DataRowId: The id assigned to a data row by Labelbox.
10+
- GlobalKey: The id assigned to a data row by the user.
11+
"""
12+
DataRowId = "ID"
13+
GlobalKey = "GKEY"
14+
15+
16+
class Identifiables:
17+
18+
def __init__(self, iterable, id_type: IdType):
19+
"""
20+
Args:
21+
iterable: Iterable of ids (unique or global keys)
22+
id_type: The type of id used to identify a data row.
23+
"""
24+
self._iterable = iterable
25+
self._index = 0
26+
self._id_type = id_type
27+
28+
def __iter__(self):
29+
return iter(self._iterable)
30+
31+
32+
class UniqueIds(Identifiables):
33+
"""
34+
Represents a collection of unique, internally generated ids.
35+
"""
36+
37+
def __init__(self, iterable: List[str]):
38+
super().__init__(iterable, IdType.DataRowId)
39+
40+
41+
class GlobalKeys(Identifiables):
42+
"""
43+
Represents a collection of user generated ids.
44+
"""
45+
46+
def __init__(self, iterable: List[str]):
47+
super().__init__(iterable, IdType.GlobalKey)
48+
49+
50+
DataRowIdentifiers = Union[UniqueIds, GlobalKeys]

labelbox/schema/project.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55
from collections import namedtuple
66
from datetime import datetime, timezone
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
8+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar, Union, overload
99
from urllib.parse import urlparse
1010

1111
import requests
1212

1313
from labelbox import parser
1414
from labelbox import utils
15-
from labelbox.exceptions import (InvalidQueryError, LabelboxError,
16-
ProcessingWaitTimeout, ResourceConflict,
17-
ResourceNotFoundError)
15+
from labelbox.exceptions import (
16+
InvalidQueryError,
17+
LabelboxError,
18+
ProcessingWaitTimeout,
19+
ResourceConflict,
20+
)
1821
from labelbox.orm import query
1922
from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental
2023
from labelbox.orm.model import Entity, Field, Relationship
@@ -25,23 +28,16 @@
2528
from labelbox.schema.export_filters import ProjectExportFilters, validate_datetime, build_filters
2629
from labelbox.schema.export_params import ProjectExportParams
2730
from labelbox.schema.export_task import ExportTask
31+
from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
2832
from labelbox.schema.media_type import MediaType
2933
from labelbox.schema.queue_mode import QueueMode
3034
from labelbox.schema.resource_tag import ResourceTag
3135
from labelbox.schema.task import Task
3236
from labelbox.schema.task_queue import TaskQueue
33-
from labelbox.schema.user import User
3437

3538
if TYPE_CHECKING:
3639
from labelbox import BulkImportRequest
3740

38-
try:
39-
datetime.fromisoformat # type: ignore[attr-defined]
40-
except AttributeError:
41-
from backports.datetime_fromisoformat import MonkeyPatch
42-
43-
MonkeyPatch.patch_fromisoformat()
44-
4541
try:
4642
from labelbox.data.serialization import LBV1Converter
4743
except ImportError:
@@ -643,7 +639,7 @@ def labeler_performance(self) -> PaginatedCollection:
643639
def create_labeler_performance(client, result):
644640
result["user"] = Entity.User(client, result["user"])
645641
# python isoformat doesn't accept Z as utc timezone
646-
result["lastActivityTime"] = datetime.fromisoformat(
642+
result["lastActivityTime"] = utils.format_iso_from_string(
647643
result["lastActivityTime"].replace('Z', '+00:00'))
648644
return LabelerPerformance(**{
649645
utils.snake_case(key): value for key, value in result.items()
@@ -1382,29 +1378,44 @@ def task_queues(self) -> List[TaskQueue]:
13821378
for field_values in task_queue_values
13831379
]
13841380

1381+
@overload
1382+
def move_data_rows_to_task_queue(self, data_row_ids: DataRowIdentifiers,
1383+
task_queue_id: str):
1384+
pass
1385+
1386+
@overload
13851387
def move_data_rows_to_task_queue(self, data_row_ids: List[str],
13861388
task_queue_id: str):
1389+
pass
1390+
1391+
def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str):
13871392
"""
13881393
13891394
Moves data rows to the specified task queue.
13901395
13911396
Args:
1392-
data_row_ids: a list of data row ids to be moved
1397+
data_row_ids: a list of data row ids to be moved. This can be a list of strings or a DataRowIdentifiers object
1398+
DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
13931399
task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue
13941400
13951401
Returns:
13961402
None if successful, or a raised error on failure
13971403
13981404
"""
1405+
if isinstance(data_row_ids, list):
1406+
data_row_ids = UniqueIds(data_row_ids)
1407+
warnings.warn("Using data row ids will be deprecated. Please use "
1408+
"UniqueIds or GlobalKeys instead.")
1409+
13991410
method = "createBulkAddRowsToQueueTask"
14001411
query_str = """mutation AddDataRowsToTaskQueueAsyncPyApi(
14011412
$projectId: ID!
14021413
$queueId: ID
1403-
$dataRowIds: [ID!]!
1414+
$dataRowIdentifiers: AddRowsToTaskQueueViaDataRowIdentifiersInput!
14041415
) {
14051416
project(where: { id: $projectId }) {
14061417
%s(
1407-
data: { queueId: $queueId, dataRowIds: $dataRowIds }
1418+
data: { queueId: $queueId, dataRowIdentifiers: $dataRowIdentifiers }
14081419
) {
14091420
taskId
14101421
}
@@ -1416,7 +1427,10 @@ def move_data_rows_to_task_queue(self, data_row_ids: List[str],
14161427
query_str, {
14171428
"projectId": self.uid,
14181429
"queueId": task_queue_id,
1419-
"dataRowIds": data_row_ids
1430+
"dataRowIdentifiers": {
1431+
"ids": [id for id in data_row_ids],
1432+
"idType": data_row_ids._id_type,
1433+
},
14201434
},
14211435
timeout=180.0,
14221436
experimental=True)["project"][method]["taskId"]
Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22

33
from labelbox import Project
4+
from labelbox.schema.identifiables import GlobalKeys, UniqueIds
45

56

67
def test_get_task_queue(project: Project):
@@ -11,22 +12,15 @@ def test_get_task_queue(project: Project):
1112
assert review_queue
1213

1314

14-
def test_move_to_task(configured_batch_project_with_label: Project):
15-
project, _, data_row, label = configured_batch_project_with_label
16-
task_queues = project.task_queues()
17-
18-
review_queue = next(
19-
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE")
20-
project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid)
21-
15+
def _validate_moved(project, queue_name, data_row_count):
2216
timeout_seconds = 30
2317
sleep_time = 2
2418
while True:
2519
task_queues = project.task_queues()
2620
review_queue = next(
27-
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE")
21+
tq for tq in task_queues if tq.queue_type == queue_name)
2822

29-
if review_queue.data_row_count == 1:
23+
if review_queue.data_row_count == data_row_count:
3024
break
3125

3226
if timeout_seconds <= 0:
@@ -35,3 +29,25 @@ def test_move_to_task(configured_batch_project_with_label: Project):
3529

3630
timeout_seconds -= sleep_time
3731
time.sleep(sleep_time)
32+
33+
34+
def test_move_to_task(configured_batch_project_with_label):
35+
project, _, data_row, _ = configured_batch_project_with_label
36+
task_queues = project.task_queues()
37+
38+
review_queue = next(
39+
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE")
40+
project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid)
41+
_validate_moved(project, "MANUAL_REVIEW_QUEUE", 1)
42+
43+
review_queue = next(
44+
tq for tq in task_queues if tq.queue_type == "MANUAL_REWORK_QUEUE")
45+
project.move_data_rows_to_task_queue(GlobalKeys([data_row.global_key]),
46+
review_queue.uid)
47+
_validate_moved(project, "MANUAL_REWORK_QUEUE", 1)
48+
49+
review_queue = next(
50+
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE")
51+
project.move_data_rows_to_task_queue(UniqueIds([data_row.uid]),
52+
review_queue.uid)
53+
_validate_moved(project, "MANUAL_REVIEW_QUEUE", 1)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from labelbox.schema.identifiables import GlobalKeys, UniqueIds
2+
3+
4+
def test_unique_ids():
5+
ids = ["a", "b", "c"]
6+
identifiables = UniqueIds(ids)
7+
assert [i for i in identifiables] == ids
8+
assert identifiables._id_type == "ID"
9+
10+
11+
def test_global_keys():
12+
ids = ["a", "b", "c"]
13+
identifiables = GlobalKeys(ids)
14+
assert [i for i in identifiables] == ids
15+
assert identifiables._id_type == "GKEY"

0 commit comments

Comments
 (0)