File tree Expand file tree Collapse file tree 3 files changed +69
-0
lines changed
tests/integration/annotation_import Expand file tree Collapse file tree 3 files changed +69
-0
lines changed Original file line number Diff line number Diff line change @@ -262,6 +262,45 @@ def update_status(self,
262262 },
263263 experimental = True )
264264
265+ @experimental
266+ def update_config (self ,
267+ config : Dict [str , Any ]):
268+ data : Dict [str , Any ] = {'config' : config }
269+ res = self .client .execute (
270+ """mutation updateModelRunConfigPyApi($modelRunId: ID!, $data: UpdateModelRunConfigInput!){
271+ updateModelRunConfig(modelRun: {id : $modelRunId}, data: $data){trainingMetadata}
272+ }
273+ """ , {
274+ 'modelRunId' : self .uid ,
275+ 'data' : data
276+ },
277+ experimental = True )
278+ return res ["updateModelRunConfig" ]
279+
280+ @experimental
281+ def reset_config (self ):
282+ res = self .client .execute (
283+ """mutation resetModelRunConfigPyApi($modelRunId: ID!){
284+ resetModelRunConfig(modelRun: {id : $modelRunId}){trainingMetadata}
285+ }
286+ """ , {
287+ 'modelRunId' : self .uid
288+ },
289+ experimental = True )
290+ return res ["resetModelRunConfig" ]
291+
292+ @experimental
293+ def config (self ):
294+ res = self .client .execute (
295+ """query ModelRunPyApi($modelRunId: ID!){
296+ modelRun(where: {id : $modelRunId}){trainingMetadata}
297+ }
298+ """ , {
299+ 'modelRunId' : self .uid
300+ },
301+ experimental = True )
302+ return res ["modelRun" ]
303+
265304 @experimental
266305 def export_labels (
267306 self ,
Original file line number Diff line number Diff line change @@ -359,6 +359,18 @@ def model_run(rand_gen, model):
359359 # Already was deleted by the test
360360 pass
361361
362+ @pytest .fixture
363+ def model_run_with_training_metadata (rand_gen , model ):
364+ name = rand_gen (str )
365+ training_metadata = {"batch_size" : 1000 }
366+ model_run = model .create_model_run (name , training_metadata )
367+ yield model_run
368+ try :
369+ model_run .delete ()
370+ except :
371+ # Already was deleted by the test
372+ pass
373+
362374
363375@pytest .fixture
364376def model_run_with_model_run_data_rows (client , configured_project ,
Original file line number Diff line number Diff line change @@ -56,6 +56,24 @@ def test_model_run_delete(client, model_run):
5656 assert len (before ) == len (after ) + 1
5757
5858
59+ def test_model_run_update_config (model_run_with_training_metadata ):
60+ new_config = {"batch_size" : 2000 }
61+ res = model_run_with_training_metadata .update_config (new_config )
62+ assert res ["trainingMetadata" ]["batchSize" ] == new_config ["batch_size" ]
63+
64+
65+ def test_model_run_reset_config (model_run_with_training_metadata ):
66+ res = model_run_with_training_metadata .reset_config ()
67+ assert res ["trainingMetadata" ] == None
68+
69+
70+ def test_model_run_fetch_config (model_run_with_training_metadata ):
71+ new_config = {"batch_size" : 2000 }
72+ model_run_with_training_metadata .update_config (new_config )
73+ res = model_run_with_training_metadata .config ()
74+ assert res ["trainingMetadata" ]["batchSize" ] == new_config ["batch_size" ]
75+
76+
5977def test_model_run_data_rows_delete (client , model_run_with_model_run_data_rows ):
6078 models = list (client .get_models ())
6179 model = models [0 ]
You can’t perform that action at this time.
0 commit comments