Skip to content

Commit 729b667

Browse files
authored
[PLT-34] add support for custom embeddings to data import (#1507)
1 parent 8b42d2b commit 729b667

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

labelbox/schema/dataset.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from labelbox.pagination import PaginatedCollection
2525
from labelbox.pydantic_compat import BaseModel
2626
from labelbox.schema.data_row import DataRow
27+
from labelbox.schema.embeddings import EmbeddingVector
2728
from labelbox.schema.export_filters import DatasetExportFilters, build_filters
2829
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
2930
from labelbox.schema.export_task import ExportTask
@@ -179,14 +180,20 @@ def convert_field_keys(items):
179180
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
180181
args[DataRow.metadata_fields.name])
181182

183+
if "embeddings" in args:
184+
args["embeddings"] = [
185+
EmbeddingVector(**e).to_gql() for e in args["embeddings"]
186+
]
187+
182188
query_str = """mutation CreateDataRowPyApi(
183189
$row_data: String!,
184190
$metadata_fields: [DataRowCustomMetadataUpsertInput!],
185191
$attachments: [DataRowAttachmentInput!],
186192
$media_type : MediaType,
187193
$external_id : String,
188194
$global_key : String,
189-
$dataset: ID!
195+
$dataset: ID!,
196+
$embeddings: [DataRowEmbeddingVectorInput!]
190197
){
191198
createDataRow(
192199
data:
@@ -198,6 +205,7 @@ def convert_field_keys(items):
198205
globalKey: $global_key
199206
attachments: $attachments
200207
dataset: {connect: {id: $dataset}}
208+
embeddings: $embeddings
201209
}
202210
)
203211
{%s}
@@ -388,6 +396,13 @@ def validate_attachments(item):
388396
)
389397
return attachments
390398

399+
def validate_embeddings(item):
400+
embeddings = item.get("embeddings")
401+
if embeddings:
402+
item["embeddings"] = [
403+
EmbeddingVector(**e).to_gql() for e in embeddings
404+
]
405+
391406
def validate_conversational_data(conversational_data: list) -> None:
392407
"""
393408
Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json
@@ -448,7 +463,9 @@ def validate_keys(item):
448463
str) and item.get('row_data').startswith("s3:/"):
449464
raise InvalidQueryError(
450465
"row_data: s3 assets must start with 'https'.")
451-
allowed_extra_fields = {'attachments', 'media_type', 'dataset_id'}
466+
allowed_extra_fields = {
467+
'attachments', 'media_type', 'dataset_id', 'embeddings'
468+
}
452469
invalid_keys = set(item) - {f.name for f in DataRow.fields()
453470
} - allowed_extra_fields
454471
if invalid_keys:
@@ -494,6 +511,8 @@ def convert_item(data_row_item):
494511
validate_keys(item)
495512
# Make sure attachments are valid
496513
validate_attachments(item)
514+
# Make sure embeddings are valid
515+
validate_embeddings(item)
497516
# Parse metadata fields if they exist
498517
parse_metadata_fields(item)
499518
# Upload any local file paths

labelbox/schema/embeddings.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import List, Optional, Any, Dict
2+
3+
from labelbox.pydantic_compat import BaseModel
4+
5+
6+
class EmbeddingVector(BaseModel):
7+
embedding_id: str
8+
vector: List[float]
9+
clusters: Optional[List[int]]
10+
11+
def to_gql(self) -> Dict[str, Any]:
12+
result = {"embeddingId": self.embedding_id, "vector": self.vector}
13+
if self.clusters:
14+
result["clusters"] = self.clusters
15+
return result

0 commit comments

Comments
 (0)