2424from labelbox .pagination import PaginatedCollection
2525from labelbox .pydantic_compat import BaseModel
2626from labelbox .schema .data_row import DataRow
27+ from labelbox .schema .embeddings import EmbeddingVector
2728from labelbox .schema .export_filters import DatasetExportFilters , build_filters
2829from labelbox .schema .export_params import CatalogExportParams , validate_catalog_export_params
2930from labelbox .schema .export_task import ExportTask
@@ -179,14 +180,20 @@ def convert_field_keys(items):
179180 args [DataRow .metadata_fields .name ] = mdo .parse_upsert_metadata (
180181 args [DataRow .metadata_fields .name ])
181182
183+ if "embeddings" in args :
184+ args ["embeddings" ] = [
185+ EmbeddingVector (** e ).to_gql () for e in args ["embeddings" ]
186+ ]
187+
182188 query_str = """mutation CreateDataRowPyApi(
183189 $row_data: String!,
184190 $metadata_fields: [DataRowCustomMetadataUpsertInput!],
185191 $attachments: [DataRowAttachmentInput!],
186192 $media_type : MediaType,
187193 $external_id : String,
188194 $global_key : String,
189- $dataset: ID!
195+ $dataset: ID!,
196+ $embeddings: [DataRowEmbeddingVectorInput!]
190197 ){
191198 createDataRow(
192199 data:
@@ -198,6 +205,7 @@ def convert_field_keys(items):
198205 globalKey: $global_key
199206 attachments: $attachments
200207 dataset: {connect: {id: $dataset}}
208+ embeddings: $embeddings
201209 }
202210 )
203211 {%s}
@@ -388,6 +396,13 @@ def validate_attachments(item):
388396 )
389397 return attachments
390398
399+ def validate_embeddings (item ):
400+ embeddings = item .get ("embeddings" )
401+ if embeddings :
402+ item ["embeddings" ] = [
403+ EmbeddingVector (** e ).to_gql () for e in embeddings
404+ ]
405+
391406 def validate_conversational_data (conversational_data : list ) -> None :
392407 """
393408 Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json
@@ -448,7 +463,9 @@ def validate_keys(item):
448463 str ) and item .get ('row_data' ).startswith ("s3:/" ):
449464 raise InvalidQueryError (
450465 "row_data: s3 assets must start with 'https'." )
451- allowed_extra_fields = {'attachments' , 'media_type' , 'dataset_id' }
466+ allowed_extra_fields = {
467+ 'attachments' , 'media_type' , 'dataset_id' , 'embeddings'
468+ }
452469 invalid_keys = set (item ) - {f .name for f in DataRow .fields ()
453470 } - allowed_extra_fields
454471 if invalid_keys :
@@ -494,6 +511,8 @@ def convert_item(data_row_item):
494511 validate_keys (item )
495512 # Make sure attachments are valid
496513 validate_attachments (item )
514+ # Make sure embeddings are valid
515+ validate_embeddings (item )
497516 # Parse metadata fields if they exist
498517 parse_metadata_fields (item )
499518 # Upload any local file paths
0 commit comments