diff --git a/bnn/model.py b/bnn/model.py index 4e4851d..497aab7 100644 --- a/bnn/model.py +++ b/bnn/model.py @@ -88,7 +88,7 @@ def load_full_epistemic_uncertainty_model(encoder, input_shape, checkpoint, epis def create_bayesian_model(encoder, input_shape, output_classes): - encoder_model = resnet50(encoder, input_shape) + encoder_model = create_encoder_model(encoder, input_shape) input_tensor = Input(shape=encoder_model.output_shape[1:]) x = BatchNormalization(name='post_encoder')(input_tensor) x = Dropout(0.5)(x)