Skip to content

Commit 9225f08

Browse files
authored
[PLT-344] Add dataset.upsert_data_rows method (#1460)
1 parent 9844d5d commit 9225f08

File tree

6 files changed

+439
-41
lines changed

6 files changed

+439
-41
lines changed

labelbox/schema/asset_attachment.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,26 @@
66
from labelbox.orm.model import Field
77

88

9+
class AttachmentType(str, Enum):
10+
11+
@classmethod
12+
def __missing__(cls, value: object):
13+
if str(value) == "TEXT":
14+
warnings.warn(
15+
"The TEXT attachment type is deprecated. Use RAW_TEXT instead.")
16+
return cls.RAW_TEXT
17+
return value
18+
19+
VIDEO = "VIDEO"
20+
IMAGE = "IMAGE"
21+
IMAGE_OVERLAY = "IMAGE_OVERLAY"
22+
HTML = "HTML"
23+
RAW_TEXT = "RAW_TEXT"
24+
TEXT_URL = "TEXT_URL"
25+
PDF_URL = "PDF_URL"
26+
CAMERA_IMAGE = "CAMERA_IMAGE" # Used by experimental point-cloud editor
27+
28+
929
class AssetAttachment(DbObject):
1030
"""Asset attachment provides extra context about an asset while labeling.
1131
@@ -15,26 +35,6 @@ class AssetAttachment(DbObject):
1535
attachment_name (str): The name of the attachment
1636
"""
1737

18-
class AttachmentType(Enum):
19-
20-
@classmethod
21-
def __missing__(cls, value: object):
22-
if str(value) == "TEXT":
23-
warnings.warn(
24-
"The TEXT attachment type is deprecated. Use RAW_TEXT instead."
25-
)
26-
return cls.RAW_TEXT
27-
return value
28-
29-
VIDEO = "VIDEO"
30-
IMAGE = "IMAGE"
31-
IMAGE_OVERLAY = "IMAGE_OVERLAY"
32-
HTML = "HTML"
33-
RAW_TEXT = "RAW_TEXT"
34-
TEXT_URL = "TEXT_URL"
35-
PDF_URL = "PDF_URL"
36-
CAMERA_IMAGE = "CAMERA_IMAGE" # Used by experimental point-cloud editor
37-
3838
for topic in AttachmentType:
3939
vars()[topic.name] = topic.value
4040

@@ -61,7 +61,7 @@ def validate_attachment_value(cls, attachment_value: str) -> None:
6161

6262
@classmethod
6363
def validate_attachment_type(cls, attachment_type: str) -> None:
64-
valid_types = set(cls.AttachmentType.__members__)
64+
valid_types = set(AttachmentType.__members__)
6565
if attachment_type not in valid_types:
6666
raise ValueError(
6767
f"attachment_type must be one of {valid_types}. Found {attachment_type}"

labelbox/schema/data_row.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
2-
from typing import TYPE_CHECKING, List, Optional, Union
2+
from enum import Enum
3+
from typing import TYPE_CHECKING, List, Optional, Union, Any
34
import json
45

56
from labelbox.orm import query
67
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable, experimental
78
from labelbox.orm.model import Entity, Field, Relationship
9+
from labelbox.schema.asset_attachment import AttachmentType
810
from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore
911
from labelbox.schema.export_filters import DatarowExportFilters, build_filters, validate_at_least_one_of_data_row_ids_or_global_keys
1012
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
@@ -17,6 +19,15 @@
1719
logger = logging.getLogger(__name__)
1820

1921

22+
class KeyType(str, Enum):
23+
ID = 'ID'
24+
"""An existing CUID"""
25+
GKEY = 'GKEY'
26+
"""A Global key, could be existing or non-existing"""
27+
AUTO = 'AUTO'
28+
"""The key will be auto-generated. Only usable for creates"""
29+
30+
2031
class DataRow(DbObject, Updateable, BulkDeletable):
2132
""" Internal Labelbox representation of a single piece of data (e.g. image, video, text).
2233
@@ -62,7 +73,7 @@ class DataRow(DbObject, Updateable, BulkDeletable):
6273
attachments = Relationship.ToMany("AssetAttachment", False, "attachments")
6374

6475
supported_meta_types = supported_attachment_types = set(
65-
Entity.AssetAttachment.AttachmentType.__members__)
76+
AttachmentType.__members__)
6677

6778
def __init__(self, *args, **kwargs):
6879
super().__init__(*args, **kwargs)
@@ -131,7 +142,7 @@ def create_attachment(self,
131142
132143
Args:
133144
attachment_type (str): Asset attachment type, must be one of:
134-
VIDEO, IMAGE, TEXT, IMAGE_OVERLAY (AssetAttachment.AttachmentType)
145+
VIDEO, IMAGE, TEXT, IMAGE_OVERLAY (AttachmentType)
135146
attachment_value (str): Asset attachment value.
136147
attachment_name (str): (Optional) Asset attachment name.
137148
Returns:

labelbox/schema/dataset.py

Lines changed: 135 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Dict, Generator, List, Optional, Union, Any
1+
from datetime import datetime
2+
from typing import Dict, Generator, List, Optional, Any, Final
23
import os
34
import json
45
import logging
@@ -14,18 +15,19 @@
1415
from io import StringIO
1516
import requests
1617

17-
from labelbox import pagination
1818
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
1919
from labelbox.orm.comparison import Comparison
2020
from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental
2121
from labelbox.orm.model import Entity, Field, Relationship
2222
from labelbox.orm import query
2323
from labelbox.exceptions import MalformedQueryException
2424
from labelbox.pagination import PaginatedCollection
25+
from labelbox.pydantic_compat import BaseModel
2526
from labelbox.schema.data_row import DataRow
2627
from labelbox.schema.export_filters import DatasetExportFilters, build_filters
2728
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
2829
from labelbox.schema.export_task import ExportTask
30+
from labelbox.schema.identifiable import UniqueId, GlobalKey
2931
from labelbox.schema.task import Task
3032
from labelbox.schema.user import User
3133

@@ -34,6 +36,11 @@
3436
MAX_DATAROW_PER_API_OPERATION = 150_000
3537

3638

39+
class DataRowUpsertItem(BaseModel):
40+
id: dict
41+
payload: dict
42+
43+
3744
class Dataset(DbObject, Updateable, Deletable):
3845
""" A Dataset is a collection of DataRows.
3946
@@ -47,6 +54,8 @@ class Dataset(DbObject, Updateable, Deletable):
4754
created_by (Relationship): `ToOne` relationship to User
4855
organization (Relationship): `ToOne` relationship to Organization
4956
"""
57+
__upsert_chunk_size: Final = 10_000
58+
5059
name = Field.String("name")
5160
description = Field.String("description")
5261
updated_at = Field.DateTime("updated_at")
@@ -64,16 +73,16 @@ def data_rows(
6473
from_cursor: Optional[str] = None,
6574
where: Optional[Comparison] = None,
6675
) -> PaginatedCollection:
67-
"""
76+
"""
6877
Custom method to paginate data_rows via cursor.
6978
7079
Args:
7180
from_cursor (str): Cursor (data row id) to start from, if none, will start from the beginning
72-
where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on.
81+
where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on.
7382
example: {'external_id': 'my_external_id'} to get a data row with external_id = 'my_external_id'
7483
7584
76-
NOTE:
85+
NOTE:
7786
Order of retrieval is newest data row first.
7887
Deleted data rows are not retrieved.
7988
Failed data rows are not retrieved.
@@ -293,7 +302,10 @@ def create_data_rows(self, items) -> "Task":
293302
task._user = user
294303
return task
295304

296-
def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
305+
def _create_descriptor_file(self,
306+
items,
307+
max_attachments_per_data_row=None,
308+
is_upsert=False):
297309
"""
298310
This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync`
299311
to prepare the input file. The user defined input is validated, processed, and json stringified.
@@ -346,6 +358,9 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
346358
AssetAttachment = Entity.AssetAttachment
347359

348360
def upload_if_necessary(item):
361+
if is_upsert and 'row_data' not in item:
362+
# When upserting, row_data is not required
363+
return item
349364
row_data = item['row_data']
350365
if isinstance(row_data, str) and os.path.exists(row_data):
351366
item_url = self.client.upload_file(row_data)
@@ -425,17 +440,17 @@ def format_row(item):
425440
return item
426441

427442
def validate_keys(item):
428-
if 'row_data' not in item:
443+
if not is_upsert and 'row_data' not in item:
429444
raise InvalidQueryError(
430445
"`row_data` missing when creating DataRow.")
431446

432447
if isinstance(item.get('row_data'),
433448
str) and item.get('row_data').startswith("s3:/"):
434449
raise InvalidQueryError(
435450
"row_data: s3 assets must start with 'https'.")
436-
invalid_keys = set(item) - {
437-
*{f.name for f in DataRow.fields()}, 'attachments', 'media_type'
438-
}
451+
allowed_extra_fields = {'attachments', 'media_type', 'dataset_id'}
452+
invalid_keys = set(item) - {f.name for f in DataRow.fields()
453+
} - allowed_extra_fields
439454
if invalid_keys:
440455
raise InvalidAttributeError(DataRow, invalid_keys)
441456
return item
@@ -460,7 +475,12 @@ def formatLegacyConversationalData(item):
460475
item["row_data"] = one_conversation
461476
return item
462477

463-
def convert_item(item):
478+
def convert_item(data_row_item):
479+
if isinstance(data_row_item, DataRowUpsertItem):
480+
item = data_row_item.payload
481+
else:
482+
item = data_row_item
483+
464484
if "tileLayerUrl" in item:
465485
validate_attachments(item)
466486
return item
@@ -478,7 +498,11 @@ def convert_item(item):
478498
parse_metadata_fields(item)
479499
# Upload any local file paths
480500
item = upload_if_necessary(item)
481-
return item
501+
502+
if isinstance(data_row_item, DataRowUpsertItem):
503+
return {'id': data_row_item.id, 'payload': item}
504+
else:
505+
return item
482506

483507
if not isinstance(items, Iterable):
484508
raise ValueError(
@@ -638,13 +662,13 @@ def export_v2(
638662
) -> Task:
639663
"""
640664
Creates a dataset export task with the given params and returns the task.
641-
665+
642666
>>> dataset = client.get_dataset(DATASET_ID)
643667
>>> task = dataset.export_v2(
644668
>>> filters={
645669
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
646670
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
647-
>>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...]
671+
>>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...]
648672
>>> },
649673
>>> params={
650674
>>> "performance_details": False,
@@ -749,3 +773,100 @@ def _export(
749773
res = res[mutation_name]
750774
task_id = res["taskId"]
751775
return Task.get_task(self.client, task_id)
776+
777+
def upsert_data_rows(self, items, file_upload_thread_count=20) -> "Task":
778+
"""
779+
Upserts data rows in this dataset. When "key" is provided, and it references an existing data row,
780+
an update will be performed. When "key" is not provided a new data row will be created.
781+
782+
>>> task = dataset.upsert_data_rows([
783+
>>> # create new data row
784+
>>> {
785+
>>> "row_data": "http://my_site.com/photos/img_01.jpg",
786+
>>> "global_key": "global_key1",
787+
>>> "external_id": "ex_id1",
788+
>>> "attachments": [
789+
>>> {"type": AttachmentType.RAW_TEXT, "name": "att1", "value": "test1"}
790+
>>> ],
791+
>>> "metadata": [
792+
>>> {"name": "tag", "value": "tag value"},
793+
>>> ]
794+
>>> },
795+
>>> # update global key of data row by existing global key
796+
>>> {
797+
>>> "key": GlobalKey("global_key1"),
798+
>>> "global_key": "global_key1_updated"
799+
>>> },
800+
>>> # update data row by ID
801+
>>> {
802+
>>> "key": UniqueId(dr.uid),
803+
>>> "external_id": "ex_id1_updated"
804+
>>> },
805+
>>> ])
806+
>>> task.wait_till_done()
807+
"""
808+
if len(items) > MAX_DATAROW_PER_API_OPERATION:
809+
raise MalformedQueryException(
810+
f"Cannot upsert more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call."
811+
)
812+
813+
specs = self._convert_items_to_upsert_format(items)
814+
chunks = [
815+
specs[i:i + self.__upsert_chunk_size]
816+
for i in range(0, len(specs), self.__upsert_chunk_size)
817+
]
818+
819+
def _upload_chunk(_chunk):
820+
return self._create_descriptor_file(_chunk, is_upsert=True)
821+
822+
with ThreadPoolExecutor(file_upload_thread_count) as executor:
823+
futures = [
824+
executor.submit(_upload_chunk, chunk) for chunk in chunks
825+
]
826+
chunk_uris = [future.result() for future in as_completed(futures)]
827+
828+
manifest = {
829+
"source": "SDK",
830+
"item_count": len(specs),
831+
"chunk_uris": chunk_uris
832+
}
833+
data = json.dumps(manifest).encode("utf-8")
834+
manifest_uri = self.client.upload_data(data,
835+
content_type="application/json",
836+
filename="manifest.json")
837+
838+
query_str = """
839+
mutation UpsertDataRowsPyApi($manifestUri: String!) {
840+
upsertDataRows(data: { manifestUri: $manifestUri }) {
841+
id createdAt updatedAt name status completionPercentage result errors type metadata
842+
}
843+
}
844+
"""
845+
846+
res = self.client.execute(query_str, {"manifestUri": manifest_uri})
847+
res = res["upsertDataRows"]
848+
task = Task(self.client, res)
849+
task._user = self.client.get_user()
850+
return task
851+
852+
def _convert_items_to_upsert_format(self, _items):
853+
_upsert_items: List[DataRowUpsertItem] = []
854+
for item in _items:
855+
# enforce current dataset's id for all specs
856+
item['dataset_id'] = self.uid
857+
key = item.pop('key', None)
858+
if not key:
859+
key = {'type': 'AUTO', 'value': ''}
860+
elif isinstance(key, UniqueId):
861+
key = {'type': 'ID', 'value': key.key}
862+
elif isinstance(key, GlobalKey):
863+
key = {'type': 'GKEY', 'value': key.key}
864+
else:
865+
raise ValueError(
866+
f"Key must be an instance of UniqueId or GlobalKey, got: {type(item['key']).__name__}"
867+
)
868+
item = {
869+
k: v for k, v in item.items() if v is not None
870+
} # remove None values
871+
_upsert_items.append(DataRowUpsertItem(payload=item, id=key))
872+
return _upsert_items

labelbox/schema/task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def wait_till_done(self,
6666
6767
Args:
6868
timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes.
69-
check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds.
69+
check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds.
7070
"""
7171
if check_frequency < 2.0:
7272
raise ValueError(
@@ -90,7 +90,7 @@ def wait_till_done(self,
9090
def errors(self) -> Optional[Dict[str, Any]]:
9191
""" Fetch the error associated with an import task.
9292
"""
93-
if self.name == 'JSON Import':
93+
if self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows':
9494
if self.status == "FAILED":
9595
result = self._fetch_remote_json()
9696
return result["error"]
@@ -168,7 +168,7 @@ def download_result(remote_json_field: Optional[str], format: str):
168168
"Expected the result format to be either `ndjson` or `json`."
169169
)
170170

171-
if self.name == 'JSON Import':
171+
if self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows':
172172
format = 'json'
173173
elif self.type == 'export-data-rows':
174174
format = 'ndjson'

0 commit comments

Comments
 (0)