Skip to content

Commit a5c8844

Browse files
author
gdj0nes
committed
CHG: id -> uid, add hierarchical name index for fields, index fields to be more intuitive
1 parent 00e32d4 commit a5c8844

File tree

3 files changed

+70
-58
lines changed

3 files changed

+70
-58
lines changed

examples/basics/data_row_metadata.ipynb

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,14 @@
8787
" DataRowMetadata,\n",
8888
" DataRowMetadataField,\n",
8989
" DeleteDataRowMetadata,\n",
90-
" DataRowMetadataKind\n",
9190
")\n",
9291
"from sklearn.random_projection import GaussianRandomProjection\n",
92+
"import tensorflow as tf\n",
9393
"import seaborn as sns\n",
94-
"from datetime import datetime\n",
95-
"from pprint import pprint\n",
9694
"import tensorflow_hub as hub\n",
95+
"from datetime import datetime\n",
9796
"from tqdm.notebook import tqdm\n",
9897
"import requests\n",
99-
"import tensorflow as tf\n",
10098
"from pprint import pprint"
10199
]
102100
},
@@ -154,7 +152,7 @@
154152
"outputs": [],
155153
"source": [
156154
"# dictionary access with id\n",
157-
"pprint(mdo.all_fields_id_index, indent=2)"
155+
"pprint(mdo.fields_by_id, indent=2)"
158156
]
159157
},
160158
{
@@ -167,7 +165,8 @@
167165
"outputs": [],
168166
"source": [
169167
"# access by name\n",
170-
"split_field = mdo.reserved_name_index[\"split\"]"
168+
"split_field = mdo.reserved_by_name[\"split\"]\n",
169+
"train_field = mdo.reserved_by_name[\"split\"][\"train\"]"
171170
]
172171
},
173172
{
@@ -191,7 +190,7 @@
191190
},
192191
"outputs": [],
193192
"source": [
194-
"tag_field = mdo.reserved_name_index[\"tag\"]"
193+
"tag_field = mdo.reserved_by_name[\"tag\"]"
195194
]
196195
},
197196
{
@@ -286,7 +285,7 @@
286285
"outputs": [],
287286
"source": [
288287
"field = DataRowMetadataField(\n",
289-
" schema_id=mdo.reserved_name_index[\"captureDateTime\"].id, # specify the schema id\n",
288+
" schema_id=mdo.reserved_by_name[\"captureDateTime\"].id, # specify the schema id\n",
290289
" value=datetime.now(), # typed inputs\n",
291290
")\n",
292291
"# Completed object ready for upload\n",
@@ -356,11 +355,11 @@
356355
" # assign datarows a split\n",
357356
" rnd = random.random()\n",
358357
" if rnd < test:\n",
359-
" split = \"cko8scbz70005h2dkastwhgqt\"\n",
358+
" split = mdo.reserved_by_name[\"split\"][\"test\"]\n",
360359
" elif rnd < valid:\n",
361-
" split = \"cko8sc2yr0004h2dk69aj5x63\"\n",
360+
" split = mdo.reserved_by_name[\"split\"][\"valid\"]\n",
362361
" else:\n",
363-
" split = \"cko8sbscr0003h2dk04w86hof\"\n",
362+
" split = mdo.reserved_by_name[\"split\"][\"train\"]\n",
364363
" \n",
365364
" embeddings.append(list(model(processor(response.content), training=False)[0].numpy()))\n",
366365
" dt = datetime.utcnow() \n",
@@ -371,15 +370,15 @@
371370
" data_row_id=datarow.uid,\n",
372371
" fields=[\n",
373372
" DataRowMetadataField(\n",
374-
" schema_id=mdo.reserved_name_index[\"captureDateTime\"].id,\n",
373+
" schema_id=mdo.reserved_by_name[\"captureDateTime\"].uid,\n",
375374
" value=dt,\n",
376375
" ),\n",
377376
" DataRowMetadataField(\n",
378-
" schema_id=mdo.reserved_name_index[\"split\"].id,\n",
377+
" schema_id=mdo.reserved_by_name[\"split\"].uid,\n",
379378
" value=split\n",
380379
" ),\n",
381380
" DataRowMetadataField(\n",
382-
" schema_id=mdo.reserved_name_index[\"tag\"].id,\n",
381+
" schema_id=mdo.reserved_by_name[\"tag\"].uid,\n",
383382
" value=message\n",
384383
" ),\n",
385384
" ]\n",
@@ -438,7 +437,7 @@
438437
"for md, embd in zip(uploads, projected):\n",
439438
" md.fields.append(\n",
440439
" DataRowMetadataField(\n",
441-
" schema_id=mdo.reserved_name_index[\"embedding\"].id,\n",
440+
" schema_id=mdo.reserved_by_name[\"embedding\"].uid,\n",
442441
" value=embd.tolist(), # convert from numpy to list\n",
443442
" ),\n",
444443
" )"
@@ -568,7 +567,7 @@
568567
"fields = []\n",
569568
"# iterate through the fields you want to delete\n",
570569
"for field in md.fields:\n",
571-
" schema = mdo.all_fields_id_index[field.schema_id]\n",
570+
" schema = mdo.field_by_index[field.schema_id]\n",
572571
" fields.append(field.schema_id)\n",
573572
"\n",
574573
"deletes = DeleteDataRowMetadata(\n",
@@ -650,4 +649,4 @@
650649
},
651650
"nbformat": 4,
652651
"nbformat_minor": 5
653-
}
652+
}

labelbox/schema/data_row_metadata.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# type: ignore
22
import datetime
3+
import warnings
34
from enum import Enum
45
from itertools import chain
56
from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator
@@ -21,13 +22,18 @@ class DataRowMetadataKind(Enum):
2122

2223
# Metadata schema
2324
class DataRowMetadataSchema(BaseModel):
24-
id: SchemaId
25+
uid: SchemaId
2526
name: constr(strip_whitespace=True, min_length=1, max_length=100)
2627
reserved: bool
2728
kind: DataRowMetadataKind
2829
options: Optional[List["DataRowMetadataSchema"]]
2930
parent: Optional[SchemaId]
3031

32+
@property
33+
def id(self):
34+
warnings.warn("`id` is being deprecated in favor of `uid`")
35+
return self.uid
36+
3137

3238
DataRowMetadataSchema.update_forward_refs()
3339

@@ -36,7 +42,7 @@ class DataRowMetadataSchema(BaseModel):
3642
DateTime: Type[datetime.datetime] = datetime.datetime # must be in UTC
3743
String: Type[str] = constr(max_length=500)
3844
OptionId: Type[SchemaId] = SchemaId # enum option
39-
Number: Type[float]
45+
Number: Type[float] = float
4046

4147
DataRowMetadataValue = Union[Embedding, DateTime, String, OptionId, Number]
4248

@@ -107,28 +113,31 @@ class DataRowMetadataOntology:
107113
"""
108114

109115
def __init__(self, client):
110-
self.client = client
111-
self._batch_size = 50
112116

113-
# TODO: consider making these properties to stay in sync with server
117+
self._client = client
118+
self._batch_size = 50 # used for uploads and deletes
119+
114120
self._raw_ontology = self._get_ontology()
121+
115122
# all fields
116-
self.all_fields = self._parse_ontology()
117-
self.all_fields_id_index = self._make_id_index(self.all_fields)
123+
self.fields = self._parse_ontology()
124+
self.fields_by_id = self._make_id_index(self.fields)
125+
118126
# reserved fields
119127
self.reserved_fields: List[DataRowMetadataSchema] = [
120-
f for f in self.all_fields if f.reserved
128+
f for f in self.fields if f.reserved
121129
]
122-
self.reserved_id_index = self._make_id_index(self.reserved_fields)
123-
self.reserved_name_index: Dict[str, DataRowMetadataSchema] = {
130+
self.reserved_by_id = self._make_id_index(self.reserved_fields)
131+
self.reserved_by_name: Dict[str, DataRowMetadataSchema] = {
124132
f.name: f for f in self.reserved_fields
125133
}
134+
126135
# custom fields
127136
self.custom_fields: List[DataRowMetadataSchema] = [
128-
f for f in self.all_fields if not f.reserved
137+
f for f in self.fields if not f.reserved
129138
]
130-
self.custom_id_index = self._make_id_index(self.custom_fields)
131-
self.custom_name_index: Dict[str, DataRowMetadataSchema] = {
139+
self.custom_by_id = self._make_id_index(self.custom_fields)
140+
self.custom_by_name: Dict[str, DataRowMetadataSchema] = {
132141
f.name: f for f in self.custom_fields
133142
}
134143

@@ -150,13 +159,13 @@ def _make_id_index(
150159
) -> Dict[SchemaId, DataRowMetadataSchema]:
151160
index = {}
152161
for f in fields:
153-
index[f.id] = f
162+
index[f.uid] = f
154163
if f.options:
155164
for o in f.options:
156-
index[o.id] = o
165+
index[o.uid] = o
157166
return index
158167

159-
def _get_ontology(self) -> Dict[str, Any]:
168+
def _get_ontology(self) -> List[Dict[str, Any]]:
160169
query = """query GetMetadataOntologyBetaPyApi {
161170
customMetadataOntology {
162171
id
@@ -171,21 +180,26 @@ def _get_ontology(self) -> Dict[str, Any]:
171180
}
172181
}}
173182
"""
174-
return self.client.execute(query)["customMetadataOntology"]
183+
return self._client.execute(query)["customMetadataOntology"]
175184

176185
def _parse_ontology(self) -> List[DataRowMetadataSchema]:
177186
fields = []
178187
for schema in self._raw_ontology:
188+
schema["uid"] = schema.pop("id")
179189
options = None
180190
if schema.get("options"):
181-
options = [
182-
DataRowMetadataSchema(**{
183-
**option,
184-
**{
185-
"parent": schema["id"]
186-
}
187-
}) for option in schema["options"]
188-
]
191+
options = []
192+
for option in schema["options"]:
193+
option["uid"] = option.pop("id")
194+
options.append(
195+
DataRowMetadataSchema(
196+
**{
197+
**option,
198+
**{
199+
"parent": schema["id"]
200+
}
201+
})
202+
)
189203
schema["options"] = options
190204
fields.append(DataRowMetadataSchema(**schema))
191205

@@ -197,7 +211,7 @@ def parse_metadata(
197211
Dict]]]]) -> List[DataRowMetadata]:
198212
""" Parse metadata responses
199213
200-
>>> mdo.parse_metadata([datarow.metadata])
214+
>>> mdo.parse_metadata([metdata])
201215
202216
Args:
203217
unparsed: An unparsed metadata export
@@ -213,14 +227,14 @@ def parse_metadata(
213227
for dr in unparsed:
214228
fields = []
215229
for f in dr["fields"]:
216-
schema = self.all_fields_id_index[f["schemaId"]]
230+
schema = self.fields_by_id[f["schemaId"]]
217231
if schema.kind == DataRowMetadataKind.enum:
218232
continue
219233
elif schema.kind == DataRowMetadataKind.option:
220234
field = DataRowMetadataField(schema_id=schema.parent,
221-
value=schema.id)
235+
value=schema.uid)
222236
else:
223-
field = DataRowMetadataField(schema_id=schema.id,
237+
field = DataRowMetadataField(schema_id=schema.uid,
224238
value=f["value"])
225239

226240
fields.append(field)
@@ -267,7 +281,7 @@ def _batch_upsert(
267281
}
268282
}
269283
}"""
270-
res = self.client.execute(
284+
res = self._client.execute(
271285
query, {"metadata": upserts})['upsertDataRowCustomMetadata']
272286
return [
273287
DataRowMetadataBatchResponse(data_row_id=r['dataRowId'],
@@ -330,7 +344,7 @@ def _batch_delete(
330344
}
331345
}
332346
"""
333-
res = self.client.execute(
347+
res = self._client.execute(
334348
query, {"deletes": deletes})['deleteDataRowCustomMetadata']
335349
failures = []
336350
for dr in res:
@@ -373,7 +387,7 @@ def _bulk_export(_data_row_ids: List[str]) -> List[DataRowMetadata]:
373387
}
374388
"""
375389
return self.parse_metadata(
376-
self.client.execute(
390+
self._client.execute(
377391
query,
378392
{"dataRowIds": _data_row_ids})['dataRowCustomMetadata'])
379393

@@ -386,11 +400,11 @@ def _parse_upsert(
386400
) -> List[_UpsertDataRowMetadataInput]:
387401
"""Format for metadata upserts to GQL"""
388402

389-
if metadatum.schema_id not in self.all_fields_id_index:
403+
if metadatum.schema_id not in self.fields_by_id:
390404
raise ValueError(
391405
f"Schema Id `{metadatum.schema_id}` not found in ontology")
392406

393-
schema = self.all_fields_id_index[metadatum.schema_id]
407+
schema = self.fields_by_id[metadatum.schema_id]
394408

395409
if schema.kind == DataRowMetadataKind.datetime:
396410
parsed = _validate_parse_datetime(metadatum)
@@ -413,16 +427,16 @@ def _validate_delete(self, delete: DeleteDataRowMetadata):
413427

414428
deletes = set()
415429
for schema_id in delete.fields:
416-
if schema_id not in self.all_fields_id_index:
430+
if schema_id not in self.fields_by_id:
417431
raise ValueError(
418432
f"Schema Id `{schema_id}` not found in ontology")
419433

420-
schema = self.all_fields_id_index[schema_id]
434+
schema = self.fields_by_id[schema_id]
421435
# handle users specifying enums by adding all option enums
422436
if schema.kind == DataRowMetadataKind.enum:
423-
[deletes.add(o.id) for o in schema.options]
437+
[deletes.add(o.uid) for o in schema.options]
424438

425-
deletes.add(schema.id)
439+
deletes.add(schema.uid)
426440

427441
return _DeleteBatchDataRowMetadata(
428442
data_row_id=delete.data_row_id,
@@ -471,7 +485,7 @@ def _validate_enum_parse(
471485
schema: DataRowMetadataSchema,
472486
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, dict]]]:
473487
if schema.options:
474-
if field.value not in {o.id for o in schema.options}:
488+
if field.value not in {o.uid for o in schema.options}:
475489
raise ValueError(
476490
f"Option `{field.value}` not found for {field.schema_id}")
477491
else:

tests/integration/test_data_row_metadata.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def make_metadata(dr_id) -> DataRowMetadata:
6565

6666

6767
def test_get_datarow_metadata_ontology(mdo):
68-
assert len(mdo.all_fields)
68+
assert len(mdo.fields)
6969
assert len(mdo.reserved_fields)
7070
assert len(mdo.custom_fields) == 0
7171

@@ -81,7 +81,6 @@ def test_bulk_upsert_datarow_metadata(datarow, mdo: DataRowMetadataOntology):
8181
@pytest.mark.slow
8282
def test_large_bulk_upsert_datarow_metadata(big_dataset, mdo):
8383
metadata = []
84-
data_row_ids = []
8584
data_row_ids = [dr.uid for dr in big_dataset.data_rows()]
8685
wait_for_embeddings_svc(data_row_ids, mdo)
8786
for data_row_id in data_row_ids:

0 commit comments

Comments
 (0)