Skip to content

Commit 2b686ee

Browse files
authored
[AL-4256] Create datarow metadata by name
2 parents 66c2f9f + 5838ea6 commit 2b686ee

File tree

5 files changed

+220
-34
lines changed

5 files changed

+220
-34
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Version 3.34.0 (...)
44
### Added
55
* Added `get_by_name()` method to MetadataOntology object to access both custom and reserved metadata by name.
6+
* Added support for adding metadata by name when creating datarows using `DataRowMetadataOntology.bulk_upsert()`.
7+
* Added support for adding metadata by name when creating datarows using `Dataset.create_data_rows()`, `Dataset.create_data_rows_sync()`, and `Dataset.create_data_row()`.
68

79
### Changed
810
* `Dataset.create_data_rows()` max limit of DataRows increased to 150,000

examples/basics/data_row_metadata.ipynb

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"source": [
4040
"# Data Row Metadata\n",
4141
"\n",
42-
"Metadata is useful to be better understand data on the platform to help with labeling review, model diagnostics, and data selection. This **should not be confused with attachments**. Attachments provide additional context for labelers but is not searchable within Catalog."
42+
"Metadata is useful to better understand data on the platform to help with labeling review, model diagnostics, and data selection. This **should not be confused with attachments**. Attachments provide additional context for labelers but is not searchable within Catalog."
4343
]
4444
},
4545
{
@@ -261,21 +261,20 @@
261261
"source": [
262262
"# Construct a metadata field of string kind\n",
263263
"tag_metadata_field = DataRowMetadataField(\n",
264-
" schema_id=mdo.reserved_by_name[\"tag\"].uid, # specify the schema id\n",
264+
" name=\"tag\", # specify the schema name\n",
265265
" value=\"tag_string\", # typed inputs\n",
266266
")\n",
267267
"\n",
268268
"# Construct an metadata field of datetime kind\n",
269269
"capture_datetime_field = DataRowMetadataField(\n",
270-
" schema_id=mdo.reserved_by_name[\"captureDateTime\"].uid, # specify the schema id\n",
270+
" name=\"captureDateTime\", # specify the schema id\n",
271271
" value=datetime.utcnow(), # typed inputs\n",
272272
")\n",
273273
"\n",
274274
"# Construct a metadata field of Enums options\n",
275-
"train_schema = mdo.reserved_by_name[\"split\"][\"train\"]\n",
276275
"split_metadta_field = DataRowMetadataField(\n",
277-
" schema_id=train_schema.parent, # specify the schema id\n",
278-
" value=train_schema.uid, # typed inputs\n",
276+
" name=\"split\", # specify the schema id\n",
277+
" value=\"train\", # typed inputs\n",
279278
")"
280279
]
281280
},
@@ -300,20 +299,20 @@
300299
"source": [
301300
"# Construct a dictionary of string metadata\n",
302301
"tag_metadata_field_dict = {\n",
303-
" \"schema_id\": mdo.reserved_by_name[\"tag\"].uid,\n",
302+
" \"name\": \"tag\",\n",
304303
" \"value\": \"tag_string\",\n",
305304
"}\n",
306305
"\n",
307306
"# Construct a dictionary of datetime metadata\n",
308307
"capture_datetime_field_dict = {\n",
309-
" \"schema_id\": mdo.reserved_by_name[\"captureDateTime\"].uid,\n",
308+
" \"name\": \"captureDateTime\",\n",
310309
" \"value\": datetime.utcnow(),\n",
311310
"}\n",
312311
"\n",
313312
"# Construct a dictionary of Enums options metadata\n",
314313
"split_metadta_field_dict = {\n",
315-
" \"schema_id\": mdo.reserved_by_name[\"split\"][\"train\"].parent,\n",
316-
" \"value\": mdo.reserved_by_name[\"split\"][\"train\"].uid,\n",
314+
" \"name\": \"split\",\n",
315+
" \"value\": \"train\",\n",
317316
"}"
318317
]
319318
},
@@ -491,7 +490,7 @@
491490
"outputs": [],
492491
"source": [
493492
"# Select a dataset to use, or you can just use the 1-image dataset created above. \n",
494-
"dataset_id = \"cl3ntfr7j7cmh07bmeqz3gfjt\"\n",
493+
"dataset_id = dataset.uid\n",
495494
"dataset = client.get_dataset(dataset_id)"
496495
]
497496
},
@@ -541,11 +540,11 @@
541540
" # assign datarows a split\n",
542541
" rnd = random.random()\n",
543542
" if rnd < test:\n",
544-
" split = mdo.reserved_by_name[\"split\"][\"test\"]\n",
543+
" split = \"test\"\n",
545544
" elif rnd < valid:\n",
546-
" split = mdo.reserved_by_name[\"split\"][\"valid\"]\n",
545+
" split = \"valid\"\n",
547546
" else:\n",
548-
" split = mdo.reserved_by_name[\"split\"][\"train\"]\n",
547+
" split = \"train\"\n",
549548
"\n",
550549
" embeddings.append(\n",
551550
" list(model(processor(response.content), training=False)[0].numpy()))\n",
@@ -557,12 +556,11 @@
557556
" data_row_id=datarow.uid,\n",
558557
" fields=[\n",
559558
" DataRowMetadataField(\n",
560-
" schema_id=mdo.reserved_by_name[\"captureDateTime\"].uid,\n",
559+
" name=\"captureDateTime\",\n",
561560
" value=dt,\n",
562561
" ),\n",
563-
" DataRowMetadataField(schema_id=split.parent, value=split.uid),\n",
564-
" DataRowMetadataField(schema_id=mdo.reserved_by_name[\"tag\"].uid,\n",
565-
" value=message),\n",
562+
" DataRowMetadataField(name=\"split\", value=split),\n",
563+
" DataRowMetadataField(name=\"tag\", value=message),\n",
566564
" ]))"
567565
]
568566
},
@@ -620,7 +618,7 @@
620618
"for md, embd in zip(uploads, projected):\n",
621619
" md.fields.append(\n",
622620
" DataRowMetadataField(\n",
623-
" schema_id=mdo.reserved_by_name[\"embedding\"].uid,\n",
621+
" name=\"embedding\",\n",
624622
" value=embd.tolist(), # convert from numpy to list\n",
625623
" ),)"
626624
]
@@ -801,4 +799,4 @@
801799
},
802800
"nbformat": 4,
803801
"nbformat_minor": 5
804-
}
802+
}

labelbox/schema/data_row_metadata.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ class DataRowMetadataSchema(BaseModel):
3838

3939
# Metadata base class
4040
class DataRowMetadataField(_CamelCaseMixin):
41-
schema_id: SchemaId
41+
# One of `schema_id` or `name` must be provided. If `schema_id` is not provided, it is
42+
# inferred from `name`
43+
schema_id: Optional[SchemaId] = None
44+
name: Optional[str] = None
4245
# value is of type `Any` so that we do not improperly coerce the value to the wrong tpye
4346
# Additional validation is performed before upload using the schema information
4447
value: Any
@@ -147,6 +150,16 @@ def _build_ontology(self):
147150
str, DataRowMetadataSchema] = self._make_normalized_name_index(
148151
self.custom_fields)
149152

153+
@staticmethod
154+
def _lookup_in_index_by_name(reserved_index, custom_index, name):
155+
# search through reserved names first
156+
if name in reserved_index:
157+
return reserved_index[name]
158+
elif name in custom_index:
159+
return custom_index[name]
160+
else:
161+
raise KeyError(f"There is no metadata with name {name}")
162+
150163
def get_by_name(
151164
self, name: str
152165
) -> Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]]:
@@ -163,14 +176,17 @@ def get_by_name(
163176
Raises:
164177
KeyError: When provided name is not presented in neither reserved nor custom metadata list
165178
"""
179+
return self._lookup_in_index_by_name(self.reserved_by_name,
180+
self.custom_by_name, name)
166181

167-
# search through reserved names first
168-
if name in self.reserved_by_name:
169-
return self.reserved_by_name[name]
170-
elif name in self.custom_by_name:
171-
return self.custom_by_name[name]
172-
else:
173-
raise KeyError(f"There is no metadata with name {name}")
182+
def _get_by_name_normalized(self, name: str) -> DataRowMetadataSchema:
183+
""" Get metadata by name. For options, it provides the option schema instead of list of
184+
options
185+
"""
186+
# using `normalized` indices to find options by name as well
187+
return self._lookup_in_index_by_name(self.reserved_by_name_normalized,
188+
self.custom_by_name_normalized,
189+
name)
174190

175191
@staticmethod
176192
def _make_name_index(
@@ -452,6 +468,8 @@ def parse_metadata_fields(
452468
else:
453469
field = DataRowMetadataField(schema_id=schema.uid,
454470
value=f["value"])
471+
472+
field.name = schema.name
455473
parsed.append(field)
456474
return parsed
457475

@@ -624,13 +642,17 @@ def _convert_metadata_field(metadata_field):
624642
if isinstance(metadata_field, DataRowMetadataField):
625643
return metadata_field
626644
elif isinstance(metadata_field, dict):
627-
if not all(key in metadata_field
628-
for key in ("schema_id", "value")):
645+
if not "value" in metadata_field:
646+
raise ValueError(
647+
f"Custom metadata field '{metadata_field}' must have a 'value' key"
648+
)
649+
if not "schema_id" in metadata_field and not "name" in metadata_field:
629650
raise ValueError(
630-
f"Custom metadata field '{metadata_field}' must have 'schema_id' and 'value' keys"
651+
f"Custom metadata field '{metadata_field}' must have either 'schema_id' or 'name' key"
631652
)
632653
return DataRowMetadataField(
633-
schema_id=metadata_field["schema_id"],
654+
schema_id=metadata_field.get("schema_id"),
655+
name=metadata_field.get("name"),
634656
value=metadata_field["value"])
635657
else:
636658
raise ValueError(
@@ -664,11 +686,32 @@ def _upsert_schema(
664686
self.refresh_ontology()
665687
return _parse_metadata_schema(res)
666688

689+
def _load_option_by_name(self, metadatum: DataRowMetadataField):
690+
is_value_a_valid_schema_id = metadatum.value in self.fields_by_id
691+
if not is_value_a_valid_schema_id:
692+
metadatum.value = self.get_by_name(
693+
metadatum.name)[metadatum.value].uid
694+
695+
def _load_schema_id_by_name(self, metadatum: DataRowMetadataField):
696+
"""
697+
Loads schema id by name for a metadata field including options schema id.
698+
"""
699+
if metadatum.name is None:
700+
return
701+
702+
if metadatum.schema_id is None:
703+
schema = self._get_by_name_normalized(metadatum.name)
704+
metadatum.schema_id = schema.uid
705+
if schema.options:
706+
self._load_option_by_name(metadatum)
707+
667708
def _parse_upsert(
668709
self, metadatum: DataRowMetadataField
669710
) -> List[_UpsertDataRowMetadataInput]:
670711
"""Format for metadata upserts to GQL"""
671712

713+
self._load_schema_id_by_name(metadatum)
714+
672715
if metadatum.schema_id not in self.fields_by_id:
673716
# Fetch latest metadata ontology if metadata can't be found
674717
self.refresh_ontology()

tests/integration/test_data_row_metadata.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh"
1919
CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb"
2020
PRE_COMPUTED_EMBEDDINGS_ID = 'ckrzang79000008l6hb5s6za1'
21+
CUSTOM_TEXT_SCHEMA_NAME = 'custom_text'
2122

2223
FAKE_NUMBER_FIELD = {
2324
"id": FAKE_SCHEMA_ID,
@@ -32,6 +33,7 @@ def mdo(client):
3233
mdo = client.get_data_row_metadata_ontology()
3334
for schema in mdo.custom_fields:
3435
mdo.delete_schema(schema.name)
36+
mdo.create_schema(CUSTOM_TEXT_SCHEMA_NAME, DataRowMetadataKind.string)
3537
mdo._raw_ontology = mdo._get_ontology()
3638
mdo._raw_ontology.append(FAKE_NUMBER_FIELD)
3739
mdo._build_ontology()
@@ -69,6 +71,25 @@ def make_metadata(dr_id) -> DataRowMetadata:
6971
return metadata
7072

7173

74+
def make_named_metadata(dr_id) -> DataRowMetadata:
75+
embeddings = [0.0] * 128
76+
msg = "A message"
77+
time = datetime.utcnow()
78+
79+
metadata = DataRowMetadata(data_row_id=dr_id,
80+
fields=[
81+
DataRowMetadataField(name='split',
82+
value=TEST_SPLIT_ID),
83+
DataRowMetadataField(name='captureDateTime',
84+
value=time),
85+
DataRowMetadataField(
86+
name=CUSTOM_TEXT_SCHEMA_NAME, value=msg),
87+
DataRowMetadataField(name='embedding',
88+
value=embeddings),
89+
])
90+
return metadata
91+
92+
7293
def test_export_empty_metadata(configured_project_with_label):
7394
project, _, _, _ = configured_project_with_label
7495
# Wait for exporter to retrieve latest labels
@@ -81,7 +102,7 @@ def test_export_empty_metadata(configured_project_with_label):
81102
def test_get_datarow_metadata_ontology(mdo):
82103
assert len(mdo.fields)
83104
assert len(mdo.reserved_fields)
84-
assert len(mdo.custom_fields) == 1
105+
assert len(mdo.custom_fields) == 2
85106

86107
split = mdo.reserved_by_name["split"]["train"]
87108

@@ -129,6 +150,48 @@ def test_large_bulk_upsert_datarow_metadata(big_dataset, mdo):
129150
]), metadata_lookup.get(data_row_id).fields
130151

131152

153+
def test_upsert_datarow_metadata_by_name(datarow, mdo):
154+
metadata = [make_named_metadata(datarow.uid)]
155+
errors = mdo.bulk_upsert(metadata)
156+
assert len(errors) == 0
157+
158+
metadata_lookup = {
159+
metadata.data_row_id: metadata
160+
for metadata in mdo.bulk_export([datarow.uid])
161+
}
162+
assert len([
163+
f for f in metadata_lookup.get(datarow.uid).fields
164+
if f.schema_id != PRE_COMPUTED_EMBEDDINGS_ID
165+
]), metadata_lookup.get(datarow.uid).fields
166+
167+
168+
def test_upsert_datarow_metadata_option_by_name(datarow, mdo):
169+
metadata = DataRowMetadata(data_row_id=datarow.uid,
170+
fields=[
171+
DataRowMetadataField(name='split',
172+
value='test'),
173+
])
174+
errors = mdo.bulk_upsert([metadata])
175+
assert len(errors) == 0
176+
177+
datarows = mdo.bulk_export([datarow.uid])
178+
assert len(datarows[0].fields) == 1
179+
metadata = datarows[0].fields[0]
180+
assert metadata.schema_id == SPLIT_SCHEMA_ID
181+
assert metadata.name == 'test'
182+
assert metadata.value == TEST_SPLIT_ID
183+
184+
185+
def test_upsert_datarow_metadata_option_by_incorrect_name(datarow, mdo):
186+
metadata = DataRowMetadata(data_row_id=datarow.uid,
187+
fields=[
188+
DataRowMetadataField(name='split',
189+
value='test1'),
190+
])
191+
with pytest.raises(KeyError):
192+
mdo.bulk_upsert([metadata])
193+
194+
132195
def test_bulk_delete_datarow_metadata(datarow, mdo):
133196
"""test bulk deletes for all fields"""
134197
metadata = make_metadata(datarow.uid)

0 commit comments

Comments
 (0)