Skip to content

Commit d35b867

Browse files
committed
Update config [AL-3061]
1 parent d44e091 commit d35b867

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

labelbox/schema/model_run.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,45 @@ def update_status(self,
262262
},
263263
experimental=True)
264264

265+
@experimental
266+
def update_config(self,
267+
config: Dict[str, Any]):
268+
data: Dict[str, Any] = {'config': config}
269+
res = self.client.execute(
270+
"""mutation updateModelRunConfigPyApi($modelRunId: ID!, $data: UpdateModelRunConfigInput!){
271+
updateModelRunConfig(modelRun: {id : $modelRunId}, data: $data){trainingMetadata}
272+
}
273+
""", {
274+
'modelRunId': self.uid,
275+
'data': data
276+
},
277+
experimental=True)
278+
return res["updateModelRunConfig"]
279+
280+
@experimental
281+
def reset_config(self):
282+
res = self.client.execute(
283+
"""mutation resetModelRunConfigPyApi($modelRunId: ID!){
284+
resetModelRunConfig(modelRun: {id : $modelRunId}){trainingMetadata}
285+
}
286+
""", {
287+
'modelRunId': self.uid
288+
},
289+
experimental=True)
290+
return res["resetModelRunConfig"]
291+
292+
@experimental
293+
def config(self):
294+
res = self.client.execute(
295+
"""query ModelRunPyApi($modelRunId: ID!){
296+
modelRun(where: {id : $modelRunId}){trainingMetadata}
297+
}
298+
""", {
299+
'modelRunId': self.uid
300+
},
301+
experimental=True)
302+
return res["modelRun"]
303+
265304
@experimental
266305
def export_labels(
267306
self,

tests/integration/annotation_import/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,18 @@ def model_run(rand_gen, model):
359359
# Already was deleted by the test
360360
pass
361361

362+
@pytest.fixture
363+
def model_run_with_training_metadata(rand_gen, model):
364+
name = rand_gen(str)
365+
training_metadata = {"batch_size": 1000}
366+
model_run = model.create_model_run(name, training_metadata)
367+
yield model_run
368+
try:
369+
model_run.delete()
370+
except:
371+
# Already was deleted by the test
372+
pass
373+
362374

363375
@pytest.fixture
364376
def model_run_with_model_run_data_rows(client, configured_project,

tests/integration/annotation_import/test_model_run.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,24 @@ def test_model_run_delete(client, model_run):
5656
assert len(before) == len(after) + 1
5757

5858

59+
def test_model_run_update_config(model_run_with_training_metadata):
60+
new_config = {"batch_size": 2000}
61+
res = model_run_with_training_metadata.update_config(new_config)
62+
assert res["trainingMetadata"]["batchSize"] == new_config["batch_size"]
63+
64+
65+
def test_model_run_reset_config(model_run_with_training_metadata):
66+
res = model_run_with_training_metadata.reset_config()
67+
assert res["trainingMetadata"] == None
68+
69+
70+
def test_model_run_fetch_config(model_run_with_training_metadata):
71+
new_config = {"batch_size": 2000}
72+
model_run_with_training_metadata.update_config(new_config)
73+
res = model_run_with_training_metadata.config()
74+
assert res["trainingMetadata"]["batchSize"] == new_config["batch_size"]
75+
76+
5977
def test_model_run_data_rows_delete(client, model_run_with_model_run_data_rows):
6078
models = list(client.get_models())
6179
model = models[0]

0 commit comments

Comments
 (0)