diff --git a/EduKTM/AKT/AKTNet.py b/EduKTM/AKT/AKTNet.py index 1155a48..1764b0a 100644 --- a/EduKTM/AKT/AKTNet.py +++ b/EduKTM/AKT/AKTNet.py @@ -270,7 +270,7 @@ def attention(q, k, v, d_k, mask, dropout, zero_pad, gamma=None): total_effect = torch.clamp(torch.clamp((dist_scores * gamma).exp(), min=1e-5), max=1e5) scores = scores * total_effect - scores.masked_fill(mask == 0, -1e23) + scores.masked_fill_(mask == 0, -1e23) scores = F.softmax(scores, dim=-1) if zero_pad: pad_zero = torch.zeros(bs, head, 1, seqlen).to(device)