diff --git a/model.py b/model.py index afdaee5..92892e9 100644 --- a/model.py +++ b/model.py @@ -15,9 +15,9 @@ def forward(self, x): # Keep the dimension for broadcasting mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1) # Keep the dimension for broadcasting - std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1) + var = x.var(dim = -1, keepdim = True) # (batch, seq_len, 1) # eps is to prevent dividing by zero or when std is very small - return self.alpha * (x - mean) / (std + self.eps) + self.bias + return self.alpha * (x - mean) / torch.sqrt(var + self.eps) + self.bias class FeedForwardBlock(nn.Module):