Skip to content

Commit edb9e18

Browse files
author
Matt Sokoloff
committed
assign data row split
1 parent 0197e0a commit edb9e18

File tree

4 files changed

+69
-8
lines changed

4 files changed

+69
-8
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ client = Client( endpoint = "<local deployment>")
8585
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="http://localhost:8080/graphql")
8686
8787
# Staging
88-
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="https://staging-api.labelbox.com/graphql")
88+
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="https://api.lb-stage.xyz/graphql")
8989
```
9090

9191
## Contribution
@@ -122,5 +122,5 @@ make test-prod # with an optional flag: PATH_TO_TEST=tests/integration/...etc LA
122122
make -B {build|test-staging|test-prod}
123123
```
124124

125-
6. Testing against Delegated Access will be skipped unless the local env contains the key:
126-
DA_GCP_LABELBOX_API_KEY. These tests will be included when run against a PR. If you would like to test it manually, please reach out to the Devops team for information on the key.
125+
6. Testing against Delegated Access will be skipped unless the local env contains the key:
126+
DA_GCP_LABELBOX_API_KEY. These tests will be included when run against a PR. If you would like to test it manually, please reach out to the Devops team for information on the key.

labelbox/schema/model_run.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
9292

9393
def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
9494
# Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
95+
original_timeout = timeout_seconds
9596
while True:
9697
res = status_fn()
9798
if res['status'] == 'COMPLETE':
@@ -102,7 +103,7 @@ def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
102103
timeout_seconds -= sleep_time
103104
if timeout_seconds <= 0:
104105
raise TimeoutError(
105-
f"Unable to complete import within {timeout_seconds} seconds."
106+
f"Unable to complete import within {original_timeout} seconds."
106107
)
107108

108109
time.sleep(sleep_time)
@@ -180,6 +181,39 @@ def delete_model_run_data_rows(self, data_row_ids):
180181
data_row_ids_param: data_row_ids
181182
})
182183

184+
@experimental
185+
def assign_data_rows_to_split(self,
186+
data_row_ids,
187+
split,
188+
timeout_seconds=60):
189+
valid_splits = ["TRAINING", "TEST", "VALIDATION"]
190+
if split not in valid_splits:
191+
raise ValueError(
192+
f"split must be one of : `{valid_splits}`. Found : `{split}`")
193+
194+
task_id = self.client.execute(
195+
"""mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){
196+
createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)}
197+
""", {
198+
'modelRunId': self.uid,
199+
'data': {
200+
'assignments': [{
201+
'split': split,
202+
'dataRowIds': data_row_ids
203+
}]
204+
}
205+
},
206+
experimental=True)['createAssignDataRowsToDataSplitTask']
207+
208+
status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){
209+
assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}}
210+
"""
211+
212+
return self._wait_until_done(lambda: self.client.execute(
213+
status_query_str, {'id': task_id}, experimental=True)[
214+
'assignDataRowsToDataSplitTaskStatus'],
215+
timeout_seconds=timeout_seconds)
216+
183217
@experimental
184218
def update_status(self,
185219
status: str,
@@ -264,6 +298,7 @@ def export_labels(
264298
class ModelRunDataRow(DbObject):
265299
label_id = Field.String("label_id")
266300
model_run_id = Field.String("model_run_id")
301+
data_split = Field.String("data_split")
267302
data_row = Relationship.ToOne("DataRow", False, cache=True)
268303

269304
def __init__(self, client, model_id, *args, **kwargs):

tests/integration/annotation_import/test_model_run.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import os
33
import pytest
44

5+
from collections import Counter
6+
57

68
def test_model_run(client, configured_project_with_label, rand_gen):
79
project, _, _, label = configured_project_with_label
@@ -119,3 +121,24 @@ def get_model_run_status():
119121
assert model_run_status['status'] == status
120122
assert model_run_status['metadata'] == {**metadata, **extra_metadata}
121123
assert model_run_status['errorMessage'] == errorMessage
124+
125+
126+
def test_model_run_split_assignment(model_run, dataset, image_url):
127+
n_data_rows = 10
128+
data_rows = dataset.create_data_rows([{
129+
"row_data": image_url
130+
} for _ in range(n_data_rows)])
131+
data_row_ids = [data_row['id'] for data_row in data_rows.result]
132+
133+
model_run.upsert_data_rows(data_row_ids)
134+
135+
for split in ["TRAINING", "TEST", "VALIDATION"]:
136+
model_run.assign_data_rows_to_split(data_row_ids[:(n_data_rows // 2)],
137+
split)
138+
counts = Counter()
139+
for data_row in model_run.model_run_data_rows():
140+
counts[data_row.data_split] += 1
141+
assert counts[split] == n_data_rows // 2
142+
143+
with pytest.raises(ValueError):
144+
model_run.assign_data_rows_to_split(data_row_ids, "INVALID SPLIT")

tests/integration/conftest.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def graphql_url(environ: str) -> str:
4747
if environ == Environ.PROD:
4848
return 'https://api.labelbox.com/graphql'
4949
elif environ == Environ.STAGING:
50-
return 'https://staging-api.labelbox.com/graphql'
50+
return 'https://api.lb-stage.xyz/graphql'
5151
elif environ == Environ.ONPREM:
5252
hostname = os.environ.get('LABELBOX_TEST_ONPREM_HOSTNAME', None)
5353
if hostname is None:
@@ -145,7 +145,10 @@ def client(environ: str):
145145

146146
@pytest.fixture(scope="session")
147147
def image_url(client):
148-
return client.upload_data(requests.get(IMG_URL).content, sign=True)
148+
return client.upload_data(requests.get(IMG_URL).content,
149+
content_type="application/json",
150+
filename="json_import.json",
151+
sign=True)
149152

150153

151154
@pytest.fixture
@@ -181,7 +184,7 @@ def iframe_url(environ) -> str:
181184
if environ in [Environ.PROD, Environ.LOCAL]:
182185
return 'https://editor.labelbox.com'
183186
elif environ == Environ.STAGING:
184-
return 'https://staging.labelbox.dev/editor'
187+
return 'https://editor.lb-stage.xyz'
185188

186189

187190
@pytest.fixture
@@ -290,7 +293,7 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset,
290293

291294
def create_label():
292295
""" Ad-hoc function to create a LabelImport
293-
296+
294297
Creates a LabelImport task which will create a label
295298
"""
296299
upload_task = LabelImport.create_from_objects(

0 commit comments

Comments
 (0)