11from collections .abc import Sequence
22
33import torch
4+ from class_resolver import Hint
45from torch import nn
56
67from torchdrug import core , layers
8+ from torchdrug .layers import readout_resolver , Readout
79from torchdrug .core import Registry as R
810
911
@@ -23,11 +25,11 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
2325 batch_norm (bool, optional): apply batch normalization or not
2426 activation (str or function, optional): activation function
2527 concat_hidden (bool, optional): concat hidden representations from all layers as output
26- readout (str, optional) : readout function. Available functions are ``sum``, ``mean``, and ``max``.
28+ readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
2729 """
2830
2931 def __init__ (self , input_dim , hidden_dims , edge_input_dim = None , short_cut = False , batch_norm = False ,
30- activation = "relu" , concat_hidden = False , readout = "sum" ):
32+ activation = "relu" , concat_hidden = False , readout : Hint [ Readout ] = "sum" ):
3133 super (GraphConvolutionalNetwork , self ).__init__ ()
3234
3335 if not isinstance (hidden_dims , Sequence ):
@@ -42,14 +44,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
4244 for i in range (len (self .dims ) - 1 ):
4345 self .layers .append (layers .GraphConv (self .dims [i ], self .dims [i + 1 ], edge_input_dim , batch_norm , activation ))
4446
45- if readout == "sum" :
46- self .readout = layers .SumReadout ()
47- elif readout == "mean" :
48- self .readout = layers .MeanReadout ()
49- elif readout == "max" :
50- self .readout = layers .MaxReadout ()
51- else :
52- raise ValueError ("Unknown readout `%s`" % readout )
47+ self .readout = readout_resolver .make (readout )
5348
5449 def forward (self , graph , input , all_loss = None , metric = None ):
5550 """
0 commit comments