11from collections .abc import Sequence
22
33import torch
4- from class_resolver import Hint
54from torch import nn
65
76from torchdrug import core , layers
87from torchdrug .core import Registry as R
9- from torchdrug .layers import Readout , readout_resolver
108
119
1210@R .register ("models.GCN" )
@@ -29,7 +27,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
2927 """
3028
3129 def __init__ (self , input_dim , hidden_dims , edge_input_dim = None , short_cut = False , batch_norm = False ,
32- activation = "relu" , concat_hidden = False , readout : Hint [ Readout ] = "sum" ):
30+ activation = "relu" , concat_hidden = False , readout = "sum" ):
3331 super (GraphConvolutionalNetwork , self ).__init__ ()
3432
3533 if not isinstance (hidden_dims , Sequence ):
@@ -44,7 +42,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
4442 for i in range (len (self .dims ) - 1 ):
4543 self .layers .append (layers .GraphConv (self .dims [i ], self .dims [i + 1 ], edge_input_dim , batch_norm , activation ))
4644
47- self .readout = readout_resolver .make (readout )
45+ if readout == "sum" :
46+ self .readout = layers .SumReadout ()
47+ elif readout == "mean" :
48+ self .readout = layers .MeanReadout ()
49+ elif readout == "max" :
50+ self .readout = layers .MaxReadout ()
51+ else :
52+ raise ValueError ("Unknown readout `%s`" % readout )
4853
4954 def forward (self , graph , input , all_loss = None , metric = None ):
5055 """
@@ -103,7 +108,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
103108 """
104109
105110 def __init__ (self , input_dim , hidden_dims , num_relation , edge_input_dim = None , short_cut = False , batch_norm = False ,
106- activation = "relu" , concat_hidden = False , readout : Hint [ Readout ] = "sum" ):
111+ activation = "relu" , concat_hidden = False , readout = "sum" ):
107112 super (RelationalGraphConvolutionalNetwork , self ).__init__ ()
108113
109114 if not isinstance (hidden_dims , Sequence ):
@@ -120,7 +125,14 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
120125 self .layers .append (layers .RelationalGraphConv (self .dims [i ], self .dims [i + 1 ], num_relation , edge_input_dim ,
121126 batch_norm , activation ))
122127
123- self .readout = readout_resolver .make (readout )
128+ if readout == "sum" :
129+ self .readout = layers .SumReadout ()
130+ elif readout == "mean" :
131+ self .readout = layers .MeanReadout ()
132+ elif readout == "max" :
133+ self .readout = layers .MaxReadout ()
134+ else :
135+ raise ValueError ("Unknown readout `%s`" % readout )
124136
125137 def forward (self , graph , input , all_loss = None , metric = None ):
126138 """
0 commit comments