diff --git a/src/hierarchical_att_model.py b/src/hierarchical_att_model.py index 363797e..da927df 100644 --- a/src/hierarchical_att_model.py +++ b/src/hierarchical_att_model.py @@ -20,13 +20,10 @@ def __init__(self, word_hidden_size, sent_hidden_size, batch_size, num_classes, self.sent_att_net = SentAttNet(sent_hidden_size, word_hidden_size, num_classes) self._init_hidden_state() - def _init_hidden_state(self, last_batch_size=None): - if last_batch_size: - batch_size = last_batch_size - else: - batch_size = self.batch_size - self.word_hidden_state = torch.zeros(2, batch_size, self.word_hidden_size) - self.sent_hidden_state = torch.zeros(2, batch_size, self.sent_hidden_size) + def _init_hidden_state(self, current_batch_size): + # Hidden state initialization always takes batch size from the train/eval batch + self.word_hidden_state = torch.zeros(2, current_batch_size, self.word_hidden_size) + self.sent_hidden_state = torch.zeros(2, current_batch_size, self.sent_hidden_size) if torch.cuda.is_available(): self.word_hidden_state = self.word_hidden_state.cuda() self.sent_hidden_state = self.sent_hidden_state.cuda() diff --git a/src/sent_att_model.py b/src/sent_att_model.py index 75bddf6..84df539 100644 --- a/src/sent_att_model.py +++ b/src/sent_att_model.py @@ -28,7 +28,7 @@ def forward(self, input, hidden_state): f_output, h_output = self.gru(input, hidden_state) output = matrix_mul(f_output, self.sent_weight, self.sent_bias) - output = matrix_mul(output, self.context_weight).permute(1, 0) + output = matrix_mul(output, self.context_weight,apply_tanh=False).permute(1, 0) output = F.softmax(output) output = element_wise_mul(f_output, output.permute(1, 0)).squeeze(0) output = self.fc(output) diff --git a/src/utils.py b/src/utils.py index 20c14ea..1c52954 100644 --- a/src/utils.py +++ b/src/utils.py @@ -23,13 +23,15 @@ def get_evaluation(y_true, y_prob, list_metrics): output['confusion_matrix'] = str(metrics.confusion_matrix(y_true, y_pred)) return output -def matrix_mul(input, weight, bias=False): +def matrix_mul(input, weight, bias=False, apply_tanh=True): feature_list = [] for feature in input: feature = torch.mm(feature, weight) if isinstance(bias, torch.nn.parameter.Parameter): feature = feature + bias.expand(feature.size()[0], bias.size()[1]) - feature = torch.tanh(feature).unsqueeze(0) + if apply_tanh: + feature = torch.tanh(feature) + feature = feature.unsqueeze(0) feature_list.append(feature) return torch.cat(feature_list, 0).squeeze() diff --git a/src/word_att_model.py b/src/word_att_model.py index 399f846..852b8c1 100644 --- a/src/word_att_model.py +++ b/src/word_att_model.py @@ -36,7 +36,7 @@ def forward(self, input, hidden_state): output = self.lookup(input) f_output, h_output = self.gru(output.float(), hidden_state) # feature output and hidden state output output = matrix_mul(f_output, self.word_weight, self.word_bias) - output = matrix_mul(output, self.context_weight).permute(1,0) + output = matrix_mul(output, self.context_weight,apply_tanh=False).permute(1,0) output = F.softmax(output) output = element_wise_mul(f_output,output.permute(1,0)) diff --git a/train.py b/train.py index e68f80c..29fb028 100644 --- a/train.py +++ b/train.py @@ -82,7 +82,9 @@ def train(opt): feature = feature.cuda() label = label.cuda() optimizer.zero_grad() - model._init_hidden_state() + # Adding batch size to the + train_num_sample = len(label) + model._init_hidden_state(train_num_sample) predictions = model(feature) loss = criterion(predictions, label) loss.backward()