|
1 | 1 | import pytest |
2 | 2 | from labelbox import DataRow |
| 3 | +from labelbox.schema.identifiable import GlobalKey, UniqueId |
3 | 4 | from labelbox.schema.identifiables import GlobalKeys, UniqueIds |
4 | 5 |
|
5 | 6 |
|
@@ -27,17 +28,36 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): |
27 | 28 | for override in updated_overrides: |
28 | 29 | assert isinstance(override.data_row(), DataRow) |
29 | 30 |
|
| 31 | + data = [(UniqueId(data_rows[0].uid), 1, 2), (UniqueId(data_rows[1].uid), 2), |
| 32 | + (UniqueId(data_rows[2].uid), 3)] |
| 33 | + success = project.set_labeling_parameter_overrides(data) |
| 34 | + assert success |
| 35 | + updated_overrides = list(project.labeling_parameter_overrides()) |
| 36 | + assert len(updated_overrides) == 3 |
| 37 | + assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} |
| 38 | + assert {o.priority for o in updated_overrides} == {1, 2, 3} |
| 39 | + |
| 40 | + data = [(GlobalKey(data_rows[0].global_key), 2, 2), |
| 41 | + (GlobalKey(data_rows[1].global_key), 3, 3), |
| 42 | + (GlobalKey(data_rows[2].global_key), 4)] |
| 43 | + success = project.set_labeling_parameter_overrides(data) |
| 44 | + assert success |
| 45 | + updated_overrides = list(project.labeling_parameter_overrides()) |
| 46 | + assert len(updated_overrides) == 3 |
| 47 | + assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} |
| 48 | + assert {o.priority for o in updated_overrides} == {2, 3, 4} |
| 49 | + |
30 | 50 | with pytest.raises(TypeError) as exc_info: |
31 | 51 | data = [(data_rows[2], "a_string", 3)] |
32 | 52 | project.set_labeling_parameter_overrides(data) |
33 | 53 | assert str(exc_info.value) == \ |
34 | | - f"Priority must be an int. Found <class 'str'> for data_row {data_rows[2]}. Index: 0" |
| 54 | + f"Priority must be an int. Found <class 'str'> for data_row_identifier {data_rows[2].uid}" |
35 | 55 |
|
36 | 56 | with pytest.raises(TypeError) as exc_info: |
37 | 57 | data = [(data_rows[2].uid, 1)] |
38 | 58 | project.set_labeling_parameter_overrides(data) |
39 | 59 | assert str(exc_info.value) == \ |
40 | | - "data_row should be be of type DataRow. Found <class 'str'>. Index: 0" |
| 60 | + f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found <class 'str'> for data_row_identifier {data_rows[2].uid}" |
41 | 61 |
|
42 | 62 |
|
43 | 63 | def test_set_labeling_priority(consensus_project_with_batch): |
|
0 commit comments