@@ -103,18 +103,33 @@ def sigmoid(x):
103103 assert_approx_equal (eval_loss , loss )
104104
105105
106- @pytest .fixture (scope = 'module' )
107- def conv2d_model ():
108- """Initialize to be tested conv model. Executed once.
106+ @pytest .fixture (scope = 'module' , params = ['channels_first' , 'channels_last' ])
107+ def conv2d_model (request ):
108+ """Initialize to be tested conv model. Executed once per param:
109+ The whole tests are repeated for respectively
110+ `channels_first` and `channels_last`.
109111 """
112+ assert request .param in {'channels_last' , 'channels_first' }
113+ K .set_image_data_format (request .param )
114+
110115 # DATA
111116 in_dim = 20
112117 init_prop = .1
113118 np .random .seed (1 )
114- X = np .random .randn (1 , in_dim , in_dim , 1 )
119+ if K .image_data_format () == 'channels_last' :
120+ X = np .random .randn (1 , in_dim , in_dim , 1 )
121+ elif K .image_data_format () == 'channels_first' :
122+ X = np .random .randn (1 , 1 , in_dim , in_dim )
123+ else :
124+ raise ValueError ('Unknown data_format:' , K .image_data_format ())
115125
116126 # MODEL
117- inputs = Input (shape = (in_dim , in_dim , 1 ,))
127+ if K .image_data_format () == 'channels_last' :
128+ inputs = Input (shape = (in_dim , in_dim , 1 ,))
129+ elif K .image_data_format () == 'channels_first' :
130+ inputs = Input (shape = (1 , in_dim , in_dim ,))
131+ else :
132+ raise ValueError ('Unknown data_format:' , K .image_data_format ())
118133 conv2d = Conv2D (1 , (3 , 3 ))
119134 # Model, normal
120135 cd = ConcreteDropout (conv2d , in_dim , prob_init = (init_prop , init_prop ))
0 commit comments