Skip to content

Commit 8389ca2

Browse files
Merge pull request #578 from Labelbox/kkim/AL-2496
[AL-2496] [AL-2498] Rename custom_metadata to metadata_fields for DataRow
2 parents 30edebd + ac02869 commit 8389ca2

File tree

4 files changed

+106
-37
lines changed

4 files changed

+106
-37
lines changed

labelbox/schema/batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
104104
response = requests.get(download_url)
105105
response.raise_for_status()
106106
reader = ndjson.reader(StringIO(response.text))
107-
# TODO: Update result to parse customMetadata when resolver returns
107+
# TODO: Update result to parse metadataFields when resolver returns
108108
return (Entity.DataRow(self.client, {
109-
**result, 'customMetadata': []
109+
**result, 'metadataFields': []
110110
}) for result in reader)
111111
elif res["status"] == "FAILED":
112112
raise LabelboxError("Data row export failed.")

labelbox/schema/data_row.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class DataRow(DbObject, Updateable, BulkDeletable):
2222
updated_at (datetime)
2323
created_at (datetime)
2424
media_attributes (dict): generated media attributes for the datarow
25-
custom_metadata (list): metadata associated with the datarow
25+
metadata_fields (list): metadata associated with the datarow
2626
2727
dataset (Relationship): `ToOne` relationship to Dataset
2828
created_by (Relationship): `ToOne` relationship to User
@@ -35,11 +35,11 @@ class DataRow(DbObject, Updateable, BulkDeletable):
3535
updated_at = Field.DateTime("updated_at")
3636
created_at = Field.DateTime("created_at")
3737
media_attributes = Field.Json("media_attributes")
38-
custom_metadata = Field.List(
38+
metadata_fields = Field.List(
3939
DataRowMetadataField,
4040
graphql_type="DataRowCustomMetadataUpsertInput!",
41-
name="custom_metadata",
42-
result_subquery="customMetadata { value schemaId }")
41+
name="metadata_fields",
42+
result_subquery="metadataFields { schemaId name value kind }")
4343

4444
# Relationships
4545
dataset = Relationship.ToOne("Dataset")

labelbox/schema/dataset.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,40 +52,57 @@ class Dataset(DbObject, Updateable, Deletable):
5252
iam_integration = Relationship.ToOne("IAMIntegration", False,
5353
"iam_integration", "signer")
5454

55-
def create_data_row(self, **kwargs) -> "DataRow":
55+
def create_data_row(self, items=None, **kwargs) -> "DataRow":
5656
""" Creates a single DataRow belonging to this dataset.
5757
5858
>>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg")
5959
6060
Args:
61+
items: Dictionary containing new `DataRow` data. At a minimum,
62+
must contain `row_data` or `DataRow.row_data`.
6163
**kwargs: Key-value arguments containing new `DataRow` data. At a minimum,
6264
must contain `row_data`.
6365
6466
Raises:
67+
InvalidQueryError: If both dictionary and `kwargs` are provided as inputs
6568
InvalidQueryError: If `DataRow.row_data` field value is not provided
6669
in `kwargs`.
6770
InvalidAttributeError: in case the DB object type does not contain
6871
any of the field names given in `kwargs`.
6972
7073
"""
74+
invalid_argument_error = "Argument to create_data_row() must be either a dictionary, or kwargs containing `row_data` at minimum"
75+
76+
def convert_field_keys(items):
77+
if not isinstance(items, dict):
78+
raise InvalidQueryError(invalid_argument_error)
79+
return {
80+
key.name if isinstance(key, Field) else key: value
81+
for key, value in items.items()
82+
}
83+
84+
if items is not None and len(kwargs) > 0:
85+
raise InvalidQueryError(invalid_argument_error)
86+
7187
DataRow = Entity.DataRow
72-
if DataRow.row_data.name not in kwargs:
88+
args = convert_field_keys(items) if items is not None else kwargs
89+
90+
if DataRow.row_data.name not in args:
7391
raise InvalidQueryError(
7492
"DataRow.row_data missing when creating DataRow.")
7593

7694
# If row data is a local file path, upload it to server.
77-
row_data = kwargs[DataRow.row_data.name]
95+
row_data = args[DataRow.row_data.name]
7896
if os.path.exists(row_data):
79-
kwargs[DataRow.row_data.name] = self.client.upload_file(row_data)
80-
kwargs[DataRow.dataset.name] = self
97+
args[DataRow.row_data.name] = self.client.upload_file(row_data)
98+
args[DataRow.dataset.name] = self
8199

82100
# Parse metadata fields, if they are provided
83-
if DataRow.custom_metadata.name in kwargs:
101+
if DataRow.metadata_fields.name in args:
84102
mdo = self.client.get_data_row_metadata_ontology()
85-
kwargs[DataRow.custom_metadata.name] = mdo.parse_upsert_metadata(
86-
kwargs[DataRow.custom_metadata.name])
87-
88-
return self.client._create(DataRow, kwargs)
103+
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
104+
args[DataRow.metadata_fields.name])
105+
return self.client._create(DataRow, args)
89106

90107
def create_data_rows_sync(self, items) -> None:
91108
""" Synchronously bulk upload data rows.
@@ -264,10 +281,10 @@ def validate_attachments(item):
264281
return attachments
265282

266283
def parse_metadata_fields(item):
267-
metadata_fields = item.get('custom_metadata')
284+
metadata_fields = item.get('metadata_fields')
268285
if metadata_fields:
269286
mdo = self.client.get_data_row_metadata_ontology()
270-
item['custom_metadata'] = mdo.parse_upsert_metadata(
287+
item['metadata_fields'] = mdo.parse_upsert_metadata(
271288
metadata_fields)
272289

273290
def format_row(item):
@@ -413,9 +430,9 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
413430
response = requests.get(download_url)
414431
response.raise_for_status()
415432
reader = ndjson.reader(StringIO(response.text))
416-
# TODO: Update result to parse customMetadata when resolver returns
433+
# TODO: Update result to parse metadataFields when resolver returns
417434
return (Entity.DataRow(self.client, {
418-
**result, 'customMetadata': []
435+
**result, 'metadataFields': []
419436
}) for result in reader)
420437
elif res["status"] == "FAILED":
421438
raise LabelboxError("Data row export failed.")

tests/integration/test_data_rows.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import imghdr
12
from tempfile import NamedTemporaryFile
23
import uuid
34
import time
@@ -177,12 +178,63 @@ def test_data_row_single_creation(dataset, rand_gen, image_url):
177178
assert requests.get(data_row_2.row_data).content == data
178179

179180

181+
def test_create_data_row_with_dict(dataset, image_url):
182+
client = dataset.client
183+
assert len(list(dataset.data_rows())) == 0
184+
dr = {"row_data": image_url}
185+
data_row = dataset.create_data_row(dr)
186+
assert len(list(dataset.data_rows())) == 1
187+
assert data_row.dataset() == dataset
188+
assert data_row.created_by() == client.get_user()
189+
assert data_row.organization() == client.get_organization()
190+
assert requests.get(image_url).content == \
191+
requests.get(data_row.row_data).content
192+
assert data_row.media_attributes is not None
193+
194+
195+
def test_create_data_row_with_dict_containing_field(dataset, image_url):
196+
client = dataset.client
197+
assert len(list(dataset.data_rows())) == 0
198+
dr = {DataRow.row_data: image_url}
199+
data_row = dataset.create_data_row(dr)
200+
assert len(list(dataset.data_rows())) == 1
201+
assert data_row.dataset() == dataset
202+
assert data_row.created_by() == client.get_user()
203+
assert data_row.organization() == client.get_organization()
204+
assert requests.get(image_url).content == \
205+
requests.get(data_row.row_data).content
206+
assert data_row.media_attributes is not None
207+
208+
209+
def test_create_data_row_with_dict_unpacked(dataset, image_url):
210+
client = dataset.client
211+
assert len(list(dataset.data_rows())) == 0
212+
dr = {"row_data": image_url}
213+
data_row = dataset.create_data_row(**dr)
214+
assert len(list(dataset.data_rows())) == 1
215+
assert data_row.dataset() == dataset
216+
assert data_row.created_by() == client.get_user()
217+
assert data_row.organization() == client.get_organization()
218+
assert requests.get(image_url).content == \
219+
requests.get(data_row.row_data).content
220+
assert data_row.media_attributes is not None
221+
222+
223+
def test_create_data_row_with_invalid_input(dataset, image_url):
224+
with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc:
225+
dataset.create_data_row("asdf")
226+
227+
dr = {"row_data": image_url}
228+
with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc:
229+
dataset.create_data_row(dr, row_data=image_url)
230+
231+
180232
def test_create_data_row_with_metadata(dataset, image_url):
181233
client = dataset.client
182234
assert len(list(dataset.data_rows())) == 0
183235

184236
data_row = dataset.create_data_row(row_data=image_url,
185-
custom_metadata=make_metadata_fields())
237+
metadata_fields=make_metadata_fields())
186238

187239
assert len(list(dataset.data_rows())) == 1
188240
assert data_row.dataset() == dataset
@@ -191,8 +243,8 @@ def test_create_data_row_with_metadata(dataset, image_url):
191243
assert requests.get(image_url).content == \
192244
requests.get(data_row.row_data).content
193245
assert data_row.media_attributes is not None
194-
assert len(data_row.custom_metadata) == 5
195-
assert [m["schemaId"] for m in data_row.custom_metadata
246+
assert len(data_row.metadata_fields) == 4
247+
assert [m["schemaId"] for m in data_row.metadata_fields
196248
].sort() == EXPECTED_METADATA_SCHEMA_IDS
197249

198250

@@ -201,7 +253,7 @@ def test_create_data_row_with_metadata_dict(dataset, image_url):
201253
assert len(list(dataset.data_rows())) == 0
202254

203255
data_row = dataset.create_data_row(
204-
row_data=image_url, custom_metadata=make_metadata_fields_dict())
256+
row_data=image_url, metadata_fields=make_metadata_fields_dict())
205257

206258
assert len(list(dataset.data_rows())) == 1
207259
assert data_row.dataset() == dataset
@@ -210,8 +262,8 @@ def test_create_data_row_with_metadata_dict(dataset, image_url):
210262
assert requests.get(image_url).content == \
211263
requests.get(data_row.row_data).content
212264
assert data_row.media_attributes is not None
213-
assert len(data_row.custom_metadata) == 5
214-
assert [m["schemaId"] for m in data_row.custom_metadata
265+
assert len(data_row.metadata_fields) == 4
266+
assert [m["schemaId"] for m in data_row.metadata_fields
215267
].sort() == EXPECTED_METADATA_SCHEMA_IDS
216268

217269

@@ -221,7 +273,7 @@ def test_create_data_row_with_invalid_metadata(dataset, image_url):
221273
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID, value=[0.0] * 128))
222274

223275
with pytest.raises(labelbox.exceptions.MalformedQueryException) as excinfo:
224-
dataset.create_data_row(row_data=image_url, custom_metadata=fields)
276+
dataset.create_data_row(row_data=image_url, metadata_fields=fields)
225277

226278

227279
def test_create_data_rows_with_metadata(dataset, image_url):
@@ -232,22 +284,22 @@ def test_create_data_rows_with_metadata(dataset, image_url):
232284
{
233285
DataRow.row_data: image_url,
234286
DataRow.external_id: "row1",
235-
DataRow.custom_metadata: make_metadata_fields()
287+
DataRow.metadata_fields: make_metadata_fields()
236288
},
237289
{
238290
DataRow.row_data: image_url,
239291
DataRow.external_id: "row2",
240-
"custom_metadata": make_metadata_fields()
292+
"metadata_fields": make_metadata_fields()
241293
},
242294
{
243295
DataRow.row_data: image_url,
244296
DataRow.external_id: "row3",
245-
DataRow.custom_metadata: make_metadata_fields_dict()
297+
DataRow.metadata_fields: make_metadata_fields_dict()
246298
},
247299
{
248300
DataRow.row_data: image_url,
249301
DataRow.external_id: "row4",
250-
"custom_metadata": make_metadata_fields_dict()
302+
"metadata_fields": make_metadata_fields_dict()
251303
},
252304
])
253305
task.wait_till_done()
@@ -261,8 +313,8 @@ def test_create_data_rows_with_metadata(dataset, image_url):
261313
assert requests.get(image_url).content == \
262314
requests.get(row.row_data).content
263315
assert row.media_attributes is not None
264-
assert len(row.custom_metadata) == 5
265-
assert [m["schemaId"] for m in row.custom_metadata
316+
assert len(row.metadata_fields) == 4
317+
assert [m["schemaId"] for m in row.metadata_fields
266318
].sort() == EXPECTED_METADATA_SCHEMA_IDS
267319

268320

@@ -273,7 +325,7 @@ def test_create_data_rows_with_invalid_metadata(dataset, image_url):
273325

274326
task = dataset.create_data_rows([{
275327
DataRow.row_data: image_url,
276-
DataRow.custom_metadata: fields
328+
DataRow.metadata_fields: fields
277329
}])
278330
task.wait_till_done()
279331
assert task.status == "FAILED"
@@ -288,7 +340,7 @@ def test_create_data_rows_with_metadata_missing_value(dataset, image_url):
288340
{
289341
DataRow.row_data: image_url,
290342
DataRow.external_id: "row1",
291-
DataRow.custom_metadata: fields
343+
DataRow.metadata_fields: fields
292344
},
293345
])
294346

@@ -302,7 +354,7 @@ def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url):
302354
{
303355
DataRow.row_data: image_url,
304356
DataRow.external_id: "row1",
305-
DataRow.custom_metadata: fields
357+
DataRow.metadata_fields: fields
306358
},
307359
])
308360

@@ -316,7 +368,7 @@ def test_create_data_rows_with_metadata_wrong_type(dataset, image_url):
316368
{
317369
DataRow.row_data: image_url,
318370
DataRow.external_id: "row1",
319-
DataRow.custom_metadata: fields
371+
DataRow.metadata_fields: fields
320372
},
321373
])
322374

0 commit comments

Comments
 (0)