Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ def __init__(
self._all_sliced_dims: dict[Hashable, int] = dict(
**self._unique_batch_dims, **self.input_dims
)

# Check that duplicate dims imply whole patches per batch
for dim, length in self.batch_dims.items():
input_dim_length = self.input_dims.get(dim)
if input_dim_length is not None and length % input_dim_length != 0:
raise ValueError(
f'Input and batch dimension sizes imply partial batches '
f'on dimension {dim}. Input size: {input_dim_length}; Batch size: {length}'
)

self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds)

def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet:
Expand Down
14 changes: 14 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,20 @@ def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d):
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


def test_batch_3d_uneven_batch_input_dim(sample_ds_3d):
"""
Test for error when a batch dimension is not a multiple of the
corresponding input dimension.
"""
with pytest.raises(ValueError, match='imply partial batches'):
_ = BatchGenerator(
sample_ds_3d,
input_dims={'x': 5, 'y': 10},
batch_dims={'x': 11, 'y': 21},
concat_input_dims=True,
)


@pytest.mark.parametrize('input_size', [5, 10])
def test_batch_3d_2d_input(sample_ds_3d, input_size):
"""
Expand Down
Loading