@@ -189,6 +189,46 @@ def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize):
189189 assert list (ds_batch ['foo' ].shape ) == [xbsize ]
190190
191191
192+ @pytest .mark .parametrize ('bsize' , [5 , 10 ])
193+ def test_batch_3d_squeeze_batch_dim (sample_ds_3d , bsize ):
194+ xbsize = 20
195+ bg = BatchGenerator (
196+ sample_ds_3d ,
197+ input_dims = {'y' : bsize , 'x' : xbsize },
198+ squeeze_batch_dim = False ,
199+ )
200+ for ds_batch in bg :
201+ assert list (ds_batch ['foo' ].shape ) == [10 , bsize , xbsize ]
202+
203+ bg2 = BatchGenerator (
204+ sample_ds_3d ,
205+ input_dims = {'y' : bsize , 'x' : xbsize },
206+ squeeze_batch_dim = True ,
207+ )
208+ for ds_batch in bg2 :
209+ assert list (ds_batch ['foo' ].shape ) == [10 , bsize , xbsize ]
210+
211+
212+ @pytest .mark .parametrize ('bsize' , [5 , 10 ])
213+ def test_batch_3d_squeeze_batch_dim2 (sample_ds_3d , bsize ):
214+ xbsize = 20
215+ bg = BatchGenerator (
216+ sample_ds_3d ,
217+ input_dims = {'x' : xbsize },
218+ squeeze_batch_dim = False ,
219+ )
220+ for ds_batch in bg :
221+ assert list (ds_batch ['foo' ].shape ) == [500 , xbsize ]
222+
223+ bg2 = BatchGenerator (
224+ sample_ds_3d ,
225+ input_dims = {'x' : xbsize },
226+ squeeze_batch_dim = True ,
227+ )
228+ for ds_batch in bg2 :
229+ assert list (ds_batch ['foo' ].shape ) == [500 , xbsize ]
230+
231+
192232def test_preload_batch_false (sample_ds_1d ):
193233 sample_ds_1d_dask = sample_ds_1d .chunk ({'x' : 2 })
194234 bg = BatchGenerator (
0 commit comments