55from torch import nn
66
77from torchdrug import core , layers
8- from torchdrug .layers import readout_resolver , Readout
98from torchdrug .core import Registry as R
9+ from torchdrug .layers import Readout , readout_resolver
1010
1111
1212@R .register ("models.GCN" )
@@ -99,11 +99,11 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
9999 batch_norm (bool, optional): apply batch normalization or not
100100 activation (str or function, optional): activation function
101101 concat_hidden (bool, optional): concat hidden representations from all layers as output
102- readout (str, optional) : readout function. Available functions are ``sum``, ``mean``, and ``max``.
102+ readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
103103 """
104104
105105 def __init__ (self , input_dim , hidden_dims , num_relation , edge_input_dim = None , short_cut = False , batch_norm = False ,
106- activation = "relu" , concat_hidden = False , readout = "sum" ):
106+ activation = "relu" , concat_hidden = False , readout : Hint [ Readout ] = "sum" ):
107107 super (RelationalGraphConvolutionalNetwork , self ).__init__ ()
108108
109109 if not isinstance (hidden_dims , Sequence ):
@@ -120,14 +120,7 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
120120 self .layers .append (layers .RelationalGraphConv (self .dims [i ], self .dims [i + 1 ], num_relation , edge_input_dim ,
121121 batch_norm , activation ))
122122
123- if readout == "sum" :
124- self .readout = layers .SumReadout ()
125- elif readout == "mean" :
126- self .readout = layers .MeanReadout ()
127- elif readout == "max" :
128- self .readout = layers .MaxReadout ()
129- else :
130- raise ValueError ("Unknown readout `%s`" % readout )
123+ self .readout = readout_resolver .make (readout )
131124
132125 def forward (self , graph , input , all_loss = None , metric = None ):
133126 """
0 commit comments