Skip to content

Commit 141cfdb

Browse files
committed
Report metrics from training stage
1 parent 39f0aa0 commit 141cfdb

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

examples/kge-distmult-nations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def inspect_graph(G):
192192

193193
model_name = "dummyModelName_" + str(time.time())
194194

195-
gds.kge.model.train(
195+
res = gds.kge.model.train(
196196
G_train,
197197
model_name=model_name,
198198
scoring_function="distmult",
@@ -202,6 +202,7 @@ def inspect_graph(G):
202202
epochs_per_val=5,
203203
split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1},
204204
)
205+
print(res["metrics"])
205206

206207
predict_result = gds.kge.model.predict(
207208
model_name=model_name,
@@ -216,7 +217,6 @@ def inspect_graph(G):
216217

217218
print(predict_result.to_string())
218219

219-
print(predict_result.to_string())
220220
for index, row in predict_result.iterrows():
221221
h = row["head"]
222222
r = row["rel"]

graphdatascience/model/kge_runner.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import logging
23
import os
34
import time
@@ -114,7 +115,12 @@ def train(
114115

115116
self._wait_for_job(job_id)
116117

117-
return Series({"status": "finished"})
118+
return Series(
119+
{
120+
"status": "finished",
121+
"metrics": self._get_metrics(config["user_name"], config["task_config"]["modelname"], job_id),
122+
}
123+
)
118124

119125
@client_only_endpoint("gds.kge.model")
120126
def predict(
@@ -205,6 +211,25 @@ def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataF
205211
os.remove(res_file_name)
206212
return df
207213

214+
def _get_metrics(self, user_name: str, model_name: str, job_id: str) -> DataFrame:
215+
res = requests.get(
216+
f"{self._compute_cluster_web_uri}/internal/fetch-model-metadata",
217+
params={"user_name": user_name, "modelname": model_name},
218+
)
219+
res.raise_for_status()
220+
221+
res_file_name = f"metadata_{job_id}.json"
222+
223+
with open(res_file_name, mode="wb+") as f:
224+
f.write(res.content)
225+
226+
with open(res_file_name, mode="r") as f:
227+
metadata = json.load(f)
228+
229+
os.remove(res_file_name)
230+
231+
return metadata["metrics"]
232+
208233
def _start_job(self, config: Dict[str, Any]) -> str:
209234
url = f"{self._compute_cluster_web_uri}/api/machine-learning/start"
210235
res = requests.post(url, json=config)

0 commit comments

Comments
 (0)