Skip to content

Commit 28e9f7e

Browse files
committed
Add model run config on create [AL-3060]
1 parent 1a4965d commit 28e9f7e

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

examples/model_diagnostics/model_diagnostics_guide.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@
350350
"source": [
351351
"lb_model = client.create_model(name=f\"{project.name}-model\",\n",
352352
" ontology_id=project.ontology().uid)\n",
353-
"lb_model_run = lb_model.create_model_run(\"0.0.0\")\n",
353+
"lb_model_run_hyperparameters = {\"batch_size\": 1000}\n",
354+
"lb_model_run = lb_model.create_model_run(\"0.0.0\", lb_model_run_hyperparameters)\n",
354355
"lb_model_run.upsert_labels([label.uid for label in labels])"
355356
]
356357
},

labelbox/schema/model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@ class Model(DbObject):
1818
name = Field.String("name")
1919
model_runs = Relationship.ToMany("ModelRun", False)
2020

21-
def create_model_run(self, name) -> "ModelRun":
21+
def create_model_run(self, name, config) -> "ModelRun":
2222
""" Creates a model run belonging to this model.
2323
2424
Args:
2525
name (string): The name for the model run.
26+
config (json): Model run's training metadata config
2627
Returns:
2728
ModelRun, the created model run.
2829
"""
2930
name_param = "name"
31+
config_param = "config"
3032
model_id_param = "modelId"
3133
ModelRun = Entity.ModelRun
3234
query_str = """mutation CreateModelRunPyApi($%s: String!, $%s: ID!) {
33-
createModelRun(data: {name: $%s, modelId: $%s}) {%s}}""" % (
34-
name_param, model_id_param, name_param, model_id_param,
35-
query.results_query_part(ModelRun))
35+
createModelRun(data: {name: $%s, trainingMetadata: $%s, modelId: $%s}) {%s}}""" % (
36+
name_param, model_id_param, name_param, config_param,
37+
model_id_param, query.results_query_part(ModelRun))
3638
res = self.client.execute(query_str, {
3739
name_param: name,
40+
config_param: config,
3841
model_id_param: self.uid
3942
})
4043
return ModelRun(self.client, res["createModelRun"])

labelbox/schema/model_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ModelRun(DbObject):
3232
created_at = Field.DateTime("created_at")
3333
created_by_id = Field.String("created_by_id", "createdBy")
3434
model_id = Field.String("model_id")
35+
training_metadata = Field.Json("training_metadata")
3536

3637
class Status(Enum):
3738
EXPORTING_DATA = "EXPORTING_DATA"

tests/integration/annotation_import/test_model_run.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ def test_model_run(client, configured_project_with_label, rand_gen):
1414
model = client.create_model(data["name"], data["ontology_id"])
1515

1616
name = rand_gen(str)
17-
model_run = model.create_model_run(name)
17+
config = {"batch_size": 100}
18+
model_run = model.create_model_run(name, config)
1819
assert model_run.name == name
20+
assert model_run.training_metadata["batch_size"] == config["batch_size"]
1921
assert model_run.model_id == model.uid
2022
assert model_run.created_by_id == client.get_user().uid
2123

0 commit comments

Comments
 (0)