From cfc4f9fc75c0689fee2c5d8ee095525c2725b6ed Mon Sep 17 00:00:00 2001 From: doheekim Date: Mon, 27 May 2024 17:55:31 -0700 Subject: [PATCH 1/3] predict simplekt with custom seq_len --- examples/wandb_predict.py | 88 +++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/examples/wandb_predict.py b/examples/wandb_predict.py index e2a435e6..b3a7e685 100644 --- a/examples/wandb_predict.py +++ b/examples/wandb_predict.py @@ -11,7 +11,7 @@ device = "cpu" if not torch.cuda.is_available() else "cuda" os.environ['CUBLAS_WORKSPACE_CONFIG']=':4096:2' -with open("../configs/wandb.json") as fin: +with open("configs/wandb.json") as fin: wandb_config = json.load(fin) def main(params): @@ -20,7 +20,7 @@ def main(params): os.environ['WANDB_API_KEY'] = wandb_config["api_key"] wandb.init(project="wandb_predict") - save_dir, batch_size, fusion_type = params["save_dir"], params["bz"], params["fusion_type"].split(",") + save_dir, batch_size, fusion_type, fold = params["save_dir"], params["bz"], params["fusion_type"].split(","), params["fold"] with open(os.path.join(save_dir, "config.json")) as fin: config = json.load(fin) @@ -31,13 +31,20 @@ def main(params): trained_params = config["params"] fold = trained_params["fold"] model_name, dataset_name, emb_type = trained_params["model_name"], trained_params["dataset_name"], trained_params["emb_type"] - if model_name in ["saint", "sakt", "atdkt"]: + if model_name in ["saint", "sakt", "atdkt", "simplekt"]: train_config = config["train_config"] seq_len = train_config["seq_len"] model_config["seq_len"] = seq_len - with open("../configs/data_config.json") as fin: + with open("configs/data_config.json") as fin: curconfig = copy.deepcopy(json.load(fin)) + curconfig[dataset_name]["test_original_file"] = f"test_fold{fold}.csv" + curconfig[dataset_name]["test_question_file"] = f"test_question_sequences_fold{fold}.csv" + curconfig[dataset_name]["test_question_window_file"] = f"test_question_window_sequences_fold{fold}.csv" + curconfig[dataset_name]["test_file"] = f"test_sequences_fold{fold}.csv" + curconfig[dataset_name]["test_window_file"] = f"test_window_sequences_fold{fold}.csv" + + data_config = curconfig[dataset_name] data_config["dataset_name"] = dataset_name if model_name in ["dkt_forget", "bakt_time"]: @@ -80,49 +87,49 @@ def main(params): testauc, testacc = evaluate(model, test_loader, model_name, save_test_path) print(f"testauc: {testauc}, testacc: {testacc}") - window_testauc, window_testacc = -1, -1 - save_test_window_path = os.path.join(save_dir, model.emb_type+"_test_window_predictions.txt") - if model.model_name == "rkt": - window_testauc, window_testacc = evaluate(model, test_window_loader, model_name, rel, save_test_window_path) - else: - window_testauc, window_testacc = evaluate(model, test_window_loader, model_name, save_test_window_path) - print(f"testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") + # window_testauc, window_testacc = -1, -1 + # save_test_window_path = os.path.join(save_dir, model.emb_type+"_test_window_predictions.txt") + # if model.model_name == "rkt": + # window_testauc, window_testacc = evaluate(model, test_window_loader, model_name, rel, save_test_window_path) + # else: + # window_testauc, window_testacc = evaluate(model, test_window_loader, model_name, save_test_window_path) + # print(f"testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") - # question_testauc, question_testacc = -1, -1 - # question_window_testauc, question_window_testacc = -1, -1 + # # question_testauc, question_testacc = -1, -1 + # # question_window_testauc, question_window_testacc = -1, -1 - dres = { - "testauc": testauc, "testacc": testacc, "window_testauc": window_testauc, "window_testacc": window_testacc, - } - - q_testaucs, q_testaccs = -1,-1 - qw_testaucs, qw_testaccs = -1,-1 - if "test_question_file" in data_config and not test_question_loader is None: - save_test_question_path = os.path.join(save_dir, model.emb_type+"_test_question_predictions.txt") - q_testaucs, q_testaccs = evaluate_question(model, test_question_loader, model_name, fusion_type, save_test_question_path) - for key in q_testaucs: - dres["oriauc"+key] = q_testaucs[key] - for key in q_testaccs: - dres["oriacc"+key] = q_testaccs[key] + # dres = { + # "testauc": testauc, "testacc": testacc, "window_testauc": window_testauc, "window_testacc": window_testacc, + # } + + # q_testaucs, q_testaccs = -1,-1 + # qw_testaucs, qw_testaccs = -1,-1 + # if "test_question_file" in data_config and not test_question_loader is None: + # save_test_question_path = os.path.join(save_dir, model.emb_type+"_test_question_predictions.txt") + # q_testaucs, q_testaccs = evaluate_question(model, test_question_loader, model_name, fusion_type, save_test_question_path) + # for key in q_testaucs: + # dres["oriauc"+key] = q_testaucs[key] + # for key in q_testaccs: + # dres["oriacc"+key] = q_testaccs[key] - if "test_question_window_file" in data_config and not test_question_window_loader is None: - save_test_question_window_path = os.path.join(save_dir, model.emb_type+"_test_question_window_predictions.txt") - qw_testaucs, qw_testaccs = evaluate_question(model, test_question_window_loader, model_name, fusion_type, save_test_question_window_path) - for key in qw_testaucs: - dres["windowauc"+key] = qw_testaucs[key] - for key in qw_testaccs: - dres["windowacc"+key] = qw_testaccs[key] + # if "test_question_window_file" in data_config and not test_question_window_loader is None: + # save_test_question_window_path = os.path.join(save_dir, model.emb_type+"_test_question_window_predictions.txt") + # qw_testaucs, qw_testaccs = evaluate_question(model, test_question_window_loader, model_name, fusion_type, save_test_question_window_path) + # for key in qw_testaucs: + # dres["windowauc"+key] = qw_testaucs[key] + # for key in qw_testaccs: + # dres["windowacc"+key] = qw_testaccs[key] - # print(f"testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") - # print(f"question_testauc: {question_testauc}, question_testacc: {question_testacc}, question_window_testauc: {question_window_testauc}, question_window_testacc: {question_window_testacc}") + # # print(f"testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") + # # print(f"question_testauc: {question_testauc}, question_testacc: {question_testacc}, question_window_testauc: {question_window_testauc}, question_window_testacc: {question_window_testacc}") - print(dres) - raw_config = json.load(open(os.path.join(save_dir,"config.json"))) - dres.update(raw_config['params']) + # print(dres) + # raw_config = json.load(open(os.path.join(save_dir,"config.json"))) + # dres.update(raw_config['params']) - if params['use_wandb'] ==1: - wandb.log(dres) + # if params['use_wandb'] ==1: + # wandb.log(dres) if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -130,6 +137,7 @@ def main(params): parser.add_argument("--save_dir", type=str, default="saved_model") parser.add_argument("--fusion_type", type=str, default="early_fusion,late_fusion") parser.add_argument("--use_wandb", type=int, default=1) + parser.add_argument("--fold", type=int, default=0) args = parser.parse_args() print(args) From e94a523e3f51cd1740b47a6b1260122387682b68 Mon Sep 17 00:00:00 2001 From: KimDohee Date: Tue, 28 May 2024 10:24:11 +0900 Subject: [PATCH 2/3] Update wandb_predict.py --- examples/wandb_predict.py | 86 ++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 46 deletions(-) diff --git a/examples/wandb_predict.py b/examples/wandb_predict.py index b3a7e685..f5f08c5d 100644 --- a/examples/wandb_predict.py +++ b/examples/wandb_predict.py @@ -11,7 +11,7 @@ device = "cpu" if not torch.cuda.is_available() else "cuda" os.environ['CUBLAS_WORKSPACE_CONFIG']=':4096:2' -with open("configs/wandb.json") as fin: +with open("../configs/wandb.json") as fin: wandb_config = json.load(fin) def main(params): @@ -20,7 +20,7 @@ def main(params): os.environ['WANDB_API_KEY'] = wandb_config["api_key"] wandb.init(project="wandb_predict") - save_dir, batch_size, fusion_type, fold = params["save_dir"], params["bz"], params["fusion_type"].split(","), params["fold"] + save_dir, batch_size, fusion_type = params["save_dir"], params["bz"], params["fusion_type"].split(",") with open(os.path.join(save_dir, "config.json")) as fin: config = json.load(fin) @@ -36,15 +36,9 @@ def main(params): seq_len = train_config["seq_len"] model_config["seq_len"] = seq_len - with open("configs/data_config.json") as fin: + with open("../configs/data_config.json") as fin: curconfig = copy.deepcopy(json.load(fin)) - curconfig[dataset_name]["test_original_file"] = f"test_fold{fold}.csv" - curconfig[dataset_name]["test_question_file"] = f"test_question_sequences_fold{fold}.csv" - curconfig[dataset_name]["test_question_window_file"] = f"test_question_window_sequences_fold{fold}.csv" - curconfig[dataset_name]["test_file"] = f"test_sequences_fold{fold}.csv" - curconfig[dataset_name]["test_window_file"] = f"test_window_sequences_fold{fold}.csv" - - + data_config = curconfig[dataset_name] data_config["dataset_name"] = dataset_name if model_name in ["dkt_forget", "bakt_time"]: @@ -87,49 +81,49 @@ def main(params): testauc, testacc = evaluate(model, test_loader, model_name, save_test_path) print(f"testauc: {testauc}, testacc: {testacc}") - # window_testauc, window_testacc = -1, -1 - # save_test_window_path = os.path.join(save_dir, model.emb_type+"_test_window_predictions.txt") - # if model.model_name == "rkt": - # window_testauc, window_testacc = evaluate(model, test_window_loader, model_name, rel, save_test_window_path) - # else: - # window_testauc, window_testacc = evaluate(model, test_window_loader, model_name, save_test_window_path) - # print(f"testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") + window_testauc, window_testacc = -1, -1 + save_test_window_path = os.path.join(save_dir, model.emb_type+"_test_window_predictions.txt") + if model.model_name == "rkt": + window_testauc, window_testacc = evaluate(model, test_window_loader, model_name, rel, save_test_window_path) + else: + window_testauc, window_testacc = evaluate(model, test_window_loader, model_name, save_test_window_path) + print(f"testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") - # # question_testauc, question_testacc = -1, -1 - # # question_window_testauc, question_window_testacc = -1, -1 + # question_testauc, question_testacc = -1, -1 + # question_window_testauc, question_window_testacc = -1, -1 - # dres = { - # "testauc": testauc, "testacc": testacc, "window_testauc": window_testauc, "window_testacc": window_testacc, - # } - - # q_testaucs, q_testaccs = -1,-1 - # qw_testaucs, qw_testaccs = -1,-1 - # if "test_question_file" in data_config and not test_question_loader is None: - # save_test_question_path = os.path.join(save_dir, model.emb_type+"_test_question_predictions.txt") - # q_testaucs, q_testaccs = evaluate_question(model, test_question_loader, model_name, fusion_type, save_test_question_path) - # for key in q_testaucs: - # dres["oriauc"+key] = q_testaucs[key] - # for key in q_testaccs: - # dres["oriacc"+key] = q_testaccs[key] + dres = { + "testauc": testauc, "testacc": testacc, "window_testauc": window_testauc, "window_testacc": window_testacc, + } + + q_testaucs, q_testaccs = -1,-1 + qw_testaucs, qw_testaccs = -1,-1 + if "test_question_file" in data_config and not test_question_loader is None: + save_test_question_path = os.path.join(save_dir, model.emb_type+"_test_question_predictions.txt") + q_testaucs, q_testaccs = evaluate_question(model, test_question_loader, model_name, fusion_type, save_test_question_path) + for key in q_testaucs: + dres["oriauc"+key] = q_testaucs[key] + for key in q_testaccs: + dres["oriacc"+key] = q_testaccs[key] - # if "test_question_window_file" in data_config and not test_question_window_loader is None: - # save_test_question_window_path = os.path.join(save_dir, model.emb_type+"_test_question_window_predictions.txt") - # qw_testaucs, qw_testaccs = evaluate_question(model, test_question_window_loader, model_name, fusion_type, save_test_question_window_path) - # for key in qw_testaucs: - # dres["windowauc"+key] = qw_testaucs[key] - # for key in qw_testaccs: - # dres["windowacc"+key] = qw_testaccs[key] + if "test_question_window_file" in data_config and not test_question_window_loader is None: + save_test_question_window_path = os.path.join(save_dir, model.emb_type+"_test_question_window_predictions.txt") + qw_testaucs, qw_testaccs = evaluate_question(model, test_question_window_loader, model_name, fusion_type, save_test_question_window_path) + for key in qw_testaucs: + dres["windowauc"+key] = qw_testaucs[key] + for key in qw_testaccs: + dres["windowacc"+key] = qw_testaccs[key] - # # print(f"testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") - # # print(f"question_testauc: {question_testauc}, question_testacc: {question_testacc}, question_window_testauc: {question_window_testauc}, question_window_testacc: {question_window_testacc}") + # print(f"testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") + # print(f"question_testauc: {question_testauc}, question_testacc: {question_testacc}, question_window_testauc: {question_window_testauc}, question_window_testacc: {question_window_testacc}") - # print(dres) - # raw_config = json.load(open(os.path.join(save_dir,"config.json"))) - # dres.update(raw_config['params']) + print(dres) + raw_config = json.load(open(os.path.join(save_dir,"config.json"))) + dres.update(raw_config['params']) - # if params['use_wandb'] ==1: - # wandb.log(dres) + if params['use_wandb'] ==1: + wandb.log(dres) if __name__ == "__main__": parser = argparse.ArgumentParser() From 45e3b86e429c950871cbcbb1a8900ce8bb9443f3 Mon Sep 17 00:00:00 2001 From: KimDohee Date: Tue, 28 May 2024 10:24:32 +0900 Subject: [PATCH 3/3] Update wandb_predict.py --- examples/wandb_predict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/wandb_predict.py b/examples/wandb_predict.py index f5f08c5d..ecc6fbee 100644 --- a/examples/wandb_predict.py +++ b/examples/wandb_predict.py @@ -131,7 +131,6 @@ def main(params): parser.add_argument("--save_dir", type=str, default="saved_model") parser.add_argument("--fusion_type", type=str, default="early_fusion,late_fusion") parser.add_argument("--use_wandb", type=int, default=1) - parser.add_argument("--fold", type=int, default=0) args = parser.parse_args() print(args)