Skip to content

Commit a40ab66

Browse files
committed
[AL-4867] Create batch using global keys
1 parent 6816ffa commit a40ab66

File tree

5 files changed

+96
-45
lines changed

5 files changed

+96
-45
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ jobs:
8989

9090
DA_GCP_LABELBOX_API_KEY: ${{ secrets[matrix.da-test-key] }}
9191
run: |
92-
tox -e py -- -svv --reruns 5 --reruns-delay 10
92+
tox -e py -- -n 10 -svv --reruns 5 --reruns-delay 10

labelbox/schema/project.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -669,16 +669,20 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None:
669669
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
670670
self.update(setup_complete=timestamp)
671671

672-
def create_batch(self,
673-
name: str,
674-
data_rows: List[Union[str, DataRow]],
675-
priority: int = 5,
676-
consensus_settings: Optional[Dict[str, float]] = None):
677-
"""Create a new batch for a project. Batches is in Beta and subject to change
672+
def create_batch(
673+
self,
674+
name: str,
675+
data_rows: Optional[List[Union[str, DataRow]]] = None,
676+
priority: int = 5,
677+
consensus_settings: Optional[Dict[str, float]] = None,
678+
global_keys: Optional[List[str]] = None,
679+
):
680+
"""Create a new batch for a project. One of `global_keys` or `data_rows` must be provided but not both.
678681
679682
Args:
680683
name: a name for the batch, must be unique within a project
681-
data_rows: Either a list of `DataRows` or Data Row ids
684+
data_rows: Either a list of `DataRows` or Data Row ids.
685+
global_keys: global keys for data rows to add to the batch.
682686
priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest
683687
consensus_settings: An optional dictionary with consensus settings: {'number_of_labels': 3, 'coverage_percentage': 0.1}
684688
"""
@@ -688,35 +692,45 @@ def create_batch(self,
688692
raise ValueError("Project must be in batch mode")
689693

690694
dr_ids = []
691-
for dr in data_rows:
692-
if isinstance(dr, Entity.DataRow):
693-
dr_ids.append(dr.uid)
694-
elif isinstance(dr, str):
695-
dr_ids.append(dr)
696-
else:
697-
raise ValueError("You can DataRow ids or DataRow objects")
695+
if data_rows is not None:
696+
for dr in data_rows:
697+
if isinstance(dr, Entity.DataRow):
698+
dr_ids.append(dr.uid)
699+
elif isinstance(dr, str):
700+
dr_ids.append(dr)
701+
else:
702+
raise ValueError(
703+
"`data_rows` must be DataRow ids or DataRow objects")
704+
705+
if data_rows is not None:
706+
row_count = len(data_rows)
707+
elif global_keys is not None:
708+
row_count = len(global_keys)
709+
else:
710+
row_count = 0
698711

699-
if len(dr_ids) > 100_000:
712+
if row_count > 100_000:
700713
raise ValueError(
701714
f"Batch exceeds max size, break into smaller batches")
702-
if not len(dr_ids):
715+
if not row_count:
703716
raise ValueError("You need at least one data row in a batch")
704717

705718
self._wait_until_data_rows_are_processed(
706-
dr_ids, self._wait_processing_max_seconds)
719+
dr_ids, global_keys, self._wait_processing_max_seconds)
707720

708721
if consensus_settings:
709722
consensus_settings = ConsensusSettings(**consensus_settings).dict(
710723
by_alias=True)
711724

712725
if len(dr_ids) >= 10_000:
713-
return self._create_batch_async(name, dr_ids, priority,
726+
return self._create_batch_async(name, dr_ids, global_keys, priority,
714727
consensus_settings)
715728
else:
716-
return self._create_batch_sync(name, dr_ids, priority,
729+
return self._create_batch_sync(name, dr_ids, global_keys, priority,
717730
consensus_settings)
718731

719-
def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
732+
def _create_batch_sync(self, name, dr_ids, global_keys, priority,
733+
consensus_settings):
720734
method = 'createBatchV2'
721735
query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
722736
project(where: {id: $projectId}) {
@@ -734,6 +748,7 @@ def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
734748
"batchInput": {
735749
"name": name,
736750
"dataRowIds": dr_ids,
751+
"globalKeys": global_keys,
737752
"priority": priority,
738753
"consensusSettings": consensus_settings
739754
}
@@ -751,7 +766,8 @@ def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
751766

752767
def _create_batch_async(self,
753768
name: str,
754-
dr_ids: List[str],
769+
dr_ids: Optional[List[str]] = None,
770+
global_keys: Optional[List[str]] = None,
755771
priority: int = 5,
756772
consensus_settings: Optional[Dict[str,
757773
float]] = None):
@@ -794,6 +810,7 @@ def _create_batch_async(self,
794810
"input": {
795811
"batchId": batch_id,
796812
"dataRowIds": dr_ids,
813+
"globalKeys": global_keys,
797814
"priority": priority,
798815
}
799816
}
@@ -1260,38 +1277,50 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
12601277
raise ValueError(
12611278
f'Invalid annotations given of type: {type(annotations)}')
12621279

1263-
def _wait_until_data_rows_are_processed(self,
1264-
data_row_ids: List[str],
1265-
wait_processing_max_seconds: int,
1266-
sleep_interval=30):
1280+
def _wait_until_data_rows_are_processed(
1281+
self,
1282+
data_row_ids: Optional[List[str]] = None,
1283+
global_keys: Optional[List[str]] = None,
1284+
wait_processing_max_seconds: int = _wait_processing_max_seconds,
1285+
sleep_interval=30):
12671286
""" Wait until all the specified data rows are processed"""
12681287
start_time = datetime.now()
1288+
12691289
while True:
12701290
if (datetime.now() -
12711291
start_time).total_seconds() >= wait_processing_max_seconds:
12721292
raise ProcessingWaitTimeout(
12731293
"Maximum wait time exceeded while waiting for data rows to be processed. Try creating a batch a bit later"
12741294
)
12751295

1276-
all_good = self.__check_data_rows_have_been_processed(data_row_ids)
1296+
all_good = self.__check_data_rows_have_been_processed(
1297+
data_row_ids, global_keys)
12771298
if all_good:
12781299
return
12791300

12801301
logger.debug(
12811302
'Some of the data rows are still being processed, waiting...')
12821303
time.sleep(sleep_interval)
12831304

1284-
def __check_data_rows_have_been_processed(self, data_row_ids: List[str]):
1285-
data_row_ids_param = "data_row_ids"
1305+
def __check_data_rows_have_been_processed(
1306+
self,
1307+
data_row_ids: Optional[List[str]] = None,
1308+
global_keys: Optional[List[str]] = None):
1309+
1310+
if data_row_ids is not None and len(data_row_ids) > 0:
1311+
param_name = "dataRowIds"
1312+
params = {param_name: data_row_ids}
1313+
else:
1314+
param_name = "globalKeys"
1315+
global_keys = global_keys if global_keys is not None else []
1316+
params = {param_name: global_keys}
12861317

1287-
query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]!) {
1288-
queryAllDataRowsHaveBeenProcessed(dataRowIds:$%s) {
1318+
query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]) {
1319+
queryAllDataRowsHaveBeenProcessed(%s:$%s) {
12891320
allDataRowsHaveBeenProcessed
12901321
}
1291-
}""" % (data_row_ids_param, data_row_ids_param)
1322+
}""" % (param_name, param_name, param_name)
12921323

1293-
params = {}
1294-
params[data_row_ids_param] = data_row_ids
12951324
response = self.client.execute(query_str, params)
12961325
return response["queryAllDataRowsHaveBeenProcessed"][
12971326
"allDataRowsHaveBeenProcessed"]

tests/integration/conftest.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,16 @@ def datarow(dataset, image_url):
214214

215215
@pytest.fixture()
216216
def data_rows(dataset, image_url):
217-
dr1 = dataset.create_data_row(row_data=image_url,
218-
global_key=f"global-key-{uuid.uuid4()}")
219-
dr2 = dataset.create_data_row(row_data=image_url,
220-
global_key=f"global-key-{uuid.uuid4()}")
221-
yield [dr1, dr2]
222-
dr1.delete()
223-
dr2.delete()
217+
dr1 = dict(row_data=image_url, global_key=f"global-key-{uuid.uuid4()}")
218+
dr2 = dict(row_data=image_url, global_key=f"global-key-{uuid.uuid4()}")
219+
task = dataset.create_data_rows([dr1, dr2])
220+
task.wait_till_done()
221+
222+
drs = list(dataset.export_data_rows())
223+
yield drs
224+
225+
for dr in drs:
226+
dr.delete()
224227

225228

226229
@pytest.fixture

tests/integration/test_batch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def test_create_batch_async(batch_project: Project, big_dataset: Dataset):
5959
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
6060
batch_project._wait_until_data_rows_are_processed(
6161
data_rows, batch_project._wait_processing_max_seconds)
62-
batch = batch_project._create_batch_async("big-batch", data_rows, 3)
62+
batch = batch_project._create_batch_async("big-batch",
63+
data_rows,
64+
priority=3)
6365
assert batch.name == "big-batch"
6466
assert batch.size == len(data_rows)
6567

tests/integration/test_project.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import json
21
import time
32
import os
3+
import uuid
44

55
import pytest
66
import requests
@@ -244,15 +244,32 @@ def test_batches(batch_project: Project, dataset: Dataset, image_url):
244244
] * 2)
245245
task.wait_till_done()
246246
data_rows = [dr.uid for dr in list(dataset.export_data_rows())]
247-
batch_one = 'batch one'
248-
batch_two = 'batch two'
247+
batch_one = f'batch one {uuid.uuid4()}'
248+
batch_two = f'batch two {uuid.uuid4()}'
249249
batch_project.create_batch(batch_one, [data_rows[0]])
250250
batch_project.create_batch(batch_two, [data_rows[1]])
251251

252252
names = set([batch.name for batch in list(batch_project.batches())])
253253
assert names == {batch_one, batch_two}
254254

255255

256+
def test_create_batch_with_global_keys_sync(batch_project: Project, data_rows):
257+
global_keys = [dr.global_key for dr in data_rows]
258+
batch_name = f'batch {uuid.uuid4()}'
259+
batch = batch_project.create_batch(batch_name, global_keys=global_keys)
260+
batch_data_rows = set(batch.export_data_rows())
261+
assert batch_data_rows == set(data_rows)
262+
263+
264+
def test_create_batch_with_global_keys_async(batch_project: Project, data_rows):
265+
global_keys = [dr.global_key for dr in data_rows]
266+
batch_name = f'batch {uuid.uuid4()}'
267+
batch = batch_project._create_batch_async(batch_name,
268+
global_keys=global_keys)
269+
batch_data_rows = set(batch.export_data_rows())
270+
assert batch_data_rows == set(data_rows)
271+
272+
256273
def test_media_type(client, configured_project: Project, rand_gen):
257274
# Existing project with no media_type
258275
assert isinstance(configured_project.media_type, MediaType)

0 commit comments

Comments
 (0)