Skip to content

Commit 0f38024

Browse files
Merge pull request #755 from Labelbox/sdubinin/al-4081
[AL-4081] Wait for data rows to be processed when creating a batch
2 parents 893ed1e + 3b26955 commit 0f38024

File tree

6 files changed

+150
-12
lines changed

6 files changed

+150
-12
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@
2727
from labelbox.schema.resource_tag import ResourceTag
2828
from labelbox.schema.project_resource_tag import ProjectResourceTag
2929
from labelbox.schema.media_type import MediaType
30-
from labelbox.schema.slice import Slice, CatalogSlice
30+
from labelbox.schema.slice import Slice, CatalogSlice

labelbox/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,8 @@ class MALValidationError(LabelboxError):
129129
class OperationNotAllowedException(Exception):
130130
"""Raised when user does not have permissions to a resource or has exceeded usage limit"""
131131
pass
132+
133+
134+
class ProcessingWaitTimeout(Exception):
135+
"""Raised when waiting for the data rows to be processed takes longer than allowed"""
136+
pass

labelbox/schema/batch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ class Batch(DbObject):
3737
# Relationships
3838
created_by = Relationship.ToOne("User")
3939

40-
def __init__(self, client, project_id, *args, **kwargs):
40+
def __init__(self,
41+
client,
42+
project_id,
43+
*args,
44+
failed_data_row_ids=None,
45+
**kwargs):
4146
super().__init__(client, *args, **kwargs)
4247
self.project_id = project_id
48+
self._failed_data_row_ids = failed_data_row_ids
4349

4450
def project(self) -> 'Project': # type: ignore
4551
""" Returns Project which this Batch belongs to
@@ -174,3 +180,7 @@ def delete_labels(self, set_labels_as_template=False) -> None:
174180
},
175181
experimental=True)
176182
return res
183+
184+
@property
185+
def failed_data_row_ids(self):
186+
return (x for x in self._failed_data_row_ids)

labelbox/schema/project.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
from collections import namedtuple
55
from datetime import datetime, timezone
66
from pathlib import Path
7-
from typing import TYPE_CHECKING, Dict, Union, Iterable, List, Optional, Any
7+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
88
from urllib.parse import urlparse
99

1010
import ndjson
1111
import requests
1212

1313
from labelbox import utils
14-
from labelbox.exceptions import InvalidQueryError, LabelboxError
14+
from labelbox.exceptions import (InvalidQueryError, LabelboxError,
15+
ProcessingWaitTimeout)
1516
from labelbox.orm import query
16-
from labelbox.orm.db_object import DbObject, Updateable, Deletable
17+
from labelbox.orm.db_object import DbObject, Deletable, Updateable
1718
from labelbox.orm.model import Entity, Field, Relationship
1819
from labelbox.pagination import PaginatedCollection
1920
from labelbox.schema.consensus_settings import ConsensusSettings
@@ -90,6 +91,9 @@ class Project(DbObject, Updateable, Deletable):
9091
benchmarks = Relationship.ToMany("Benchmark", False)
9192
ontology = Relationship.ToOne("Ontology", True)
9293

94+
#
95+
_wait_processing_max_seconds = 3600
96+
9397
def update(self, **kwargs):
9498
""" Updates this project with the specified attributes
9599
@@ -319,7 +323,7 @@ def _validate_datetime(string_date: str) -> bool:
319323
return True
320324
except ValueError:
321325
pass
322-
raise ValueError(f"""Incorrect format for: {string_date}.
326+
raise ValueError(f"""Incorrect format for: {string_date}.
323327
Format must be \"YYYY-MM-DD\" or \"YYYY-MM-DD hh:mm:ss\"""")
324328
return True
325329

@@ -595,11 +599,16 @@ def create_batch(self,
595599
if not len(dr_ids):
596600
raise ValueError("You need at least one data row in a batch")
597601

598-
method = 'createBatch'
602+
self._wait_until_data_rows_are_processed(
603+
data_rows, self._wait_processing_max_seconds)
604+
method = 'createBatchV2'
599605
query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
600606
project(where: {id: $projectId}) {
601607
%s(input: $batchInput) {
602-
%s
608+
batch {
609+
%s
610+
}
611+
failedDataRowIds
603612
}
604613
}
605614
}
@@ -622,9 +631,12 @@ def create_batch(self,
622631
params,
623632
timeout=180.0,
624633
experimental=True)["project"][method]
625-
626-
res['size'] = len(dr_ids)
627-
return Entity.Batch(self.client, self.uid, res)
634+
batch = res['batch']
635+
batch['size'] = len(dr_ids)
636+
return Entity.Batch(self.client,
637+
self.uid,
638+
batch,
639+
failed_data_row_ids=res['failedDataRowIds'])
628640

629641
def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode":
630642
"""
@@ -977,6 +989,42 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
977989
raise ValueError(
978990
f'Invalid annotations given of type: {type(annotations)}')
979991

992+
def _wait_until_data_rows_are_processed(self,
993+
data_row_ids: List[str],
994+
wait_processing_max_seconds: int,
995+
sleep_interval=30):
996+
""" Wait until all the specified data rows are processed"""
997+
start_time = datetime.now()
998+
while True:
999+
if (datetime.now() -
1000+
start_time).total_seconds() >= wait_processing_max_seconds:
1001+
raise ProcessingWaitTimeout(
1002+
"Maximum wait time exceeded while waiting for data rows to be processed. Try creating a batch a bit later"
1003+
)
1004+
1005+
all_good = self.__check_data_rows_have_been_processed(data_row_ids)
1006+
if all_good:
1007+
return
1008+
1009+
logger.debug(
1010+
'Some of the data rows are still being processed, waiting...')
1011+
time.sleep(sleep_interval)
1012+
1013+
def __check_data_rows_have_been_processed(self, data_row_ids: List[str]):
1014+
data_row_ids_param = "data_row_ids"
1015+
1016+
query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]!) {
1017+
queryAllDataRowsHaveBeenProcessed(dataRowIds:$%s) {
1018+
allDataRowsHaveBeenProcessed
1019+
}
1020+
}""" % (data_row_ids_param, data_row_ids_param)
1021+
1022+
params = {}
1023+
params[data_row_ids_param] = data_row_ids
1024+
response = self.client.execute(query_str, params)
1025+
return response["queryAllDataRowsHaveBeenProcessed"][
1026+
"allDataRowsHaveBeenProcessed"]
1027+
9801028

9811029
class ProjectMember(DbObject):
9821030
user = Relationship.ToOne("User", cache=True)

tests/integration/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ def dataset(client, rand_gen):
191191
dataset.delete()
192192

193193

194+
@pytest.fixture(scope='function')
195+
def unique_dataset(client, rand_gen):
196+
dataset = client.create_dataset(name=rand_gen(str))
197+
yield dataset
198+
dataset.delete()
199+
200+
194201
@pytest.fixture
195202
def datarow(dataset, image_url):
196203
task = dataset.create_data_rows([

tests/integration/test_batch.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
from labelbox.exceptions import ProcessingWaitTimeout
12
import pytest
2-
33
from labelbox import Dataset, Project
44

55
IMAGE_URL = "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000000034.jpg"
@@ -31,6 +31,23 @@ def small_dataset(dataset: Dataset):
3131
yield dataset
3232

3333

34+
@pytest.fixture(scope='function')
35+
def dataset_with_invalid_data_rows(unique_dataset: Dataset):
36+
upload_invalid_data_rows_for_dataset(unique_dataset)
37+
38+
yield unique_dataset
39+
40+
41+
def upload_invalid_data_rows_for_dataset(dataset: Dataset):
42+
task = dataset.create_data_rows([
43+
{
44+
"row_data": 'gs://lb-test-private/mask-2.png', # forbidden
45+
"external_id": "image-without-access.jpg"
46+
},
47+
] * 2)
48+
task.wait_till_done()
49+
50+
3451
def test_create_batch(batch_project: Project, big_dataset: Dataset):
3552
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
3653
batch = batch_project.create_batch("test-batch", data_rows, 3)
@@ -72,12 +89,63 @@ def test_batch_project(batch_project: Project, small_dataset: Dataset):
7289
data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())]
7390
batch = batch_project.create_batch("batch to test project relationship",
7491
data_rows)
92+
7593
project_from_batch = batch.project()
7694

7795
assert project_from_batch.uid == batch_project.uid
7896
assert project_from_batch.name == batch_project.name
7997

8098

99+
def test_batch_creation_for_data_rows_with_issues(
100+
batch_project: Project, small_dataset: Dataset,
101+
dataset_with_invalid_data_rows: Dataset):
102+
"""
103+
Create a batch containing both valid and invalid data rows
104+
"""
105+
valid_data_rows = [dr.uid for dr in list(small_dataset.data_rows())]
106+
invalid_data_rows = [
107+
dr.uid for dr in list(dataset_with_invalid_data_rows.data_rows())
108+
]
109+
data_rows_to_add = valid_data_rows + invalid_data_rows
110+
111+
assert len(data_rows_to_add) == 5
112+
batch = batch_project.create_batch("batch to test failed data rows",
113+
data_rows_to_add)
114+
failed_data_row_ids = [x for x in batch.failed_data_row_ids]
115+
assert len(failed_data_row_ids) == 2
116+
117+
failed_data_row_ids_set = set(failed_data_row_ids)
118+
invalid_data_rows_set = set(invalid_data_rows)
119+
assert len(failed_data_row_ids_set.intersection(invalid_data_rows_set)) == 2
120+
121+
122+
def test_batch_creation_with_processing_timeout(batch_project: Project,
123+
small_dataset: Dataset,
124+
unique_dataset: Dataset):
125+
"""
126+
Create a batch with zero wait time, this means that the waiting logic will throw exception immediately
127+
"""
128+
# wait for these data rows to be processed
129+
valid_data_rows = [dr.uid for dr in list(small_dataset.data_rows())]
130+
batch_project._wait_until_data_rows_are_processed(
131+
valid_data_rows, wait_processing_max_seconds=3600, sleep_interval=5)
132+
133+
# upload data rows for this dataset and don't wait
134+
upload_invalid_data_rows_for_dataset(unique_dataset)
135+
unprocessed_data_rows = [dr.uid for dr in list(unique_dataset.data_rows())]
136+
137+
data_row_ids = valid_data_rows + unprocessed_data_rows
138+
139+
stashed_wait_timeout = batch_project._wait_processing_max_seconds
140+
with pytest.raises(ProcessingWaitTimeout):
141+
# emulate the situation where there are still some data rows being
142+
# processed but wait timeout exceeded
143+
batch_project._wait_processing_max_seconds = 0
144+
batch_project.create_batch("batch to test failed data rows",
145+
data_row_ids)
146+
batch_project._wait_processing_max_seconds = stashed_wait_timeout
147+
148+
81149
def test_export_data_rows(batch_project: Project, dataset: Dataset):
82150
n_data_rows = 5
83151
task = dataset.create_data_rows([

0 commit comments

Comments
 (0)