Skip to content

Commit b3b8e3a

Browse files
committed
[AL-4870] Split model run data rows using global keys
1 parent 00049d9 commit b3b8e3a

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
# Version 3.40.0 (YYYY-MM-DD)
44

5-
## Added
6-
* Insert newest changelogs here
5+
## Added
6+
* Support Global keys to reference data rows in `Project.create_batch()`, `ModelRun.assign_data_rows_to_split()`.
7+
78

89
# Version 3.39.0 (2023-02-28)
910
## Added

labelbox/schema/model_run.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,15 @@ def delete_model_run_data_rows(self, data_row_ids: List[str]):
335335

336336
@experimental
337337
def assign_data_rows_to_split(self,
338-
data_row_ids: List[str],
339-
split: Union[DataSplit, str],
338+
data_row_ids: List[str] = None,
339+
split: Union[DataSplit, str] = None,
340+
global_keys: List[str] = None,
340341
timeout_seconds=120):
341342

342343
split_value = split.value if isinstance(split, DataSplit) else split
343344
valid_splits = DataSplit._member_names_
344345

345-
if split_value not in valid_splits:
346+
if split_value is None or split_value not in valid_splits:
346347
raise ValueError(
347348
f"`split` must be one of : `{valid_splits}`. Found : `{split}`")
348349

@@ -354,7 +355,8 @@ def assign_data_rows_to_split(self,
354355
'data': {
355356
'assignments': [{
356357
'split': split_value,
357-
'dataRowIds': data_row_ids
358+
'dataRowIds': data_row_ids,
359+
'globalKeys': global_keys,
358360
}]
359361
}
360362
},

tests/integration/annotation_import/test_model_run.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
22
import os
3+
import uuid
34
import pytest
45

56
from collections import Counter
@@ -208,7 +209,8 @@ def test_model_run_export_v2(model_run_with_model_run_data_rows,
208209
assert prediction_id in label_ids_set
209210

210211

211-
def test_model_run_split_assignment(model_run, dataset, image_url):
212+
def test_model_run_split_assignment_by_data_row_ids(model_run, dataset,
213+
image_url):
212214
n_data_rows = 10
213215
data_rows = dataset.create_data_rows([{
214216
"row_data": image_url
@@ -227,3 +229,18 @@ def test_model_run_split_assignment(model_run, dataset, image_url):
227229
counts[data_row.data_split.value] += 1
228230
split = split.value if isinstance(split, DataSplit) else split
229231
assert counts[split] == n_data_rows
232+
233+
234+
def test_model_run_split_assignment_by_global_keys(model_run, data_rows):
235+
global_keys = [data_row.global_key for data_row in data_rows]
236+
237+
model_run.upsert_data_rows(global_keys=global_keys)
238+
239+
for split in ["TRAINING", "TEST", "VALIDATION", "UNASSIGNED", *DataSplit]:
240+
model_run.assign_data_rows_to_split(split=split,
241+
global_keys=global_keys)
242+
splits = [
243+
data_row.data_split.value
244+
for data_row in model_run.model_run_data_rows()
245+
]
246+
assert len(set(splits)) == 1

0 commit comments

Comments
 (0)