Skip to content

Commit 2c5d0f7

Browse files
authored
Update gcn.py
1 parent ca856c8 commit 2c5d0f7

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

torchdrug/models/gcn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
2323
batch_norm (bool, optional): apply batch normalization or not
2424
activation (str or function, optional): activation function
2525
concat_hidden (bool, optional): concat hidden representations from all layers as output
26-
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
26+
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
2727
"""
2828

2929
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
@@ -46,6 +46,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
4646
self.readout = layers.SumReadout()
4747
elif readout == "mean":
4848
self.readout = layers.MeanReadout()
49+
elif readout == "max":
50+
self.readout = layers.MaxReadout()
4951
else:
5052
raise ValueError("Unknown readout `%s`" % readout)
5153

@@ -102,7 +104,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
102104
batch_norm (bool, optional): apply batch normalization or not
103105
activation (str or function, optional): activation function
104106
concat_hidden (bool, optional): concat hidden representations from all layers as output
105-
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
107+
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
106108
"""
107109

108110
def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False,
@@ -127,6 +129,8 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
127129
self.readout = layers.SumReadout()
128130
elif readout == "mean":
129131
self.readout = layers.MeanReadout()
132+
elif readout == "max":
133+
self.readout = layers.MaxReadout()
130134
else:
131135
raise ValueError("Unknown readout `%s`" % readout)
132136

@@ -165,4 +169,4 @@ def forward(self, graph, input, all_loss=None, metric=None):
165169
return {
166170
"graph_feature": graph_feature,
167171
"node_feature": node_feature
168-
}
172+
}

0 commit comments

Comments
 (0)