Skip to content

Commit f95a13a

Browse files
authored
Merge pull request #669 from Labelbox/mmw/AL-3061
Update config [AL-3061]
2 parents d44e091 + cc2462c commit f95a13a

File tree

4 files changed

+99
-9
lines changed

4 files changed

+99
-9
lines changed

CHANGELOG.md

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Changelog
22

3+
# Version 0.0.0 (YYYY-MM-DD)
4+
### Added
5+
* `ModelRun.update_config()`
6+
* Updates model run training metadata
7+
* `ModelRun.reset_config()`
8+
* Resets model run training metadata
9+
* `ModelRun.get_config()`
10+
* Fetches model run training metadata
11+
12+
### Changed
13+
* `Model.create_model_run()`
14+
* Add training metadata config as a model run creation param
15+
316
# Version 3.26.0 (2022-08-15)
417
## Added
518
* `Batch.delete()` which will delete an existing `Batch`
@@ -663,7 +676,3 @@ a `Label`. Default value is 0.0.
663676

664677
## Version 2.2 (2019-10-18)
665678
Changelog not maintained before version 2.2.
666-
667-
### Changed
668-
* `Model.create_model_run()`
669-
* Add training metadata config as a model run creation param

labelbox/schema/model_run.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Status(Enum):
4242
FAILED = "FAILED"
4343

4444
def upsert_labels(self, label_ids, timeout_seconds=60):
45-
""" Adds data rows and labels to a model run
45+
""" Adds data rows and labels to a Model Run
4646
Args:
4747
label_ids (list): label ids to insert
4848
timeout_seconds (float): Max waiting time, in seconds.
@@ -75,7 +75,7 @@ def upsert_labels(self, label_ids, timeout_seconds=60):
7575
timeout_seconds=timeout_seconds)
7676

7777
def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
78-
""" Adds data rows to a model run without any associated labels
78+
""" Adds data rows to a Model Run without any associated labels
7979
Args:
8080
data_row_ids (list): data row ids to add to mea
8181
timeout_seconds (float): Max waiting time, in seconds.
@@ -167,7 +167,7 @@ def model_run_data_rows(self):
167167
['annotationGroups', 'pageInfo', 'endCursor'])
168168

169169
def delete(self):
170-
""" Deletes specified model run.
170+
""" Deletes specified Model Run.
171171
172172
Returns:
173173
Query execution success.
@@ -178,10 +178,10 @@ def delete(self):
178178
self.client.execute(query_str, {ids_param: str(self.uid)})
179179

180180
def delete_model_run_data_rows(self, data_row_ids: List[str]):
181-
""" Deletes data rows from model runs.
181+
""" Deletes data rows from Model Runs.
182182
183183
Args:
184-
data_row_ids (list): List of data row ids to delete from the model run.
184+
data_row_ids (list): List of data row ids to delete from the Model Run.
185185
Returns:
186186
Query execution success.
187187
"""
@@ -262,6 +262,56 @@ def update_status(self,
262262
},
263263
experimental=True)
264264

265+
@experimental
266+
def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
267+
"""
268+
Updates the Model Run's training metadata config
269+
Args:
270+
config (dict): A dictionary of keys and values
271+
Returns:
272+
Model Run id and updated training metadata
273+
"""
274+
data: Dict[str, Any] = {'config': config}
275+
res = self.client.execute(
276+
"""mutation updateModelRunConfigPyApi($modelRunId: ID!, $data: UpdateModelRunConfigInput!){
277+
updateModelRunConfig(modelRun: {id : $modelRunId}, data: $data){trainingMetadata}
278+
}
279+
""", {
280+
'modelRunId': self.uid,
281+
'data': data
282+
},
283+
experimental=True)
284+
return res["updateModelRunConfig"]
285+
286+
@experimental
287+
def reset_config(self) -> Dict[str, Any]:
288+
"""
289+
Resets Model Run's training metadata config
290+
Returns:
291+
Model Run id and reset training metadata
292+
"""
293+
res = self.client.execute(
294+
"""mutation resetModelRunConfigPyApi($modelRunId: ID!){
295+
resetModelRunConfig(modelRun: {id : $modelRunId}){trainingMetadata}
296+
}
297+
""", {'modelRunId': self.uid},
298+
experimental=True)
299+
return res["resetModelRunConfig"]
300+
301+
@experimental
302+
def get_config(self) -> Dict[str, Any]:
303+
"""
304+
Gets Model Run's training metadata
305+
Returns:
306+
training metadata as a dictionary
307+
"""
308+
res = self.client.execute("""query ModelRunPyApi($modelRunId: ID!){
309+
modelRun(where: {id : $modelRunId}){trainingMetadata}
310+
}
311+
""", {'modelRunId': self.uid},
312+
experimental=True)
313+
return res["modelRun"]
314+
265315
@experimental
266316
def export_labels(
267317
self,

tests/integration/annotation_import/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,19 @@ def model_run(rand_gen, model):
360360
pass
361361

362362

363+
@pytest.fixture
364+
def model_run_with_training_metadata(rand_gen, model):
365+
name = rand_gen(str)
366+
training_metadata = {"batch_size": 1000}
367+
model_run = model.create_model_run(name, training_metadata)
368+
yield model_run
369+
try:
370+
model_run.delete()
371+
except:
372+
# Already was deleted by the test
373+
pass
374+
375+
363376
@pytest.fixture
364377
def model_run_with_model_run_data_rows(client, configured_project,
365378
model_run_predictions, model_run):

tests/integration/annotation_import/test_model_run.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,24 @@ def test_model_run_delete(client, model_run):
5656
assert len(before) == len(after) + 1
5757

5858

59+
def test_model_run_update_config(model_run_with_training_metadata):
60+
new_config = {"batch_size": 2000}
61+
res = model_run_with_training_metadata.update_config(new_config)
62+
assert res["trainingMetadata"]["batch_size"] == new_config["batch_size"]
63+
64+
65+
def test_model_run_reset_config(model_run_with_training_metadata):
66+
res = model_run_with_training_metadata.reset_config()
67+
assert res["trainingMetadata"] is None
68+
69+
70+
def test_model_run_get_config(model_run_with_training_metadata):
71+
new_config = {"batch_size": 2000}
72+
model_run_with_training_metadata.update_config(new_config)
73+
res = model_run_with_training_metadata.get_config()
74+
assert res["trainingMetadata"]["batch_size"] == new_config["batch_size"]
75+
76+
5977
def test_model_run_data_rows_delete(client, model_run_with_model_run_data_rows):
6078
models = list(client.get_models())
6179
model = models[0]

0 commit comments

Comments
 (0)