diff --git a/examples/wandb_predict.py b/examples/wandb_predict.py index e2a435e6..ecc6fbee 100644 --- a/examples/wandb_predict.py +++ b/examples/wandb_predict.py @@ -31,13 +31,14 @@ def main(params): trained_params = config["params"] fold = trained_params["fold"] model_name, dataset_name, emb_type = trained_params["model_name"], trained_params["dataset_name"], trained_params["emb_type"] - if model_name in ["saint", "sakt", "atdkt"]: + if model_name in ["saint", "sakt", "atdkt", "simplekt"]: train_config = config["train_config"] seq_len = train_config["seq_len"] model_config["seq_len"] = seq_len with open("../configs/data_config.json") as fin: curconfig = copy.deepcopy(json.load(fin)) + data_config = curconfig[dataset_name] data_config["dataset_name"] = dataset_name if model_name in ["dkt_forget", "bakt_time"]: