@@ -23,7 +23,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
2323 batch_norm (bool, optional): apply batch normalization or not
2424 activation (str or function, optional): activation function
2525 concat_hidden (bool, optional): concat hidden representations from all layers as output
26- readout (str, optional): readout function. Available functions are ``sum`` and ``mean ``.
26+ readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max ``.
2727 """
2828
2929 def __init__ (self , input_dim , hidden_dims , edge_input_dim = None , short_cut = False , batch_norm = False ,
@@ -46,6 +46,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
4646 self .readout = layers .SumReadout ()
4747 elif readout == "mean" :
4848 self .readout = layers .MeanReadout ()
49+ elif readout == "max" :
50+ self .readout = layers .MaxReadout ()
4951 else :
5052 raise ValueError ("Unknown readout `%s`" % readout )
5153
@@ -102,7 +104,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
102104 batch_norm (bool, optional): apply batch normalization or not
103105 activation (str or function, optional): activation function
104106 concat_hidden (bool, optional): concat hidden representations from all layers as output
105- readout (str, optional): readout function. Available functions are ``sum`` and ``mean ``.
107+ readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max ``.
106108 """
107109
108110 def __init__ (self , input_dim , hidden_dims , num_relation , edge_input_dim = None , short_cut = False , batch_norm = False ,
@@ -127,6 +129,8 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
127129 self .readout = layers .SumReadout ()
128130 elif readout == "mean" :
129131 self .readout = layers .MeanReadout ()
132+ elif readout == "max" :
133+ self .readout = layers .MaxReadout ()
130134 else :
131135 raise ValueError ("Unknown readout `%s`" % readout )
132136
@@ -165,4 +169,4 @@ def forward(self, graph, input, all_loss=None, metric=None):
165169 return {
166170 "graph_feature" : graph_feature ,
167171 "node_feature" : node_feature
168- }
172+ }
0 commit comments