Skip to content

Commit ca856c8

Browse files
authored
Enable max readouts
1 parent a93356e commit ca856c8

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchdrug/models/gat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable):
2525
batch_norm (bool, optional): apply batch normalization or not
2626
activation (str or function, optional): activation function
2727
concat_hidden (bool, optional): concat hidden representations from all layers as output
28-
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
28+
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
2929
"""
3030

3131
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False,
@@ -49,6 +49,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega
4949
self.readout = layers.SumReadout()
5050
elif readout == "mean":
5151
self.readout = layers.MeanReadout()
52+
elif readout == "max":
53+
self.readout = layers.MaxReadout()
5254
else:
5355
raise ValueError("Unknown readout `%s`" % readout)
5456

@@ -85,4 +87,4 @@ def forward(self, graph, input, all_loss=None, metric=None):
8587
return {
8688
"graph_feature": graph_feature,
8789
"node_feature": node_feature
88-
}
90+
}

0 commit comments

Comments
 (0)