@@ -35,15 +35,20 @@ def get_encoder_names():
3535 return list (encoders .keys ())
3636
3737
38- def get_preprocessing_fn (encoder_name , pretrained = 'imagenet' ):
38+ def get_preprocessing_params (encoder_name , pretrained = 'imagenet' ):
3939 settings = encoders [encoder_name ]['pretrained_settings' ]
4040
4141 if pretrained not in settings .keys ():
4242 raise ValueError ('Avaliable pretrained options {}' .format (settings .keys ()))
43-
44- input_space = settings [pretrained ].get ('input_space' )
45- input_range = settings [pretrained ].get ('input_range' )
46- mean = settings [pretrained ].get ('mean' )
47- std = settings [pretrained ].get ('std' )
4843
49- return functools .partial (preprocess_input , mean = mean , std = std , input_space = input_space , input_range = input_range )
44+ formatted_settings = {}
45+ formatted_settings ['input_space' ] = settings [pretrained ].get ('input_space' )
46+ formatted_settings ['input_range' ] = settings [pretrained ].get ('input_range' )
47+ formatted_settings ['mean' ] = settings [pretrained ].get ('mean' )
48+ formatted_settings ['std' ] = settings [pretrained ].get ('std' )
49+ return formatted_settings
50+
51+
52+ def get_preprocessing_fn (encoder_name , pretrained = 'imagenet' ):
53+ params = get_preprocessing_params (encoder_name , pretrained = pretrained )
54+ return functools .partial (preprocess_input , ** params )
0 commit comments