Skip to content

Commit 8a44aaa

Browse files
committed
Merge remote-tracking branch 'origin/develop' into farkob/schema_id_name
2 parents 4ac94f3 + 00b5351 commit 8a44aaa

File tree

12 files changed

+350
-37
lines changed

12 files changed

+350
-37
lines changed

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# Changelog
22

3+
# Version 3.24.1 (2022-07-07)
4+
## Updated
5+
* Added `refresh_ontology()` as part of create/update/delete metadata schema functions
6+
7+
# Version 3.24.0 (2022-07-06)
8+
## Added
9+
* `DataRowMetadataOntology` class now has functions to create/update/delete metadata schema
10+
* `create_schema` - Create custom metadata schema
11+
* `update_schema` - Update name of custom metadata schema
12+
* `update_enum_options` - Update name of an Enum option for an Enum custom metadata schema
13+
* `delete_schema` - Delete custom metadata schema
14+
* `ModelRun` class now has `assign_data_rows_to_split` function, which can assign a `DataSplit` to a list of `DataRow`s
15+
* `Dataset.create_data_rows()` can bulk import `conversationalData`
16+
317
# Version 3.23.3 (2022-06-23)
418

519
## Fix

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ client = Client( endpoint = "<local deployment>")
8585
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="http://localhost:8080/graphql")
8686
8787
# Staging
88-
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="https://staging-api.labelbox.com/graphql")
88+
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="https://api.lb-stage.xyz/graphql")
8989
```
9090

9191
## Contribution
@@ -122,5 +122,5 @@ make test-prod # with an optional flag: PATH_TO_TEST=tests/integration/...etc LA
122122
make -B {build|test-staging|test-prod}
123123
```
124124

125-
6. Testing against Delegated Access will be skipped unless the local env contains the key:
126-
DA_GCP_LABELBOX_API_KEY. These tests will be included when run against a PR. If you would like to test it manually, please reach out to the Devops team for information on the key.
125+
6. Testing against Delegated Access will be skipped unless the local env contains the key:
126+
DA_GCP_LABELBOX_API_KEY. These tests will be included when run against a PR. If you would like to test it manually, please reach out to the Devops team for information on the key.

labelbox/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "labelbox"
2-
__version__ = "3.23.3"
2+
__version__ = "3.24.1"
33

44
from labelbox.client import Client
55
from labelbox.schema.project import Project
@@ -21,7 +21,7 @@
2121
from labelbox.schema.role import Role, ProjectRole
2222
from labelbox.schema.invite import Invite, InviteLimit
2323
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
24-
from labelbox.schema.model_run import ModelRun
24+
from labelbox.schema.model_run import ModelRun, DataSplit
2525
from labelbox.schema.benchmark import Benchmark
2626
from labelbox.schema.iam_integration import IAMIntegration
2727
from labelbox.schema.resource_tag import ResourceTag

labelbox/schema/data_row_metadata.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def delete_schema(self, name: str) -> bool:
366366
res = self._client.execute(query, {'where': {
367367
'id': schema.uid
368368
}})['deleteCustomMetadataSchema']
369+
self.refresh_ontology()
369370

370371
return res['success']
371372

@@ -642,6 +643,7 @@ def _upsert_schema(
642643
res = self._client.execute(
643644
query, {"data": upsert_schema.dict(exclude_none=True)
644645
})['upsertCustomMetadataSchema']
646+
self.refresh_ontology()
645647
return _parse_metadata_schema(res)
646648

647649
def _parse_upsert(

labelbox/schema/dataset.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _create_descriptor_file(self, items, max_attachments_per_data_row=None):
226226
>>> {DataRow.row_data:"/path/to/file1.jpg"},
227227
>>> "path/to/file2.jpg",
228228
>>> {"tileLayerUrl" : "http://", ...}
229+
>>> {"conversationalData" : [...], ...}
229230
>>> ])
230231
231232
For an example showing how to upload tiled data_rows see the following notebook:
@@ -280,6 +281,33 @@ def validate_attachments(item):
280281
)
281282
return attachments
282283

284+
def validate_conversational_data(conversational_data: list) -> None:
285+
"""
286+
Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json
287+
288+
Args:
289+
conversational_data (list): list of dictionaries.
290+
"""
291+
292+
def check_message_keys(message):
293+
accepted_message_keys = set([
294+
"messageId", "timestampUsec", "content", "user", "align",
295+
"canLabel"
296+
])
297+
for key in message.keys():
298+
if not key in accepted_message_keys:
299+
raise KeyError(
300+
f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}"
301+
)
302+
303+
if conversational_data and not isinstance(conversational_data,
304+
list):
305+
raise ValueError(
306+
f"conversationalData must be a list. Found {type(conversational_data)}"
307+
)
308+
309+
[check_message_keys(message) for message in conversational_data]
310+
283311
def parse_metadata_fields(item):
284312
metadata_fields = item.get('metadata_fields')
285313
if metadata_fields:
@@ -321,6 +349,27 @@ def convert_item(item):
321349
if "tileLayerUrl" in item:
322350
validate_attachments(item)
323351
return item
352+
353+
if "conversationalData" in item:
354+
messages = item.pop("conversationalData")
355+
version = item.pop("version")
356+
type = item.pop("type")
357+
if "externalId" in item:
358+
external_id = item.pop("externalId")
359+
item["external_id"] = external_id
360+
validate_conversational_data(messages)
361+
one_conversation = \
362+
{
363+
"type": type,
364+
"version": version,
365+
"messages": messages
366+
}
367+
conversationUrl = self.client.upload_data(
368+
json.dumps(one_conversation),
369+
content_type="application/json",
370+
filename="conversational_data.json")
371+
item["row_data"] = conversationUrl
372+
324373
# Convert all payload variations into the same dict format
325374
item = format_row(item)
326375
# Make sure required keys exist (and there are no extra keys)

labelbox/schema/model_run.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
# type: ignore
12
from typing import TYPE_CHECKING, Dict, Iterable, Union, List, Optional, Any
23
from pathlib import Path
34
import os
45
import time
56
import logging
67
import requests
78
import ndjson
9+
from enum import Enum
810

911
from labelbox.pagination import PaginatedCollection
1012
from labelbox.orm.query import results_query_part
@@ -17,13 +19,27 @@
1719
logger = logging.getLogger(__name__)
1820

1921

22+
class DataSplit(Enum):
23+
TRAINING = "TRAINING"
24+
TEST = "TEST"
25+
VALIDATION = "VALIDATION"
26+
UNASSIGNED = "UNASSIGNED"
27+
28+
2029
class ModelRun(DbObject):
2130
name = Field.String("name")
2231
updated_at = Field.DateTime("updated_at")
2332
created_at = Field.DateTime("created_at")
2433
created_by_id = Field.String("created_by_id", "createdBy")
2534
model_id = Field.String("model_id")
2635

36+
class Status(Enum):
37+
EXPORTING_DATA = "EXPORTING_DATA"
38+
PREPARING_DATA = "PREPARING_DATA"
39+
TRAINING_MODEL = "TRAINING_MODEL"
40+
COMPLETE = "COMPLETE"
41+
FAILED = "FAILED"
42+
2743
def upsert_labels(self, label_ids, timeout_seconds=60):
2844
""" Adds data rows and labels to a model run
2945
Args:
@@ -90,8 +106,9 @@ def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
90106
}})['MEADataRowRegistrationTaskStatus'],
91107
timeout_seconds=timeout_seconds)
92108

93-
def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
109+
def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5):
94110
# Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
111+
original_timeout = timeout_seconds
95112
while True:
96113
res = status_fn()
97114
if res['status'] == 'COMPLETE':
@@ -102,9 +119,8 @@ def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
102119
timeout_seconds -= sleep_time
103120
if timeout_seconds <= 0:
104121
raise TimeoutError(
105-
f"Unable to complete import within {timeout_seconds} seconds."
122+
f"Unable to complete import within {original_timeout} seconds."
106123
)
107-
108124
time.sleep(sleep_time)
109125

110126
def add_predictions(
@@ -161,7 +177,7 @@ def delete(self):
161177
deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param, ids_param)
162178
self.client.execute(query_str, {ids_param: str(self.uid)})
163179

164-
def delete_model_run_data_rows(self, data_row_ids):
180+
def delete_model_run_data_rows(self, data_row_ids: List[str]):
165181
""" Deletes data rows from model runs.
166182
167183
Args:
@@ -180,22 +196,62 @@ def delete_model_run_data_rows(self, data_row_ids):
180196
data_row_ids_param: data_row_ids
181197
})
182198

199+
@experimental
200+
def assign_data_rows_to_split(self,
201+
data_row_ids: List[str],
202+
split: Union[DataSplit, str],
203+
timeout_seconds=120):
204+
205+
split_value = split.value if isinstance(split, DataSplit) else split
206+
207+
if split_value == DataSplit.UNASSIGNED.value:
208+
raise ValueError(
209+
f"Cannot assign split value of `{DataSplit.UNASSIGNED.value}`.")
210+
211+
valid_splits = filter(lambda name: name != DataSplit.UNASSIGNED.value,
212+
DataSplit._member_names_)
213+
214+
if split_value not in valid_splits:
215+
raise ValueError(
216+
f"`split` must be one of : `{valid_splits}`. Found : `{split}`")
217+
218+
task_id = self.client.execute(
219+
"""mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){
220+
createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)}
221+
""", {
222+
'modelRunId': self.uid,
223+
'data': {
224+
'assignments': [{
225+
'split': split_value,
226+
'dataRowIds': data_row_ids
227+
}]
228+
}
229+
},
230+
experimental=True)['createAssignDataRowsToDataSplitTask']
231+
232+
status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){
233+
assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}}
234+
"""
235+
236+
return self._wait_until_done(lambda: self.client.execute(
237+
status_query_str, {'id': task_id}, experimental=True)[
238+
'assignDataRowsToDataSplitTaskStatus'],
239+
timeout_seconds=timeout_seconds)
240+
183241
@experimental
184242
def update_status(self,
185-
status: str,
243+
status: Union[str, "ModelRun.Status"],
186244
metadata: Optional[Dict[str, str]] = None,
187245
error_message: Optional[str] = None):
188246

189-
valid_statuses = [
190-
"EXPORTING_DATA", "PREPARING_DATA", "TRAINING_MODEL", "COMPLETE",
191-
"FAILED"
192-
]
193-
if status not in valid_statuses:
247+
status_value = status.value if isinstance(status,
248+
ModelRun.Status) else status
249+
if status_value not in ModelRun.Status._member_names_:
194250
raise ValueError(
195-
f"Status must be one of : `{valid_statuses}`. Found : `{status}`"
251+
f"Status must be one of : `{ModelRun.Status._member_names_}`. Found : `{status_value}`"
196252
)
197253

198-
data: Dict[str, Any] = {'status': status}
254+
data: Dict[str, Any] = {'status': status_value}
199255
if error_message:
200256
data['errorMessage'] = error_message
201257

@@ -264,6 +320,7 @@ def export_labels(
264320
class ModelRunDataRow(DbObject):
265321
label_id = Field.String("label_id")
266322
model_run_id = Field.String("model_run_id")
323+
data_split = Field.Enum(DataSplit, "data_split")
267324
data_row = Relationship.ToOne("DataRow", False, cache=True)
268325

269326
def __init__(self, client, model_id, *args, **kwargs):

tests/integration/annotation_import/test_model_run.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import os
33
import pytest
44

5+
from collections import Counter
6+
from labelbox import DataSplit, ModelRun
7+
58

69
def test_model_run(client, configured_project_with_label, rand_gen):
710
project, _, _, label = configured_project_with_label
@@ -119,3 +122,40 @@ def get_model_run_status():
119122
assert model_run_status['status'] == status
120123
assert model_run_status['metadata'] == {**metadata, **extra_metadata}
121124
assert model_run_status['errorMessage'] == errorMessage
125+
126+
status = ModelRun.Status.FAILED
127+
model_run_with_model_run_data_rows.update_status(status, metadata,
128+
errorMessage)
129+
model_run_status = get_model_run_status()
130+
assert model_run_status['status'] == status.value
131+
132+
with pytest.raises(ValueError):
133+
model_run_with_model_run_data_rows.update_status(
134+
"INVALID", metadata, errorMessage)
135+
136+
137+
def test_model_run_split_assignment(model_run, dataset, image_url):
138+
n_data_rows = 10
139+
data_rows = dataset.create_data_rows([{
140+
"row_data": image_url
141+
} for _ in range(n_data_rows)])
142+
data_row_ids = [data_row['id'] for data_row in data_rows.result]
143+
144+
model_run.upsert_data_rows(data_row_ids)
145+
146+
with pytest.raises(ValueError):
147+
model_run.assign_data_rows_to_split(data_row_ids, "INVALID SPLIT")
148+
149+
with pytest.raises(ValueError):
150+
model_run.assign_data_rows_to_split(data_row_ids, DataSplit.UNASSIGNED)
151+
152+
for split in ["TRAINING", "TEST", "VALIDATION", *DataSplit]:
153+
if split == DataSplit.UNASSIGNED:
154+
continue
155+
156+
model_run.assign_data_rows_to_split(data_row_ids, split)
157+
counts = Counter()
158+
for data_row in model_run.model_run_data_rows():
159+
counts[data_row.data_split.value] += 1
160+
split = split.value if isinstance(split, DataSplit) else split
161+
assert counts[split] == n_data_rows

tests/integration/conftest.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import re
34
import time
@@ -145,7 +146,10 @@ def client(environ: str):
145146

146147
@pytest.fixture(scope="session")
147148
def image_url(client):
148-
return client.upload_data(requests.get(IMG_URL).content, sign=True)
149+
return client.upload_data(requests.get(IMG_URL).content,
150+
content_type="application/json",
151+
filename="json_import.json",
152+
sign=True)
149153

150154

151155
@pytest.fixture
@@ -181,16 +185,23 @@ def iframe_url(environ) -> str:
181185
if environ in [Environ.PROD, Environ.LOCAL]:
182186
return 'https://editor.labelbox.com'
183187
elif environ == Environ.STAGING:
184-
return 'https://staging.labelbox.dev/editor'
188+
return 'https://editor.lb-stage.xyz'
185189

186190

187191
@pytest.fixture
188192
def sample_video() -> str:
189193
path_to_video = 'tests/integration/media/cat.mp4'
190-
assert os.path.exists(path_to_video)
191194
return path_to_video
192195

193196

197+
@pytest.fixture
198+
def sample_bulk_conversation() -> list:
199+
path_to_conversation = 'tests/integration/media/bulk_conversation.json'
200+
with open(path_to_conversation) as json_file:
201+
conversations = json.load(json_file)
202+
return conversations
203+
204+
194205
@pytest.fixture
195206
def organization(client):
196207
# Must have at least one seat open in your org to run these tests
@@ -290,7 +301,7 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset,
290301

291302
def create_label():
292303
""" Ad-hoc function to create a LabelImport
293-
304+
294305
Creates a LabelImport task which will create a label
295306
"""
296307
upload_task = LabelImport.create_from_objects(

0 commit comments

Comments
 (0)