@@ -17,7 +17,6 @@ class NeuralLogicProgramming(nn.Module, core.Configurable):
1717 https://papers.nips.cc/paper/2017/file/0e55666a4ad822e0e34299df3591d979-Paper.pdf
1818
1919 Parameters:
20- num_entity (int): number of entities
2120 num_relation (int): number of relations
2221 hidden_dim (int): dimension of hidden units in LSTM
2322 num_step (int): number of recurrent steps
@@ -26,17 +25,15 @@ class NeuralLogicProgramming(nn.Module, core.Configurable):
2625
2726 eps = 1e-10
2827
29- def __init__ (self , num_entity , num_relation , hidden_dim , num_step , num_lstm_layer = 1 ):
28+ def __init__ (self , num_relation , hidden_dim , num_step , num_lstm_layer = 1 ):
3029 super (NeuralLogicProgramming , self ).__init__ ()
3130
3231 num_relation = int (num_relation )
33- self .num_entity = num_entity
3432 self .num_relation = num_relation
3533 self .num_step = num_step
3634
3735 self .query = nn .Embedding (num_relation * 2 + 1 , hidden_dim )
3836 self .lstm = nn .LSTM (hidden_dim , hidden_dim , num_lstm_layer )
39- self .key_linear = nn .Linear (hidden_dim , hidden_dim )
4037 self .weight_linear = nn .Linear (hidden_dim , num_relation * 2 )
4138 self .linear = nn .Linear (1 , 1 )
4239
@@ -56,7 +53,7 @@ def get_t_output(self, graph, h_index, r_index):
5653 query = self .query (q_index )
5754
5855 hidden , hx = self .lstm (query )
59- memory = functional .one_hot (h_index , self .num_entity ).unsqueeze (0 )
56+ memory = functional .one_hot (h_index , graph .num_entity ).unsqueeze (0 )
6057
6158 for i in range (self .num_step ):
6259 key = hidden [i ]
0 commit comments