diff --git a/metalearner.py b/metalearner.py index d906833..55ff675 100644 --- a/metalearner.py +++ b/metalearner.py @@ -20,7 +20,7 @@ def __init__(self, input_size, hidden_size, n_learner_params): self.n_learner_params = n_learner_params self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size)) self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size)) - self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1)) + self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1), requires_grad=False) # Freeze the parameters to avoide getting optimized by gradient descent (Adam) self.bI = nn.Parameter(torch.Tensor(1, hidden_size)) self.bF = nn.Parameter(torch.Tensor(1, hidden_size))