Skip to content

Commit 43260f0

Browse files
authored
Merge pull request #734 from Labelbox/ms/al-3814
data row content under row_data
2 parents 3b2f417 + c1812cd commit 43260f0

File tree

6 files changed

+234
-40
lines changed

6 files changed

+234
-40
lines changed

labelbox/schema/data_row.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from typing import TYPE_CHECKING
3+
import json
34

45
from labelbox.orm import query
56
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable
@@ -64,6 +65,14 @@ def __init__(self, *args, **kwargs):
6465
self.attachments.supports_filtering = False
6566
self.attachments.supports_sorting = False
6667

68+
def update(self, **kwargs):
69+
# Convert row data to string if it is an object
70+
# All other updates pass through
71+
row_data = kwargs.get("row_data")
72+
if isinstance(row_data, dict):
73+
kwargs['row_data'] = json.dumps(kwargs['row_data'])
74+
super().update(**kwargs)
75+
6776
@staticmethod
6877
def bulk_delete(data_rows) -> None:
6978
""" Deletes all the given DataRows.

labelbox/schema/dataset.py

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
1616
from labelbox.orm.db_object import DbObject, Updateable, Deletable
1717
from labelbox.orm.model import Entity, Field, Relationship
18+
from labelbox.orm import query
1819
from labelbox.exceptions import MalformedQueryException
1920

2021
if TYPE_CHECKING:
@@ -95,18 +96,46 @@ def convert_field_keys(items):
9596
raise InvalidQueryError(
9697
"DataRow.row_data missing when creating DataRow.")
9798

98-
# If row data is a local file path, upload it to server.
9999
row_data = args[DataRow.row_data.name]
100-
if os.path.exists(row_data):
100+
if not isinstance(row_data, str):
101+
# If the row data is an object, upload as a string
102+
args[DataRow.row_data.name] = json.dumps(row_data)
103+
elif os.path.exists(row_data):
104+
# If row data is a local file path, upload it to server.
101105
args[DataRow.row_data.name] = self.client.upload_file(row_data)
102-
args[DataRow.dataset.name] = self
103106

104107
# Parse metadata fields, if they are provided
105108
if DataRow.metadata_fields.name in args:
106109
mdo = self.client.get_data_row_metadata_ontology()
107110
args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata(
108111
args[DataRow.metadata_fields.name])
109-
return self.client._create(DataRow, args)
112+
113+
query_str = """mutation CreateDataRowPyApi(
114+
$row_data: String!,
115+
$metadata_fields: [DataRowCustomMetadataUpsertInput!],
116+
$attachments: [DataRowAttachmentInput!],
117+
$media_type : MediaType,
118+
$external_id : String,
119+
$global_key : String,
120+
$dataset: ID!
121+
){
122+
createDataRow(
123+
data:
124+
{
125+
rowData: $row_data
126+
mediaType: $media_type
127+
metadataFields: $metadata_fields
128+
externalId: $external_id
129+
globalKey: $global_key
130+
attachments: $attachments
131+
dataset: {connect: {id: $dataset}}
132+
}
133+
)
134+
{%s}
135+
}
136+
""" % query.results_query_part(Entity.DataRow)
137+
res = self.client.execute(query_str, {**args, 'dataset': self.uid})
138+
return DataRow(self.client, res['createDataRow'])
110139

111140
def create_data_rows_sync(self, items) -> None:
112141
""" Synchronously bulk upload data rows.
@@ -229,8 +258,8 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
229258
>>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"},
230259
>>> {DataRow.row_data:"/path/to/file1.jpg"},
231260
>>> "path/to/file2.jpg",
232-
>>> {"tileLayerUrl" : "http://", ...}
233-
>>> {"conversationalData" : [...], ...}
261+
>>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}}
262+
>>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}}
234263
>>> ])
235264
236265
For an example showing how to upload tiled data_rows see the following notebook:
@@ -258,7 +287,7 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
258287

259288
def upload_if_necessary(item):
260289
row_data = item['row_data']
261-
if os.path.exists(row_data):
290+
if isinstance(row_data, str) and os.path.exists(row_data):
262291
item_url = self.client.upload_file(row_data)
263292
item['row_data'] = item_url
264293
if 'external_id' not in item:
@@ -341,40 +370,39 @@ def validate_keys(item):
341370
"`row_data` missing when creating DataRow.")
342371

343372
invalid_keys = set(item) - {
344-
*{f.name for f in DataRow.fields()}, 'attachments'
373+
*{f.name for f in DataRow.fields()}, 'attachments', 'media_type'
345374
}
346375
if invalid_keys:
347376
raise InvalidAttributeError(DataRow, invalid_keys)
348377
return item
349378

379+
def formatLegacyConversationalData(item):
380+
messages = item.pop("conversationalData")
381+
version = item.pop("version", 1)
382+
type = item.pop("type", "application/vnd.labelbox.conversational")
383+
if "externalId" in item:
384+
external_id = item.pop("externalId")
385+
item["external_id"] = external_id
386+
if "globalKey" in item:
387+
global_key = item.pop("globalKey")
388+
item["globalKey"] = global_key
389+
validate_conversational_data(messages)
390+
one_conversation = \
391+
{
392+
"type": type,
393+
"version": version,
394+
"messages": messages
395+
}
396+
item["row_data"] = one_conversation
397+
return item
398+
350399
def convert_item(item):
351-
# Don't make any changes to tms data
352400
if "tileLayerUrl" in item:
353401
validate_attachments(item)
354402
return item
355403

356404
if "conversationalData" in item:
357-
messages = item.pop("conversationalData")
358-
version = item.pop("version")
359-
type = item.pop("type")
360-
if "externalId" in item:
361-
external_id = item.pop("externalId")
362-
item["external_id"] = external_id
363-
if "globalKey" in item:
364-
global_key = item.pop("globalKey")
365-
item["globalKey"] = global_key
366-
validate_conversational_data(messages)
367-
one_conversation = \
368-
{
369-
"type": type,
370-
"version": version,
371-
"messages": messages
372-
}
373-
conversationUrl = self.client.upload_data(
374-
json.dumps(one_conversation),
375-
content_type="application/json",
376-
filename="conversational_data.json")
377-
item["row_data"] = conversationUrl
405+
formatLegacyConversationalData(item)
378406

379407
# Convert all payload variations into the same dict format
380408
item = format_row(item)
@@ -386,11 +414,7 @@ def convert_item(item):
386414
parse_metadata_fields(item)
387415
# Upload any local file paths
388416
item = upload_if_necessary(item)
389-
390-
return {
391-
"data" if key == "row_data" else utils.camel_case(key): value
392-
for key, value in item.items()
393-
}
417+
return item
394418

395419
if not isinstance(items, Iterable):
396420
raise ValueError(

labelbox/schema/media_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ class MediaType(Enum):
2121

2222
@classmethod
2323
def _missing_(cls, name):
24-
"""Handle missing null data types for projects
24+
"""Handle missing null data types for projects
2525
created without setting allowedMediaType
26-
Handle upper case names for compatibility with
26+
Handle upper case names for compatibility with
2727
the GraphQL"""
2828

2929
if name is None:

tests/integration/test_data_row_media_attributes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ def test_export_empty_media_attributes(configured_project_with_label):
77
sleep(10)
88
labels = project.label_generator()
99
label = next(labels)
10-
assert label.data.media_attributes == {}
10+
assert label.data.media_attributes == {}

tests/integration/test_data_rows.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from tempfile import NamedTemporaryFile
22
import uuid
33
from datetime import datetime
4+
import json
45

56
import pytest
67
import requests
@@ -28,6 +29,56 @@ def mdo(client):
2829
yield mdo
2930

3031

32+
@pytest.fixture
33+
def conversational_content():
34+
return {
35+
'row_data': {
36+
"messages": [{
37+
"messageId": "message-0",
38+
"timestampUsec": 1530718491,
39+
"content": "I love iphone! i just bought new iphone! 🥰 📲",
40+
"user": {
41+
"userId": "Bot 002",
42+
"name": "Bot"
43+
},
44+
"align": "left",
45+
"canLabel": False
46+
}],
47+
"version": 1,
48+
"type": "application/vnd.labelbox.conversational"
49+
}
50+
}
51+
52+
53+
@pytest.fixture
54+
def tile_content():
55+
return {
56+
"row_data": {
57+
"tileLayerUrl":
58+
"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png",
59+
"bounds": [[19.405662413477728, -99.21052827588443],
60+
[19.400498983095076, -99.20534818927473]],
61+
"minZoom":
62+
12,
63+
"maxZoom":
64+
20,
65+
"epsg":
66+
"EPSG4326",
67+
"alternativeLayers": [{
68+
"tileLayerUrl":
69+
"https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v11/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw",
70+
"name":
71+
"Satellite"
72+
}, {
73+
"tileLayerUrl":
74+
"https://api.mapbox.com/styles/v1/mapbox/navigation-guidance-night-v4/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw",
75+
"name":
76+
"Guidance"
77+
}]
78+
}
79+
}
80+
81+
3182
def make_metadata_fields():
3283
embeddings = [0.0] * 128
3384
msg = "A message"
@@ -408,6 +459,18 @@ def test_data_row_update(dataset, rand_gen, image_url):
408459
data_row.update(external_id=external_id_2)
409460
assert data_row.external_id == external_id_2
410461

462+
in_line_content = "123"
463+
data_row.update(row_data=in_line_content)
464+
assert requests.get(data_row.row_data).text == in_line_content
465+
466+
data_row.update(row_data=image_url)
467+
assert data_row.row_data == image_url
468+
469+
# tileLayer becomes a media attribute
470+
pdf_url = "http://somepdfurl"
471+
data_row.update(row_data={'pdfUrl': pdf_url, "tileLayerUrl": "123"})
472+
assert data_row.row_data == pdf_url
473+
411474

412475
def test_data_row_filtering_sorting(dataset, image_url):
413476
task = dataset.create_data_rows([
@@ -696,3 +759,74 @@ def test_data_row_rulk_creation_sync_with_same_global_keys(
696759

697760
assert len(list(dataset.data_rows())) == 1
698761
assert list(dataset.data_rows())[0].global_key == global_key_1
762+
763+
764+
def test_create_conversational_text(dataset, conversational_content):
765+
examples = [
766+
{
767+
**conversational_content, 'media_type': 'CONVERSATIONAL'
768+
},
769+
conversational_content,
770+
{
771+
"conversationalData": conversational_content['row_data']['messages']
772+
} # Old way to check for backwards compatibility
773+
]
774+
dataset.create_data_rows_sync(examples)
775+
data_rows = list(dataset.data_rows())
776+
assert len(data_rows) == len(examples)
777+
for data_row in data_rows:
778+
assert requests.get(
779+
data_row.row_data).json() == conversational_content['row_data']
780+
781+
782+
def test_invalid_media_type(dataset, conversational_content):
783+
for error_message, invalid_media_type in [[
784+
"Found invalid contents for media type: 'IMAGE'", 'IMAGE'
785+
], ["Found invalid media type: 'totallyinvalid'", 'totallyinvalid']]:
786+
# TODO: What error kind should this be? It looks like for global key we are
787+
# using malformed query. But for invalid contents in FileUploads we use InvalidQueryError
788+
with pytest.raises(labelbox.exceptions.InvalidQueryError):
789+
dataset.create_data_rows_sync([{
790+
**conversational_content, 'media_type': invalid_media_type
791+
}])
792+
793+
task = dataset.create_data_rows([{
794+
**conversational_content, 'media_type': invalid_media_type
795+
}])
796+
task.wait_till_done()
797+
assert task.errors == {'message': error_message}
798+
799+
800+
def test_create_tiled_layer(dataset, tile_content):
801+
examples = [
802+
{
803+
**tile_content, 'media_type': 'TMS_SIMPLE'
804+
},
805+
tile_content,
806+
tile_content['row_data'] # Old way to check for backwards compatibility
807+
]
808+
dataset.create_data_rows_sync(examples)
809+
data_rows = list(dataset.data_rows())
810+
assert len(data_rows) == len(examples)
811+
for data_row in data_rows:
812+
assert json.loads(data_row.row_data) == tile_content['row_data']
813+
814+
815+
def test_create_data_row_with_attachments(dataset):
816+
attachment_value = 'attachment value'
817+
dr = dataset.create_data_row(row_data="123",
818+
attachments=[{
819+
'type': 'TEXT',
820+
'value': attachment_value
821+
}])
822+
attachments = list(dr.attachments())
823+
assert len(attachments) == 1
824+
825+
826+
def test_create_data_row_with_media_type(dataset, image_url):
827+
with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc:
828+
dr = dataset.create_data_row(
829+
row_data={'invalid_object': 'invalid_value'}, media_type="IMAGE")
830+
assert "Found invalid contents for media type: \'IMAGE\'" in str(exc.value)
831+
832+
dataset.create_data_row(row_data=image_url, media_type="IMAGE")

0 commit comments

Comments
 (0)