Skip to content

Commit 66afe4e

Browse files
authored
Merge pull request #629 from Labelbox/jtso/al-2659
[AL-2659] Inclusion of optional metadata parameter
2 parents e5770fd + daded19 commit 66afe4e

File tree

3 files changed

+38
-26
lines changed

3 files changed

+38
-26
lines changed

labelbox/schema/batch.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def remove_queued_data_rows(self) -> None:
7777
},
7878
experimental=True)
7979

80-
def export_data_rows(self, timeout_seconds=120) -> Generator:
80+
def export_data_rows(self,
81+
timeout_seconds=120,
82+
include_metadata: bool = False) -> Generator:
8183
""" Returns a generator that produces all data rows that are currently
8284
in this batch.
8385
@@ -92,23 +94,24 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
9294
LabelboxError: if the export fails or is unable to download within the specified time.
9395
"""
9496
id_param = "batchId"
95-
query_str = """mutation GetBatchDataRowsExportUrlPyApi($%s: ID!)
96-
{exportBatchDataRows(data:{batchId: $%s }) {downloadUrl createdAt status}}
97-
""" % (id_param, id_param)
97+
metadata_param = "includeMetadataInput"
98+
query_str = """mutation GetBatchDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!)
99+
{exportBatchDataRows(data:{batchId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status}}
100+
""" % (id_param, metadata_param, id_param, metadata_param)
98101
sleep_time = 2
99102
while True:
100-
res = self.client.execute(query_str, {id_param: self.uid})
103+
res = self.client.execute(query_str, {
104+
id_param: self.uid,
105+
metadata_param: include_metadata
106+
})
101107
res = res["exportBatchDataRows"]
102108
if res["status"] == "COMPLETE":
103109
download_url = res["downloadUrl"]
104110
response = requests.get(download_url)
105111
response.raise_for_status()
106112
reader = ndjson.reader(StringIO(response.text))
107-
# TODO: Update result to parse metadataFields when resolver returns
108-
return (Entity.DataRow(self.client, {
109-
**result, 'metadataFields': [],
110-
'customMetadata': []
111-
}) for result in reader)
113+
return (
114+
Entity.DataRow(self.client, result) for result in reader)
112115
elif res["status"] == "FAILED":
113116
raise LabelboxError("Data row export failed.")
114117

labelbox/schema/dataset.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,9 @@ def data_row_for_external_id(self, external_id) -> "DataRow":
462462
external_id)
463463
return data_rows[0]
464464

465-
def export_data_rows(self, timeout_seconds=120) -> Generator:
465+
def export_data_rows(self,
466+
timeout_seconds=120,
467+
include_metadata: bool = False) -> Generator:
466468
""" Returns a generator that produces all data rows that are currently
467469
attached to this dataset.
468470
@@ -477,23 +479,24 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
477479
LabelboxError: if the export fails or is unable to download within the specified time.
478480
"""
479481
id_param = "datasetId"
480-
query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!)
481-
{exportDatasetDataRows(data:{datasetId: $%s }) {downloadUrl createdAt status}}
482-
""" % (id_param, id_param)
482+
metadata_param = "includeMetadataInput"
483+
query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!)
484+
{exportDatasetDataRows(data:{datasetId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status}}
485+
""" % (id_param, metadata_param, id_param, metadata_param)
483486
sleep_time = 2
484487
while True:
485-
res = self.client.execute(query_str, {id_param: self.uid})
488+
res = self.client.execute(query_str, {
489+
id_param: self.uid,
490+
metadata_param: include_metadata
491+
})
486492
res = res["exportDatasetDataRows"]
487493
if res["status"] == "COMPLETE":
488494
download_url = res["downloadUrl"]
489495
response = requests.get(download_url)
490496
response.raise_for_status()
491497
reader = ndjson.reader(StringIO(response.text))
492-
# TODO: Update result to parse metadataFields when resolver returns
493-
return (Entity.DataRow(self.client, {
494-
**result, 'metadataFields': [],
495-
'customMetadata': []
496-
}) for result in reader)
498+
return (
499+
Entity.DataRow(self.client, result) for result in reader)
497500
elif res["status"] == "FAILED":
498501
raise LabelboxError("Data row export failed.")
499502

labelbox/schema/project.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,10 @@ def labels(self, datasets=None, order_by=None) -> PaginatedCollection:
185185
return PaginatedCollection(self.client, query_str, {id_param: self.uid},
186186
["project", "labels"], Label)
187187

188-
def export_queued_data_rows(self,
189-
timeout_seconds=120) -> List[Dict[str, str]]:
188+
def export_queued_data_rows(
189+
self,
190+
timeout_seconds=120,
191+
include_metadata: bool = False) -> List[Dict[str, str]]:
190192
""" Returns all data rows that are currently enqueued for this project.
191193
192194
Args:
@@ -197,12 +199,16 @@ def export_queued_data_rows(self,
197199
LabelboxError: if the export fails or is unable to download within the specified time.
198200
"""
199201
id_param = "projectId"
200-
query_str = """mutation GetQueuedDataRowsExportUrlPyApi($%s: ID!)
201-
{exportQueuedDataRows(data:{projectId: $%s }) {downloadUrl createdAt status} }
202-
""" % (id_param, id_param)
202+
metadata_param = "includeMetadataInput"
203+
query_str = """mutation GetQueuedDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!)
204+
{exportQueuedDataRows(data:{projectId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status} }
205+
""" % (id_param, metadata_param, id_param, metadata_param)
203206
sleep_time = 2
204207
while True:
205-
res = self.client.execute(query_str, {id_param: self.uid})
208+
res = self.client.execute(query_str, {
209+
id_param: self.uid,
210+
metadata_param: include_metadata
211+
})
206212
res = res["exportQueuedDataRows"]
207213
if res["status"] == "COMPLETE":
208214
download_url = res["downloadUrl"]

0 commit comments

Comments
 (0)