@@ -134,6 +134,7 @@ def __init__(
134134 super (SelectAdaptivePool2d , self ).__init__ ()
135135 assert input_fmt in ('NCHW' , 'NHWC' )
136136 self .pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
137+ pool_type = pool_type .lower ()
137138 if not pool_type :
138139 self .pool = nn .Identity () # pass through
139140 self .flatten = nn .Flatten (1 ) if flatten else nn .Identity ()
@@ -145,8 +146,10 @@ def __init__(
145146 self .pool = FastAdaptiveAvgMaxPool (flatten , input_fmt = input_fmt )
146147 elif pool_type .endswith ('max' ):
147148 self .pool = FastAdaptiveMaxPool (flatten , input_fmt = input_fmt )
148- else :
149+ elif pool_type == 'fast' or pool_type . endswith ( 'avg' ) :
149150 self .pool = FastAdaptiveAvgPool (flatten , input_fmt = input_fmt )
151+ else :
152+ assert False , 'Invalid pool type: %s' % pool_type
150153 self .flatten = nn .Identity ()
151154 else :
152155 assert input_fmt == 'NCHW'
@@ -156,8 +159,10 @@ def __init__(
156159 self .pool = AdaptiveCatAvgMaxPool2d (output_size )
157160 elif pool_type == 'max' :
158161 self .pool = nn .AdaptiveMaxPool2d (output_size )
159- else :
162+ elif pool_type == 'avg' :
160163 self .pool = nn .AdaptiveAvgPool2d (output_size )
164+ else :
165+ assert False , 'Invalid pool type: %s' % pool_type
161166 self .flatten = nn .Flatten (1 ) if flatten else nn .Identity ()
162167
163168 def is_identity (self ):
0 commit comments