@@ -25,7 +25,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable):
2525 batch_norm (bool, optional): apply batch normalization or not
2626 activation (str or function, optional): activation function
2727 concat_hidden (bool, optional): concat hidden representations from all layers as output
28- readout (str, optional): readout function. Available functions are ``sum`` and ``mean ``.
28+ readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max ``.
2929 """
3030
3131 def __init__ (self , input_dim , hidden_dims , edge_input_dim = None , num_head = 1 , negative_slope = 0.2 , short_cut = False ,
@@ -49,6 +49,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega
4949 self .readout = layers .SumReadout ()
5050 elif readout == "mean" :
5151 self .readout = layers .MeanReadout ()
52+ elif readout == "max" :
53+ self .readout = layers .MaxReadout ()
5254 else :
5355 raise ValueError ("Unknown readout `%s`" % readout )
5456
@@ -85,4 +87,4 @@ def forward(self, graph, input, all_loss=None, metric=None):
8587 return {
8688 "graph_feature" : graph_feature ,
8789 "node_feature" : node_feature
88- }
90+ }
0 commit comments