diff --git a/main.py b/main.py index 0bf095c..8d77a2f 100644 --- a/main.py +++ b/main.py @@ -88,10 +88,11 @@ def main(): # initialize model, criterion/loss_function, optimizer model = SimilarityTreeLSTM( - args.cuda, vocab.size(), - args.input_dim, args.mem_dim, - args.hidden_dim, args.num_classes, - args.sparse) + args.mem_dim, + args.hidden_dim, + vocab.size(), + args.input_dim, + args.num_classes,) criterion = nn.KLDivLoss() if args.cuda: model.cuda(), criterion.cuda() @@ -123,7 +124,7 @@ def main(): # plug these into embedding matrix inside model if args.cuda: emb = emb.cuda() - model.childsumtreelstm.emb.state_dict()['weight'].copy_(emb) + model.embed.weight.data.copy_(emb) # create trainer object for training and testing trainer = Trainer(args, model, criterion, optimizer) diff --git a/model.py b/model.py index c2a18b3..f57bd53 100644 --- a/model.py +++ b/model.py @@ -1,108 +1,144 @@ +import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable as Var import Constants +from torch.nn import Parameter -# module for childsumtreelstm -class ChildSumTreeLSTM(nn.Module): - def __init__(self, cuda, vocab_size, in_dim, mem_dim, sparsity): - super(ChildSumTreeLSTM, self).__init__() - self.cudaFlag = cuda - self.in_dim = in_dim - self.mem_dim = mem_dim - - self.emb = nn.Embedding(vocab_size,in_dim, - padding_idx=Constants.PAD, - sparse=sparsity) - - self.ix = nn.Linear(self.in_dim,self.mem_dim) - self.ih = nn.Linear(self.mem_dim,self.mem_dim) - - self.fx = nn.Linear(self.in_dim,self.mem_dim) - self.fh = nn.Linear(self.mem_dim,self.mem_dim) - - self.ox = nn.Linear(self.in_dim,self.mem_dim) - self.oh = nn.Linear(self.mem_dim,self.mem_dim) - - self.ux = nn.Linear(self.in_dim,self.mem_dim) - self.uh = nn.Linear(self.mem_dim,self.mem_dim) - - def node_forward(self, inputs, child_c, child_h): - child_h_sum = F.torch.sum(torch.squeeze(child_h, 1), 0, keepdim=True) - - i = F.sigmoid(self.ix(inputs) + self.ih(child_h_sum)) - o = F.sigmoid(self.ox(inputs) + self.oh(child_h_sum)) - u = F.tanh(self.ux(inputs) + self.uh(child_h_sum)) - - fx = self.fx(inputs) - f = F.torch.cat([self.fh(child_hi) + fx for child_hi in child_h], 0) - f = F.sigmoid(f) - # adding extra singleton dimension - f = F.torch.unsqueeze(f, 1) - fc = F.torch.squeeze(F.torch.mul(f, child_c), 1) - - c = F.torch.mul(i, u) + F.torch.sum(fc, 0, keepdim=True) - h = F.torch.mul(o, F.tanh(c)) - - return c,h - - def forward(self, tree, inputs): - # add singleton dimension for future call to node_forward - embs = F.torch.unsqueeze(self.emb(inputs),1) - for idx in range(tree.num_children): - _ = self.forward(tree.children[idx], inputs) - child_c, child_h = self.get_child_states(tree) - tree.state = self.node_forward(embs[tree.idx], child_c, child_h) - return tree.state - - def get_child_states(self, tree): - # add extra singleton dimension in middle... - # because pytorch needs mini batches... :sad: - if tree.num_children==0: - child_c = Var(torch.zeros(1, 1, self.mem_dim)) - child_h = Var(torch.zeros(1, 1, self.mem_dim)) - if self.cudaFlag: - child_c, child_h = child_c.cuda(), child_h.cuda() +class Tree(object): + def __init__(self, idx): + self.children = [] + self.idx = idx + + def __repr__(self): + if self.children: + return '{0}: {1}'.format(self.idx, str(self.children)) + else: + return str(self.idx) + +tree = Tree(0) +tree.children.append(Tree(1)) +tree.children.append(Tree(2)) +tree.children.append(Tree(3)) +tree.children[1].children.append(Tree(4)) +print(tree) + +class ChildSumLSTMCell(nn.Module): + def __init__(self, hidden_size, + i2h_weight_initializer=None, + hs2h_weight_initializer=None, + hc2h_weight_initializer=None, + i2h_bias_initializer='zeros', + hs2h_bias_initializer='zeros', + hc2h_bias_initializer='zeros', + input_size=0): + super(ChildSumLSTMCell, self).__init__() + self._hidden_size = hidden_size + self._input_size = input_size + stdv = 1. / math.sqrt(input_size) + self.i2h_weight = Parameter(torch.Tensor(4*hidden_size, input_size).uniform_(-stdv, stdv)) + self.i2h_bias = Parameter(torch.Tensor(4*hidden_size).uniform_(-stdv, stdv)) + stdv = 1. / math.sqrt(hidden_size) + self.hs2h_weight = Parameter(torch.Tensor(3*hidden_size, hidden_size).uniform_(-stdv, stdv)) + self.hs2h_bias = Parameter(torch.Tensor(3*hidden_size).uniform_(-stdv, stdv)) + stdv = 1. / math.sqrt(hidden_size) + self.hc2h_weight = Parameter(torch.randn(hidden_size, hidden_size).uniform_(-stdv, stdv)) + self.hc2h_bias = Parameter(torch.Tensor(hidden_size).uniform_(-stdv, stdv)) + + def forward(self, inputs, tree): + children_outputs = [self(inputs, child) for child in tree.children] + if children_outputs: + _, children_states = zip(*children_outputs) # unzip + else: + children_states = None + + return self.node_forward(inputs[tree.idx].unsqueeze(0), + children_states, + self.i2h_weight, self.hs2h_weight, + self.hc2h_weight, self.i2h_bias, + self.hs2h_bias, self.hc2h_bias) + + def node_forward(self, inputs, children_states, + i2h_weight, hs2h_weight, hc2h_weight, + i2h_bias, hs2h_bias, hc2h_bias): + # comment notation: + # N for batch size + # C for hidden state dimensions + # K for number of children. + + # FC for i, f, u, o gates (N, 4*C), from input to hidden + i2h = F.linear(inputs, i2h_weight, i2h_bias) + i2h_slices = torch.split(i2h, i2h.size(1) // 4, dim=1) # (N, C)*4 + i2h_iuo = torch.cat([i2h_slices[0], i2h_slices[2], i2h_slices[3]], dim=1) # (N, C*3) + + if children_states: + # sum of children states, (N, C) + hs = torch.sum(torch.cat([state[0].unsqueeze(0) for state in children_states]), dim=0) + # concatenation of children hidden states, (N, K, C) + hc = torch.cat([state[0].unsqueeze(1) for state in children_states], dim=1) + # concatenation of children cell states, (N, K, C) + cs = torch.cat([state[1].unsqueeze(1) for state in children_states], dim=1) + # calculate activation for forget gate. addition in f_act is done with broadcast + i2h_f_slice = i2h_slices[1] + f_act = i2h_f_slice + hc2h_bias.unsqueeze(0).expand_as(i2h_f_slice) + torch.matmul(hc, hc2h_weight) # (N, K, C) + forget_gates = F.sigmoid(f_act) # (N, K, C) else: - child_c = Var(torch.Tensor(tree.num_children, 1, self.mem_dim)) - child_h = Var(torch.Tensor(tree.num_children, 1, self.mem_dim)) - if self.cudaFlag: - child_c, child_h = child_c.cuda(), child_h.cuda() - for idx in range(tree.num_children): - child_c[idx], child_h[idx] = tree.children[idx].state - return child_c, child_h + # for leaf nodes, summation of children hidden states are zeros. + # in > 0.2 you can use torch.zeros_like for this + hs = Var(i2h_slices[0].data.new(*i2h_slices[0].size()).fill_(0)) + + # FC for i, u, o gates, from summation of children states to hidden state + hs2h_iuo = F.linear(hs, hs2h_weight, hs2h_bias) + i2h_iuo = i2h_iuo + hs2h_iuo + + iuo_act_slices = torch.split(i2h_iuo, i2h_iuo.size(1) // 3, dim=1) # (N, C)*3 + i_act, u_act, o_act = iuo_act_slices[0], iuo_act_slices[1], iuo_act_slices[2] # (N, C) each + + # calculate gate outputs + in_gate = F.sigmoid(i_act) + in_transform = F.tanh(u_act) + out_gate = F.sigmoid(o_act) + + # calculate cell state and hidden state + next_c = in_gate * in_transform + if children_states: + next_c = torch.sum(forget_gates * cs, dim=1) + next_c + next_h = out_gate * torch.tanh(next_c) + + return next_h, [next_h, next_c] + # module for distance-angle similarity class Similarity(nn.Module): - def __init__(self, cuda, mem_dim, hidden_dim, num_classes): + def __init__(self, sim_hidden_size, rnn_hidden_size, num_classes): super(Similarity, self).__init__() - self.cudaFlag = cuda - self.mem_dim = mem_dim - self.hidden_dim = hidden_dim - self.num_classes = num_classes - self.wh = nn.Linear(2*self.mem_dim, self.hidden_dim) - self.wp = nn.Linear(self.hidden_dim, self.num_classes) - - def forward(self, lvec, rvec): - mult_dist = F.torch.mul(lvec, rvec) - abs_dist = F.torch.abs(F.torch.add(lvec, -rvec)) - vec_dist = F.torch.cat((mult_dist, abs_dist),1) - out = F.sigmoid(self.wh(vec_dist)) - # out = F.sigmoid(out) - out = F.log_softmax(self.wp(out)) + self.wh = nn.Linear(2*rnn_hidden_size, sim_hidden_size) + self.wp = nn.Linear(sim_hidden_size, num_classes) + + def forward(self, F, lvec, rvec): + # lvec and rvec will be tree_lstm cell states at roots + mult_dist = lvec * rvec + abs_dist = torch.abs(lvec - rvec) + vec_dist = torch.cat([mult_dist, abs_dist], dim=1) + out = F.log_softmax(self.wp(torch.sigmoid(self.wh(vec_dist)))) return out -# puttinh the whole model together + +# putting the whole model together class SimilarityTreeLSTM(nn.Module): - def __init__(self, cuda, vocab_size, in_dim, mem_dim, hidden_dim, num_classes, sparsity): + def __init__(self, sim_hidden_size, rnn_hidden_size, + embed_in_size, embed_dim, num_classes): super(SimilarityTreeLSTM, self).__init__() - self.cudaFlag = cuda - self.childsumtreelstm = ChildSumTreeLSTM(cuda, vocab_size, in_dim, mem_dim, sparsity) - self.similarity = Similarity(cuda, mem_dim, hidden_dim, num_classes) - - def forward(self, ltree, linputs, rtree, rinputs): - lstate, lhidden = self.childsumtreelstm(ltree, linputs) - rstate, rhidden = self.childsumtreelstm(rtree, rinputs) - output = self.similarity(lstate, rstate) + self.embed = nn.Embedding(embed_in_size, embed_dim) + self.childsumtreelstm = ChildSumLSTMCell(rnn_hidden_size, input_size=embed_dim) + self.similarity = Similarity(sim_hidden_size, rnn_hidden_size, num_classes) + + def forward(self, l_inputs, r_inputs, l_tree, r_tree): + l_inputs = self.embed(l_inputs) + r_inputs = self.embed(r_inputs) + # get cell states at roots + lstate = self.childsumtreelstm(l_inputs, l_tree)[1][1] + rstate = self.childsumtreelstm(r_inputs, r_tree)[1][1] + output = self.similarity(F, lstate, rstate) return output diff --git a/trainer.py b/trainer.py index 7ffc4a9..d77cf8b 100644 --- a/trainer.py +++ b/trainer.py @@ -25,7 +25,7 @@ def train(self, dataset): if self.args.cuda: linput, rinput = linput.cuda(), rinput.cuda() target = target.cuda() - output = self.model(ltree,linput,rtree,rinput) + output = self.model(linput, rinput, ltree, rtree) err = self.criterion(output, target) loss += err.data[0] err.backward() @@ -49,7 +49,7 @@ def test(self, dataset): if self.args.cuda: linput, rinput = linput.cuda(), rinput.cuda() target = target.cuda() - output = self.model(ltree,linput,rtree,rinput) + output = self.model(linput, rinput, ltree, rtree) err = self.criterion(output, target) loss += err.data[0] output = output.data.squeeze().cpu()