Skip to content

Commit 15409d8

Browse files
authored
[AL-4867] Create batch using global keys
2 parents 2b0cc80 + a40ab66 commit 15409d8

File tree

5 files changed

+96
-45
lines changed

5 files changed

+96
-45
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ jobs:
8989

9090
DA_GCP_LABELBOX_API_KEY: ${{ secrets[matrix.da-test-key] }}
9191
run: |
92-
tox -e py -- -svv --reruns 5 --reruns-delay 10
92+
tox -e py -- -n 10 -svv --reruns 5 --reruns-delay 10

labelbox/schema/project.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -666,16 +666,20 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None:
666666
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
667667
self.update(setup_complete=timestamp)
668668

669-
def create_batch(self,
670-
name: str,
671-
data_rows: List[Union[str, DataRow]],
672-
priority: int = 5,
673-
consensus_settings: Optional[Dict[str, float]] = None):
674-
"""Create a new batch for a project. Batches is in Beta and subject to change
669+
def create_batch(
670+
self,
671+
name: str,
672+
data_rows: Optional[List[Union[str, DataRow]]] = None,
673+
priority: int = 5,
674+
consensus_settings: Optional[Dict[str, float]] = None,
675+
global_keys: Optional[List[str]] = None,
676+
):
677+
"""Create a new batch for a project. One of `global_keys` or `data_rows` must be provided but not both.
675678
676679
Args:
677680
name: a name for the batch, must be unique within a project
678-
data_rows: Either a list of `DataRows` or Data Row ids
681+
data_rows: Either a list of `DataRows` or Data Row ids.
682+
global_keys: global keys for data rows to add to the batch.
679683
priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest
680684
consensus_settings: An optional dictionary with consensus settings: {'number_of_labels': 3, 'coverage_percentage': 0.1}
681685
"""
@@ -685,35 +689,45 @@ def create_batch(self,
685689
raise ValueError("Project must be in batch mode")
686690

687691
dr_ids = []
688-
for dr in data_rows:
689-
if isinstance(dr, Entity.DataRow):
690-
dr_ids.append(dr.uid)
691-
elif isinstance(dr, str):
692-
dr_ids.append(dr)
693-
else:
694-
raise ValueError("You can DataRow ids or DataRow objects")
692+
if data_rows is not None:
693+
for dr in data_rows:
694+
if isinstance(dr, Entity.DataRow):
695+
dr_ids.append(dr.uid)
696+
elif isinstance(dr, str):
697+
dr_ids.append(dr)
698+
else:
699+
raise ValueError(
700+
"`data_rows` must be DataRow ids or DataRow objects")
701+
702+
if data_rows is not None:
703+
row_count = len(data_rows)
704+
elif global_keys is not None:
705+
row_count = len(global_keys)
706+
else:
707+
row_count = 0
695708

696-
if len(dr_ids) > 100_000:
709+
if row_count > 100_000:
697710
raise ValueError(
698711
f"Batch exceeds max size, break into smaller batches")
699-
if not len(dr_ids):
712+
if not row_count:
700713
raise ValueError("You need at least one data row in a batch")
701714

702715
self._wait_until_data_rows_are_processed(
703-
dr_ids, self._wait_processing_max_seconds)
716+
dr_ids, global_keys, self._wait_processing_max_seconds)
704717

705718
if consensus_settings:
706719
consensus_settings = ConsensusSettings(**consensus_settings).dict(
707720
by_alias=True)
708721

709722
if len(dr_ids) >= 10_000:
710-
return self._create_batch_async(name, dr_ids, priority,
723+
return self._create_batch_async(name, dr_ids, global_keys, priority,
711724
consensus_settings)
712725
else:
713-
return self._create_batch_sync(name, dr_ids, priority,
726+
return self._create_batch_sync(name, dr_ids, global_keys, priority,
714727
consensus_settings)
715728

716-
def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
729+
def _create_batch_sync(self, name, dr_ids, global_keys, priority,
730+
consensus_settings):
717731
method = 'createBatchV2'
718732
query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
719733
project(where: {id: $projectId}) {
@@ -731,6 +745,7 @@ def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
731745
"batchInput": {
732746
"name": name,
733747
"dataRowIds": dr_ids,
748+
"globalKeys": global_keys,
734749
"priority": priority,
735750
"consensusSettings": consensus_settings
736751
}
@@ -748,7 +763,8 @@ def _create_batch_sync(self, name, dr_ids, priority, consensus_settings):
748763

749764
def _create_batch_async(self,
750765
name: str,
751-
dr_ids: List[str],
766+
dr_ids: Optional[List[str]] = None,
767+
global_keys: Optional[List[str]] = None,
752768
priority: int = 5,
753769
consensus_settings: Optional[Dict[str,
754770
float]] = None):
@@ -791,6 +807,7 @@ def _create_batch_async(self,
791807
"input": {
792808
"batchId": batch_id,
793809
"dataRowIds": dr_ids,
810+
"globalKeys": global_keys,
794811
"priority": priority,
795812
}
796813
}
@@ -1257,38 +1274,50 @@ def _is_url_valid(url: Union[str, Path]) -> bool:
12571274
raise ValueError(
12581275
f'Invalid annotations given of type: {type(annotations)}')
12591276

1260-
def _wait_until_data_rows_are_processed(self,
1261-
data_row_ids: List[str],
1262-
wait_processing_max_seconds: int,
1263-
sleep_interval=30):
1277+
def _wait_until_data_rows_are_processed(
1278+
self,
1279+
data_row_ids: Optional[List[str]] = None,
1280+
global_keys: Optional[List[str]] = None,
1281+
wait_processing_max_seconds: int = _wait_processing_max_seconds,
1282+
sleep_interval=30):
12641283
""" Wait until all the specified data rows are processed"""
12651284
start_time = datetime.now()
1285+
12661286
while True:
12671287
if (datetime.now() -
12681288
start_time).total_seconds() >= wait_processing_max_seconds:
12691289
raise ProcessingWaitTimeout(
12701290
"Maximum wait time exceeded while waiting for data rows to be processed. Try creating a batch a bit later"
12711291
)
12721292

1273-
all_good = self.__check_data_rows_have_been_processed(data_row_ids)
1293+
all_good = self.__check_data_rows_have_been_processed(
1294+
data_row_ids, global_keys)
12741295
if all_good:
12751296
return
12761297

12771298
logger.debug(
12781299
'Some of the data rows are still being processed, waiting...')
12791300
time.sleep(sleep_interval)
12801301

1281-
def __check_data_rows_have_been_processed(self, data_row_ids: List[str]):
1282-
data_row_ids_param = "data_row_ids"
1302+
def __check_data_rows_have_been_processed(
1303+
self,
1304+
data_row_ids: Optional[List[str]] = None,
1305+
global_keys: Optional[List[str]] = None):
1306+
1307+
if data_row_ids is not None and len(data_row_ids) > 0:
1308+
param_name = "dataRowIds"
1309+
params = {param_name: data_row_ids}
1310+
else:
1311+
param_name = "globalKeys"
1312+
global_keys = global_keys if global_keys is not None else []
1313+
params = {param_name: global_keys}
12831314

1284-
query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]!) {
1285-
queryAllDataRowsHaveBeenProcessed(dataRowIds:$%s) {
1315+
query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]) {
1316+
queryAllDataRowsHaveBeenProcessed(%s:$%s) {
12861317
allDataRowsHaveBeenProcessed
12871318
}
1288-
}""" % (data_row_ids_param, data_row_ids_param)
1319+
}""" % (param_name, param_name, param_name)
12891320

1290-
params = {}
1291-
params[data_row_ids_param] = data_row_ids
12921321
response = self.client.execute(query_str, params)
12931322
return response["queryAllDataRowsHaveBeenProcessed"][
12941323
"allDataRowsHaveBeenProcessed"]

tests/integration/conftest.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,16 @@ def datarow(dataset, image_url):
231231

232232
@pytest.fixture()
233233
def data_rows(dataset, image_url):
234-
dr1 = dataset.create_data_row(row_data=image_url,
235-
global_key=f"global-key-{uuid.uuid4()}")
236-
dr2 = dataset.create_data_row(row_data=image_url,
237-
global_key=f"global-key-{uuid.uuid4()}")
238-
yield [dr1, dr2]
239-
dr1.delete()
240-
dr2.delete()
234+
dr1 = dict(row_data=image_url, global_key=f"global-key-{uuid.uuid4()}")
235+
dr2 = dict(row_data=image_url, global_key=f"global-key-{uuid.uuid4()}")
236+
task = dataset.create_data_rows([dr1, dr2])
237+
task.wait_till_done()
238+
239+
drs = list(dataset.export_data_rows())
240+
yield drs
241+
242+
for dr in drs:
243+
dr.delete()
241244

242245

243246
@pytest.fixture

tests/integration/test_batch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def test_create_batch_async(batch_project: Project, big_dataset: Dataset):
5959
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
6060
batch_project._wait_until_data_rows_are_processed(
6161
data_rows, batch_project._wait_processing_max_seconds)
62-
batch = batch_project._create_batch_async("big-batch", data_rows, 3)
62+
batch = batch_project._create_batch_async("big-batch",
63+
data_rows,
64+
priority=3)
6365
assert batch.name == "big-batch"
6466
assert batch.size == len(data_rows)
6567

tests/integration/test_project.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import json
21
import time
32
import os
3+
import uuid
44

55
import pytest
66
import requests
@@ -244,15 +244,32 @@ def test_batches(batch_project: Project, dataset: Dataset, image_url):
244244
] * 2)
245245
task.wait_till_done()
246246
data_rows = [dr.uid for dr in list(dataset.export_data_rows())]
247-
batch_one = 'batch one'
248-
batch_two = 'batch two'
247+
batch_one = f'batch one {uuid.uuid4()}'
248+
batch_two = f'batch two {uuid.uuid4()}'
249249
batch_project.create_batch(batch_one, [data_rows[0]])
250250
batch_project.create_batch(batch_two, [data_rows[1]])
251251

252252
names = set([batch.name for batch in list(batch_project.batches())])
253253
assert names == {batch_one, batch_two}
254254

255255

256+
def test_create_batch_with_global_keys_sync(batch_project: Project, data_rows):
257+
global_keys = [dr.global_key for dr in data_rows]
258+
batch_name = f'batch {uuid.uuid4()}'
259+
batch = batch_project.create_batch(batch_name, global_keys=global_keys)
260+
batch_data_rows = set(batch.export_data_rows())
261+
assert batch_data_rows == set(data_rows)
262+
263+
264+
def test_create_batch_with_global_keys_async(batch_project: Project, data_rows):
265+
global_keys = [dr.global_key for dr in data_rows]
266+
batch_name = f'batch {uuid.uuid4()}'
267+
batch = batch_project._create_batch_async(batch_name,
268+
global_keys=global_keys)
269+
batch_data_rows = set(batch.export_data_rows())
270+
assert batch_data_rows == set(data_rows)
271+
272+
256273
def test_media_type(client, configured_project: Project, rand_gen):
257274
# Existing project with no media_type
258275
assert isinstance(configured_project.media_type, MediaType)

0 commit comments

Comments
 (0)