diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 2616ab3..8f1055f 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -89,6 +89,16 @@ def __init__( self._all_sliced_dims: dict[Hashable, int] = dict( **self._unique_batch_dims, **self.input_dims ) + + # Check that duplicate dims imply whole patches per batch + for dim, length in self.batch_dims.items(): + input_dim_length = self.input_dims.get(dim) + if input_dim_length is not None and length % input_dim_length != 0: + raise ValueError( + f'Input and batch dimension sizes imply partial batches ' + f'on dimension {dim}. Input size: {input_dim_length}; Batch size: {length}' + ) + self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet: diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 166648c..2bf4124 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -265,6 +265,20 @@ def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) +def test_batch_3d_uneven_batch_input_dim(sample_ds_3d): + """ + Test for error when a batch dimension is not a multiple of the + corresponding input dimension. + """ + with pytest.raises(ValueError, match='imply partial batches'): + _ = BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 10}, + batch_dims={'x': 11, 'y': 21}, + concat_input_dims=True, + ) + + @pytest.mark.parametrize('input_size', [5, 10]) def test_batch_3d_2d_input(sample_ds_3d, input_size): """