11from collections .abc import Sequence
22
33import torch
4- from class_resolver import Hint
54from torch import nn
65
76from torchdrug import core , layers
87from torchdrug .core import Registry as R
9- from torchdrug .layers import Readout , readout_resolver
108
119
1210@R .register ("models.GIN" )
@@ -32,8 +30,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
3230 """
3331
3432 def __init__ (self , input_dim = None , hidden_dims = None , edge_input_dim = None , num_mlp_layer = 2 , eps = 0 , learn_eps = False ,
35- short_cut = False , batch_norm = False , activation = "relu" , concat_hidden = False ,
36- readout : Hint [Readout ] = "sum" ):
33+ short_cut = False , batch_norm = False , activation = "relu" , concat_hidden = False , readout = "sum" ):
3734 super (GraphIsomorphismNetwork , self ).__init__ ()
3835
3936 if not isinstance (hidden_dims , Sequence ):
@@ -50,7 +47,14 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml
5047 self .layers .append (layers .GraphIsomorphismConv (self .dims [i ], self .dims [i + 1 ], edge_input_dim ,
5148 layer_hidden_dims , eps , learn_eps , batch_norm , activation ))
5249
53- self .readout = readout_resolver .make (readout )
50+ if readout == "sum" :
51+ self .readout = layers .SumReadout ()
52+ elif readout == "mean" :
53+ self .readout = layers .MeanReadout ()
54+ elif readout == "max" :
55+ self .readout = layers .MaxReadout ()
56+ else :
57+ raise ValueError ("Unknown readout `%s`" % readout )
5458
5559 def forward (self , graph , input , all_loss = None , metric = None ):
5660 """
0 commit comments