Skip to content

Commit 86112a4

Browse files
author
Richard Sun
committed
[QQC-1484] Support move to task action
1 parent 8113a8c commit 86112a4

File tree

6 files changed

+245
-44
lines changed

6 files changed

+245
-44
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@
2929
from labelbox.schema.media_type import MediaType
3030
from labelbox.schema.slice import Slice, CatalogSlice
3131
from labelbox.schema.queue_mode import QueueMode
32+
from labelbox.schema.task_queue import TaskQueue

labelbox/orm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ class Entity(metaclass=EntityMeta):
378378
Project: Type[labelbox.Project]
379379
Batch: Type[labelbox.Batch]
380380
CatalogSlice: Type[labelbox.CatalogSlice]
381+
TaskQueue: Type[labelbox.TaskQueue]
381382

382383
@classmethod
383384
def _attributes_of_type(cls, attr_type):

labelbox/schema/project.py

Lines changed: 118 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from labelbox.schema.resource_tag import ResourceTag
2727
from labelbox.schema.task import Task
2828
from labelbox.schema.user import User
29+
from labelbox.schema.task_queue import TaskQueue
2930

3031
if TYPE_CHECKING:
3132
from labelbox import BulkImportRequest
@@ -69,6 +70,7 @@ class Project(DbObject, Updateable, Deletable):
6970
webhooks (Relationship): `ToMany` relationship to Webhook
7071
benchmarks (Relationship): `ToMany` relationship to Benchmark
7172
ontology (Relationship): `ToOne` relationship to Ontology
73+
task_queues (Relationship): `ToMany` relationship to TaskQueue
7274
"""
7375

7476
name = Field.String("name")
@@ -794,54 +796,33 @@ def _create_batch_async(self,
794796

795797
task_id = res['taskId']
796798

797-
timeout_seconds = 600
798-
sleep_time = 2
799-
get_task_query_str = """query %s($taskId: ID!) {
800-
task(where: {id: $taskId}) {
801-
status
799+
status = self._wait_for_task(task_id)
800+
if status != "COMPLETE":
801+
raise LabelboxError(f"Batch was not created successfully.")
802+
803+
# obtain batch entity to return
804+
get_batch_str = """query %s($projectId: ID!, $batchId: ID!) {
805+
project(where: {id: $projectId}) {
806+
batches(where: {id: $batchId}) {
807+
nodes {
808+
%s
809+
}
810+
}
802811
}
803812
}
804-
""" % "getTaskPyApi"
813+
""" % ("getProjectBatchPyApi",
814+
query.results_query_part(Entity.Batch))
805815

806-
while True:
807-
task_status = self.client.execute(
808-
get_task_query_str, {'taskId': task_id},
809-
experimental=True)['task']['status']
816+
batch = self.client.execute(
817+
get_batch_str, {
818+
"projectId": self.uid,
819+
"batchId": batch_id
820+
},
821+
timeout=180.0,
822+
experimental=True)["project"]["batches"]["nodes"][0]
810823

811-
if task_status == "COMPLETE":
812-
# obtain batch entity to return
813-
get_batch_str = """query %s($projectId: ID!, $batchId: ID!) {
814-
project(where: {id: $projectId}) {
815-
batches(where: {id: $batchId}) {
816-
nodes {
817-
%s
818-
}
819-
}
820-
}
821-
}
822-
""" % ("getProjectBatchPyApi",
823-
query.results_query_part(Entity.Batch))
824-
825-
batch = self.client.execute(
826-
get_batch_str, {
827-
"projectId": self.uid,
828-
"batchId": batch_id
829-
},
830-
timeout=180.0,
831-
experimental=True)["project"]["batches"]["nodes"][0]
832-
833-
# TODO async endpoints currently do not provide failed_data_row_ids in response
834-
return Entity.Batch(self.client, self.uid, batch)
835-
elif task_status == "IN_PROGRESS":
836-
timeout_seconds -= sleep_time
837-
if timeout_seconds <= 0:
838-
raise LabelboxError(
839-
f"Timed out while waiting for batch to be created.")
840-
logger.debug("Creating batch, waiting for server...", self.uid)
841-
time.sleep(sleep_time)
842-
continue
843-
else:
844-
raise LabelboxError(f"Batch was not created successfully.")
824+
# TODO async endpoints currently do not provide failed_data_row_ids in response
825+
return Entity.Batch(self.client, self.uid, batch)
845826

846827
def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode":
847828
"""
@@ -1127,6 +1108,99 @@ def batches(self) -> PaginatedCollection:
11271108
cursor_path=['project', 'batches', 'pageInfo', 'endCursor'],
11281109
experimental=True)
11291110

1111+
def task_queues(self) -> List[TaskQueue]:
1112+
""" Fetch all task queues that belong to this project
1113+
1114+
Returns:
1115+
A `List of `TaskQueue`s
1116+
"""
1117+
query_str = """query GetProjectTaskQueuesPyApi($projectId: ID!) {
1118+
project(where: {id: $projectId}) {
1119+
taskQueues {
1120+
%s
1121+
}
1122+
}
1123+
}
1124+
""" % (query.results_query_part(Entity.TaskQueue))
1125+
1126+
task_queue_values = self.client.execute(
1127+
query_str, {"projectId": self.uid},
1128+
timeout=180.0,
1129+
experimental=True)["project"]["taskQueues"]
1130+
1131+
return [
1132+
Entity.TaskQueue(self.client, field_values)
1133+
for field_values in task_queue_values
1134+
]
1135+
1136+
def move_data_rows_to_task(self, data_row_ids: List[str],
1137+
task_queue_id: str):
1138+
"""
1139+
1140+
Moves data rows to the specified task queue.
1141+
1142+
Args:
1143+
data_row_ids: a list of data row ids to be moved
1144+
task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue
1145+
1146+
Returns:
1147+
None if successful, or a raised error on failure
1148+
1149+
"""
1150+
method = "createBulkAddRowsToQueueTask"
1151+
query_str = """mutation AddDataRowsToTaskQueueAsyncPyApi(
1152+
$projectId: ID!
1153+
$queueId: ID
1154+
$dataRowIds: [ID!]!
1155+
) {
1156+
project(where: { id: $projectId }) {
1157+
%s(
1158+
data: { queueId: $queueId, dataRowIds: $dataRowIds }
1159+
) {
1160+
taskId
1161+
}
1162+
}
1163+
}
1164+
""" % method
1165+
1166+
task_id = self.client.execute(
1167+
query_str, {
1168+
"projectId": self.uid,
1169+
"queueId": task_queue_id,
1170+
"dataRowIds": data_row_ids
1171+
},
1172+
timeout=180.0,
1173+
experimental=True)["project"][method]["taskId"]
1174+
1175+
status = self._wait_for_task(task_id)
1176+
if status != "COMPLETE":
1177+
raise LabelboxError(f"Data rows were not moved successfully")
1178+
1179+
def _wait_for_task(self, task_id: str):
1180+
timeout_seconds = 600
1181+
sleep_time = 2
1182+
get_task_query_str = """query %s($taskId: ID!) {
1183+
task(where: {id: $taskId}) {
1184+
status
1185+
}
1186+
}
1187+
""" % "getTaskPyApi"
1188+
1189+
while True:
1190+
task_status = self.client.execute(
1191+
get_task_query_str, {'taskId': task_id},
1192+
experimental=True)['task']['status']
1193+
1194+
if task_status == "IN_PROGRESS":
1195+
timeout_seconds -= sleep_time
1196+
if timeout_seconds <= 0:
1197+
raise LabelboxError(
1198+
f"Timed out while waiting for task to be completed.")
1199+
time.sleep(sleep_time)
1200+
continue
1201+
1202+
return task_status
1203+
11301204
def upload_annotations(
11311205
self,
11321206
name: str,

labelbox/schema/task_queue.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from labelbox.orm.db_object import DbObject
2+
from labelbox.orm.model import Field
3+
4+
5+
class TaskQueue(DbObject):
6+
"""
7+
a task queue
8+
9+
Attributes
10+
name
11+
description
12+
queue_type
13+
data_row_count
14+
15+
Relationships
16+
project
17+
organization
18+
pass_queue
19+
fail_queue
20+
"""
21+
22+
name = Field.String("name")
23+
description = Field.String("description")
24+
queue_type = Field.String("queue_type")
25+
data_row_count = Field.Int("data_row_count")
26+
27+
def __init__(self, client, *args, **kwargs):
28+
super().__init__(client, *args, **kwargs)

tests/integration/conftest.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,66 @@ def create_label():
359359
label.delete()
360360

361361

362+
@pytest.fixture
363+
def configured_batch_project_with_label(client, rand_gen, image_url,
364+
batch_project, dataset, datarow,
365+
wait_for_label_processing):
366+
"""Project with a batch having one datarow
367+
Project contains an ontology with 1 bbox tool
368+
Additionally includes a create_label method for any needed extra labels
369+
One label is already created and yielded when using fixture
370+
"""
371+
data_rows = [dr.uid for dr in list(dataset.data_rows())]
372+
batch_project.create_batch("test-batch", data_rows)
373+
editor = list(
374+
batch_project.client.get_labeling_frontends(
375+
where=LabelingFrontend.name == "editor"))[0]
376+
377+
ontology_builder = OntologyBuilder(tools=[
378+
Tool(tool=Tool.Type.BBOX, name="test-bbox-class"),
379+
])
380+
batch_project.setup(editor, ontology_builder.asdict())
381+
# TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent
382+
time.sleep(2)
383+
384+
ontology = ontology_builder.from_project(batch_project)
385+
predictions = [{
386+
"uuid": str(uuid.uuid4()),
387+
"schemaId": ontology.tools[0].feature_schema_id,
388+
"dataRow": {
389+
"id": datarow.uid
390+
},
391+
"bbox": {
392+
"top": 20,
393+
"left": 20,
394+
"height": 50,
395+
"width": 50
396+
}
397+
}]
398+
399+
def create_label():
400+
""" Ad-hoc function to create a LabelImport
401+
Creates a LabelImport task which will create a label
402+
"""
403+
upload_task = LabelImport.create_from_objects(
404+
client, batch_project.uid, f'label-import-{uuid.uuid4()}',
405+
predictions)
406+
upload_task.wait_until_done(sleep_time_seconds=5)
407+
assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish"
408+
assert len(
409+
upload_task.errors
410+
) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}"
411+
412+
batch_project.create_label = create_label
413+
batch_project.create_label()
414+
label = wait_for_label_processing(batch_project)[0]
415+
416+
yield [batch_project, dataset, datarow, label]
417+
418+
for label in batch_project.labels():
419+
label.delete()
420+
421+
362422
@pytest.fixture
363423
def configured_project_with_complex_ontology(client, rand_gen, image_url):
364424
project = client.create_project(name=rand_gen(str),
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import time
2+
3+
from labelbox import Project
4+
5+
6+
def test_get_task_queue(batch_project: Project):
7+
task_queues = batch_project.task_queues()
8+
assert len(task_queues) == 3
9+
review_queue = next(
10+
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE")
11+
assert review_queue
12+
13+
14+
def test_move_to_task(configured_batch_project_with_label: Project):
15+
project, _, data_row, label = configured_batch_project_with_label
16+
task_queues = project.task_queues()
17+
18+
review_queue = next(
19+
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE")
20+
project.move_data_rows_to_task([data_row.uid], review_queue.uid)
21+
22+
timeout_seconds = 30
23+
sleep_time = 2
24+
while True:
25+
task_queues = project.task_queues()
26+
review_queue = next(
27+
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE")
28+
29+
if review_queue.data_row_count == 1:
30+
break
31+
32+
if timeout_seconds <= 0:
33+
raise AssertionError(
34+
"Timed out expecting data_row_count of 1 in the review queue")
35+
36+
timeout_seconds -= sleep_time
37+
time.sleep(sleep_time)

0 commit comments

Comments
 (0)