Skip to content

Commit eb7cff0

Browse files
committed
[AL-2075] Batch list and export
1 parent 55b96ba commit eb7cff0

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

labelbox/schema/batch.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1-
from labelbox.orm.db_object import DbObject
2-
from labelbox.orm.model import Field, Relationship
1+
from typing import Generator
2+
from labelbox.orm.db_object import DbObject, experimental
3+
from labelbox.orm.model import Entity, Field, Relationship
4+
from labelbox.exceptions import LabelboxError
5+
from io import StringIO
6+
import ndjson
7+
import requests
8+
import logging
9+
import time
10+
11+
logger = logging.getLogger(__name__)
312

413

514
class Batch(DbObject):
@@ -19,7 +28,50 @@ class Batch(DbObject):
1928
created_at = Field.DateTime("created_at")
2029
updated_at = Field.DateTime("updated_at")
2130
size = Field.Int("size")
31+
archived_at = Field.DateTime("archived_at")
2232

2333
# Relationships
2434
project = Relationship.ToOne("Project")
2535
created_by = Relationship.ToOne("User")
36+
37+
def export_data_rows(self, timeout_seconds=120) -> Generator:
38+
""" Returns a generator that produces all data rows that are currently
39+
in this batch.
40+
41+
Note: For efficiency, the data are cached for 30 minutes. Newly created data rows will not appear
42+
until the end of the cache period.
43+
44+
Args:
45+
timeout_seconds (float): Max waiting time, in seconds.
46+
Returns:
47+
Generator that yields DataRow objects belonging to this batch.
48+
Raises:
49+
LabelboxError: if the export fails or is unable to download within the specified time.
50+
"""
51+
id_param = "batchId"
52+
query_str = """mutation GetBatchDataRowsExportUrlPyApi($%s: ID!)
53+
{exportBatchDataRows(data:{batchId: $%s }) {downloadUrl createdAt status}}
54+
""" % (id_param, id_param)
55+
sleep_time = 2
56+
while True:
57+
res = self.client.execute(query_str, {id_param: self.uid})
58+
res = res["exportBatchDataRows"]
59+
if res["status"] == "COMPLETE":
60+
download_url = res["downloadUrl"]
61+
response = requests.get(download_url)
62+
response.raise_for_status()
63+
reader = ndjson.reader(StringIO(response.text))
64+
return (
65+
Entity.DataRow(self.client, result) for result in reader)
66+
elif res["status"] == "FAILED":
67+
raise LabelboxError("Data row export failed.")
68+
69+
timeout_seconds -= sleep_time
70+
if timeout_seconds <= 0:
71+
raise LabelboxError(
72+
f"Unable to export data rows within {timeout_seconds} seconds."
73+
)
74+
75+
logger.debug("Batch '%s' data row export, waiting for server...",
76+
self.uid)
77+
time.sleep(sleep_time)

labelbox/schema/project.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def export_labels(self,
298298

299299
def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str:
300300
"""Returns a concatenated string of the dictionary's keys and values
301-
301+
302302
The string will be formatted as {key}: 'value' for each key. Value will be inclusive of
303303
quotations while key will not. This can be toggled with `value_with_quotes`"""
304304

@@ -840,6 +840,25 @@ def bulk_import_requests(self) -> PaginatedCollection:
840840
["bulkImportRequests"],
841841
Entity.BulkImportRequest)
842842

843+
def batches(self) -> PaginatedCollection:
844+
""" Fetch all batches that belong to this project
845+
846+
Returns:
847+
A `PaginatedCollection of `Batch`es
848+
"""
849+
id_param = "projectId"
850+
query_str = """query GetProjectBatchesPyApi($from: String, $first: PageSize, $%s: ID!) {
851+
project(where: {id: $%s}) {id
852+
batches(after: $from, first: $first) { nodes { %s } pageInfo { endCursor }}}}
853+
""" % (id_param, id_param, query.results_query_part(Entity.Batch))
854+
return PaginatedCollection(self.client,
855+
query_str, {id_param: self.uid},
856+
['project', 'batches', 'nodes'],
857+
Entity.Batch,
858+
cursor_path=['project', 'batches',
859+
'pageInfo', 'endCursor'],
860+
experimental=True)
861+
843862
def upload_annotations(
844863
self,
845864
name: str,

tests/integration/test_batch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,22 @@ def test_create_batch(configured_project: Project, big_dataset: Dataset):
2525
batch = configured_project.create_batch("test-batch", data_rows, 3)
2626
assert batch.name == 'test-batch'
2727
assert batch.size == len(data_rows)
28+
29+
def test_export_data_rows(configured_project: Project, dataset: Dataset):
30+
n_data_rows = 5
31+
task = dataset.create_data_rows([
32+
{
33+
"row_data": IMAGE_URL,
34+
"external_id": "my-image"
35+
},
36+
] * n_data_rows)
37+
task.wait_till_done()
38+
39+
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
40+
batch = configured_project.create_batch("batch test", data_rows)
41+
42+
result = list(batch.export_data_rows())
43+
exported_data_rows = [dr.uid for dr in result]
44+
45+
assert len(result) == n_data_rows
46+
assert set(data_rows) == set(exported_data_rows)

0 commit comments

Comments
 (0)