Skip to content

Commit 1823ad6

Browse files
committed
AL-4081: Extended create_batch method with DRPS logic
1 parent 70484dd commit 1823ad6

File tree

7 files changed

+145
-20
lines changed

7 files changed

+145
-20
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@
2727
from labelbox.schema.resource_tag import ResourceTag
2828
from labelbox.schema.project_resource_tag import ProjectResourceTag
2929
from labelbox.schema.media_type import MediaType
30-
from labelbox.schema.slice import Slice, CatalogSlice
30+
from labelbox.schema.slice import Slice, CatalogSlice

labelbox/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def get_data_row_ids_for_external_ids(
751751
for row in self.execute(
752752
query_str,
753753
{'externalId_in': external_ids[i:i + max_ids_per_request]
754-
})['externalIdsToDataRowIds']:
754+
})['externalIdsToDataRowIds']:
755755
result[row['externalId']].append(row['dataRowId'])
756756
return result
757757

@@ -1058,7 +1058,7 @@ def _format_failed_rows(rows: Dict[str, str],
10581058
result_params = {
10591059
"jobId":
10601060
assign_global_keys_to_data_rows_job["assignGlobalKeysToDataRows"
1061-
]["jobId"]
1061+
]["jobId"]
10621062
}
10631063

10641064
# Poll job status until finished, then retrieve results

labelbox/schema/batch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ class Batch(DbObject):
3636
# Relationships
3737
created_by = Relationship.ToOne("User")
3838

39-
def __init__(self, client, project_id, *args, **kwargs):
39+
def __init__(self, client, project_id, *args, failed_data_row_ids=None, **kwargs):
4040
super().__init__(client, *args, **kwargs)
4141
self.project_id = project_id
42+
self._failed_data_row_ids = failed_data_row_ids
4243

4344
def project(self) -> 'Project': # type: ignore
4445
""" Returns Project which this Batch belongs to
@@ -75,7 +76,7 @@ def remove_queued_data_rows(self) -> None:
7576
batch_id_param), {
7677
project_id_param: self.project_id,
7778
batch_id_param: self.uid
78-
},
79+
},
7980
experimental=True)
8081

8182
def export_data_rows(self,
@@ -144,8 +145,8 @@ def delete(self) -> None:
144145
batch_id_param), {
145146
project_id_param: self.project_id,
146147
batch_id_param: self.uid
147-
},
148-
experimental=True)
148+
},
149+
experimental=True)
149150

150151
def delete_labels(self, set_labels_as_template=False) -> None:
151152
""" Deletes labels that were created for data rows in the batch.
@@ -170,6 +171,10 @@ def delete_labels(self, set_labels_as_template=False) -> None:
170171
type_param:
171172
"RequeueDataWithLabelAsTemplate"
172173
if set_labels_as_template else "RequeueData"
173-
},
174+
},
174175
experimental=True)
175176
return res
177+
178+
@property
179+
def failed_data_row_ids(self):
180+
return self._failed_data_row_ids

labelbox/schema/project.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
from collections import namedtuple
55
from datetime import datetime, timezone
66
from pathlib import Path
7-
from typing import TYPE_CHECKING, Dict, Union, Iterable, List, Optional, Any
7+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
88
from urllib.parse import urlparse
99

1010
import ndjson
1111
import requests
12-
1312
from labelbox import utils
1413
from labelbox.exceptions import InvalidQueryError, LabelboxError
1514
from labelbox.orm import query
16-
from labelbox.orm.db_object import DbObject, Updateable, Deletable
15+
from labelbox.orm.db_object import DbObject, Deletable, Updateable
1716
from labelbox.orm.model import Entity, Field, Relationship
1817
from labelbox.pagination import PaginatedCollection
1918
from labelbox.schema.media_type import MediaType
@@ -318,7 +317,7 @@ def _validate_datetime(string_date: str) -> bool:
318317
return True
319318
except ValueError:
320319
pass
321-
raise ValueError(f"""Incorrect format for: {string_date}.
320+
raise ValueError(f"""Incorrect format for: {string_date}.
322321
Format must be \"YYYY-MM-DD\" or \"YYYY-MM-DD hh:mm:ss\"""")
323322
return True
324323

@@ -561,7 +560,7 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None:
561560
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
562561
self.update(setup_complete=timestamp)
563562

564-
def create_batch(self, name: str, data_rows: List[str], priority: int = 5):
563+
def create_batch(self, name: str, data_rows: List[str], priority: int = 5, wait_processing_max_seconds: int = 5):
565564
"""Create a new batch for a project. Batches is in Beta and subject to change
566565
567566
Args:
@@ -590,11 +589,18 @@ def create_batch(self, name: str, data_rows: List[str], priority: int = 5):
590589
if not len(dr_ids):
591590
raise ValueError("You need at least one data row in a batch")
592591

593-
method = 'createBatch'
592+
self._wait_until_data_rows_are_processed(
593+
data_rows,
594+
wait_processing_max_seconds=wait_processing_max_seconds
595+
)
596+
method = 'createBatchV2'
594597
query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
595598
project(where: {id: $projectId}) {
596599
%s(input: $batchInput) {
597-
%s
600+
batch{
601+
%s
602+
}
603+
failedDataRowIds
598604
}
599605
}
600606
}
@@ -613,9 +619,9 @@ def create_batch(self, name: str, data_rows: List[str], priority: int = 5):
613619
params,
614620
timeout=180.0,
615621
experimental=True)["project"][method]
616-
617-
res['size'] = len(dr_ids)
618-
return Entity.Batch(self.client, self.uid, res)
622+
batch = res['batch']
623+
batch['size'] = len(dr_ids)
624+
return Entity.Batch(self.client, self.uid, batch, failed_data_row_ids=res['failedDataRowIds'])
619625

620626
def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode":
621627
"""
@@ -964,6 +970,34 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
964970
raise ValueError(
965971
f'Invalid annotations given of type: {type(annotations)}')
966972

973+
def _wait_until_data_rows_are_processed(self, data_row_ids: List[str], wait_processing_max_seconds: int, sleep_interval=30):
974+
""" Wait until all the specified data rows are processed"""
975+
start_time = datetime.now()
976+
while True:
977+
if (datetime.now() - start_time).total_seconds() >= wait_processing_max_seconds:
978+
logger.warning(
979+
"""Not all data rows have been processed, proceeding anyway""")
980+
return
981+
982+
all_good = self.__check_data_rows_have_been_processed(data_row_ids)
983+
if all_good:
984+
return
985+
time.sleep(sleep_interval)
986+
987+
def __check_data_rows_have_been_processed(self, data_row_ids: List[str]):
988+
data_row_ids_param = "data_row_ids"
989+
990+
query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]!) {
991+
queryAllDataRowsHaveBeenProcessed(dataRowIds:$%s) {
992+
allDataRowsHaveBeenProcessed
993+
}
994+
}""" % (data_row_ids_param, data_row_ids_param)
995+
996+
params = {}
997+
params[data_row_ids_param] = data_row_ids
998+
response = self.client.execute(query_str, params)
999+
return response["queryAllDataRowsHaveBeenProcessed"]["allDataRowsHaveBeenProcessed"]
1000+
9671001

9681002
class ProjectMember(DbObject):
9691003
user = Relationship.ToOne("User", cache=True)

pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[pytest]
2-
addopts = -s -vv -x --reruns 5 --reruns-delay 10 --durations=20
2+
addopts = -s -vv -x
33
markers =
44
slow: marks tests as slow (deselect with '-m "not slow"')

tests/integration/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ def dataset(client, rand_gen):
188188
yield dataset
189189
dataset.delete()
190190

191+
@pytest.fixture(scope='function')
192+
def unique_dataset(client, rand_gen):
193+
dataset = client.create_dataset(name=rand_gen(str))
194+
yield dataset
195+
dataset.delete()
196+
191197

192198
@pytest.fixture
193199
def datarow(dataset, image_url):

tests/integration/test_batch.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import pytest
1+
import warnings
22

3+
import pytest
34
from labelbox import Dataset, Project
45
from labelbox.schema.queue_mode import QueueMode
56

@@ -32,6 +33,23 @@ def small_dataset(dataset: Dataset):
3233
yield dataset
3334

3435

36+
@pytest.fixture(scope='function')
37+
def dataset_with_invalid_data_rows(unique_dataset: Dataset):
38+
upload_invalid_data_rows_for_dataset(unique_dataset)
39+
40+
yield unique_dataset
41+
42+
43+
def upload_invalid_data_rows_for_dataset(dataset: Dataset):
44+
task = dataset.create_data_rows([
45+
{
46+
"row_data": 'https://jakub-da-test-primary.s3.us-east-2.amazonaws.com/dogecoin-whitepaper.pdf',
47+
"external_id": "my-pdf"
48+
},
49+
] * 2)
50+
task.wait_till_done()
51+
52+
3553
def test_create_batch(batch_project: Project, big_dataset: Dataset):
3654
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
3755
batch = batch_project.create_batch("test-batch", data_rows, 3)
@@ -60,12 +78,74 @@ def test_batch_project(batch_project: Project, small_dataset: Dataset):
6078
data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())]
6179
batch = batch_project.create_batch("batch to test project relationship",
6280
data_rows)
81+
6382
project_from_batch = batch.project()
6483

6584
assert project_from_batch.uid == batch_project.uid
6685
assert project_from_batch.name == batch_project.name
6786

6887

88+
def test_batch_creation_for_data_rows_with_issues(
89+
batch_project: Project,
90+
small_dataset: Dataset,
91+
dataset_with_invalid_data_rows: Dataset
92+
):
93+
"""
94+
Create a batch containing both valid and invalid data rows
95+
"""
96+
valid_data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())]
97+
invalid_data_rows = [dr.uid for dr in list(
98+
dataset_with_invalid_data_rows.export_data_rows())]
99+
data_rows_to_add = valid_data_rows + invalid_data_rows
100+
101+
assert len(data_rows_to_add) == 5
102+
batch = batch_project.create_batch(
103+
"batch to test failed data rows",
104+
data_rows_to_add
105+
)
106+
107+
assert len(batch.failed_data_row_ids) == 2
108+
109+
failed_data_row_ids_set = set(batch.failed_data_row_ids)
110+
invalid_data_rows_set = set(invalid_data_rows)
111+
assert len(failed_data_row_ids_set.intersection(
112+
invalid_data_rows_set)) == 2
113+
114+
115+
def test_batch_creation_with_processing_timeout(
116+
batch_project: Project,
117+
small_dataset: Dataset,
118+
unique_dataset: Dataset
119+
):
120+
"""
121+
Create a batch with zero wait time, this means that the waiting will termintate instantly
122+
"""
123+
# wait for these data rows to be processed
124+
valid_data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())]
125+
batch_project._wait_until_data_rows_are_processed(
126+
valid_data_rows, wait_processing_max_seconds=3600, sleep_interval=5
127+
)
128+
129+
# upload data rows for this dataset and don't wait
130+
upload_invalid_data_rows_for_dataset(unique_dataset)
131+
unprocessed_data_rows = [dr.uid for dr in list(
132+
unique_dataset.export_data_rows())]
133+
134+
data_row_ids = valid_data_rows + unprocessed_data_rows
135+
with warnings.catch_warnings(record=True) as w:
136+
warnings.simplefilter("always")
137+
breakpoint()
138+
batch_project.create_batch(
139+
"batch to test failed data rows",
140+
data_row_ids,
141+
wait_processing_max_seconds=0
142+
)
143+
assert len(w) == 1
144+
assert issubclass(w[-1].category, DeprecationWarning)
145+
assert "Not all data rows have been processed, proceeding anyway" in str(
146+
w[-1].message)
147+
148+
69149
def test_export_data_rows(batch_project: Project, dataset: Dataset):
70150
n_data_rows = 5
71151
task = dataset.create_data_rows([

0 commit comments

Comments
 (0)