11# type: ignore
22import datetime
33import warnings
4+ from copy import deepcopy
45from enum import Enum
56from itertools import chain
67from typing import List , Optional , Dict , Union , Callable , Type , Any , Generator
@@ -46,7 +47,9 @@ def id(self):
4647OptionId : Type [SchemaId ] = SchemaId # enum option
4748Number : Type [float ] = float
4849
49- DataRowMetadataValue = Union [Embedding , DateTime , String , OptionId , Number ]
50+ DataRowMetadataValue = Union [Embedding , Number , DateTime , String , OptionId ]
51+ # primitives used in uploads
52+ _DataRowMetadataValuePrimitives = Union [str , List , dict , float ]
5053
5154
5255class _CamelCaseMixin (BaseModel ):
@@ -59,7 +62,7 @@ class Config:
5962# Metadata base class
6063class DataRowMetadataField (_CamelCaseMixin ):
6164 schema_id : SchemaId
62- value : DataRowMetadataValue
65+ value : Any
6366
6467
6568class DataRowMetadata (_CamelCaseMixin ):
@@ -85,7 +88,7 @@ class DataRowMetadataBatchResponse(_CamelCaseMixin):
8588# Bulk upsert values
8689class _UpsertDataRowMetadataInput (_CamelCaseMixin ):
8790 schema_id : str
88- value : Union [ str , List , dict ]
91+ value : Any
8992
9093
9194# Batch of upsert values for a datarow
@@ -121,28 +124,30 @@ def __init__(self, client):
121124 self ._batch_size = 50 # used for uploads and deletes
122125
123126 self ._raw_ontology = self ._get_ontology ()
127+ self ._build_ontology ()
124128
129+ def _build_ontology (self ):
125130 # all fields
126- self .fields = self ._parse_ontology ()
131+ self .fields = self ._parse_ontology (self . _raw_ontology )
127132 self .fields_by_id = self ._make_id_index (self .fields )
128133
129134 # reserved fields
130135 self .reserved_fields : List [DataRowMetadataSchema ] = [
131136 f for f in self .fields if f .reserved
132137 ]
133138 self .reserved_by_id = self ._make_id_index (self .reserved_fields )
134- self .reserved_by_name : Dict [str , DataRowMetadataSchema ] = {
135- f . name : f for f in self . reserved_fields
136- }
139+ self .reserved_by_name : Dict [
140+ str ,
141+ DataRowMetadataSchema ] = self . _make_name_index ( self . reserved_fields )
137142
138143 # custom fields
139144 self .custom_fields : List [DataRowMetadataSchema ] = [
140145 f for f in self .fields if not f .reserved
141146 ]
142147 self .custom_by_id = self ._make_id_index (self .custom_fields )
143- self .custom_by_name : Dict [str , DataRowMetadataSchema ] = {
144- f . name : f for f in self . custom_fields
145- }
148+ self .custom_by_name : Dict [
149+ str ,
150+ DataRowMetadataSchema ] = self . _make_name_index ( self . custom_fields )
146151
147152 @staticmethod
148153 def _make_name_index (fields : List [DataRowMetadataSchema ]):
@@ -151,7 +156,7 @@ def _make_name_index(fields: List[DataRowMetadataSchema]):
151156 if f .options :
152157 index [f .name ] = {}
153158 for o in f .options :
154- index [o .name ] = o
159+ index [f . name ][ o .name ] = o
155160 else :
156161 index [f .name ] = f
157162 return index
@@ -185,15 +190,17 @@ def _get_ontology(self) -> List[Dict[str, Any]]:
185190 """
186191 return self ._client .execute (query )["customMetadataOntology" ]
187192
188- def _parse_ontology (self ) -> List [DataRowMetadataSchema ]:
193+ @staticmethod
194+ def _parse_ontology (raw_ontology ) -> List [DataRowMetadataSchema ]:
189195 fields = []
190- for schema in self ._raw_ontology :
191- schema ["uid" ] = schema .pop ("id" )
196+ copy = deepcopy (raw_ontology )
197+ for schema in copy :
198+ schema ["uid" ] = schema ["id" ]
192199 options = None
193200 if schema .get ("options" ):
194201 options = []
195202 for option in schema ["options" ]:
196- option ["uid" ] = option . pop ( "id" )
203+ option ["uid" ] = option [ "id" ]
197204 options .append (
198205 DataRowMetadataSchema (** {
199206 ** option ,
@@ -415,6 +422,8 @@ def _parse_upsert(
415422 parsed = _validate_parse_datetime (metadatum )
416423 elif schema .kind == DataRowMetadataKind .string :
417424 parsed = _validate_parse_text (metadatum )
425+ elif schema .kind == DataRowMetadataKind .number :
426+ parsed = _validate_parse_number (metadatum )
418427 elif schema .kind == DataRowMetadataKind .embedding :
419428 parsed = _validate_parse_embedding (metadatum )
420429 elif schema .kind == DataRowMetadataKind .enum :
@@ -472,6 +481,12 @@ def _validate_parse_embedding(
472481 return [field .dict (by_alias = True )]
473482
474483
484+ def _validate_parse_number (
485+ field : DataRowMetadataField
486+ ) -> List [Dict [str , Union [SchemaId , Number ]]]:
487+ return [field .dict (by_alias = True )]
488+
489+
475490def _validate_parse_datetime (
476491 field : DataRowMetadataField ) -> List [Dict [str , Union [SchemaId , str ]]]:
477492 # TODO: better validate tzinfo
0 commit comments