diff --git a/changes/3547.feature.md b/changes/3547.feature.md new file mode 100644 index 0000000000..d933865bd2 --- /dev/null +++ b/changes/3547.feature.md @@ -0,0 +1 @@ +Add a `array.max_bytes_per_shard_for_auto_sharding` to [`zarr.config`][] to allow users to set a maximum number of bytes per-shard when `shards="auto"` in, for example, [`zarr.create_array`][]. \ No newline at end of file diff --git a/docs/user-guide/performance.md b/docs/user-guide/performance.md index 88d8e69936..6550064bc1 100644 --- a/docs/user-guide/performance.md +++ b/docs/user-guide/performance.md @@ -81,6 +81,8 @@ z6 = zarr.create_array(store={}, shape=(10000, 10000, 1000), shards=(1000, 1000, print(z6.info) ``` +`shards` can be `"auto"` as well, in which case the `array.max_bytes_per_shard_for_auto_sharding` setting can be used to control the size of shards; otherwise, a default is used. + ### Chunk memory layout The order of bytes **within each chunk** of an array can be changed via the diff --git a/src/zarr/core/chunk_grids.py b/src/zarr/core/chunk_grids.py index cf5bd8cdbe..ed9fb5755b 100644 --- a/src/zarr/core/chunk_grids.py +++ b/src/zarr/core/chunk_grids.py @@ -12,6 +12,7 @@ import numpy as np +import zarr from zarr.abc.metadata import Metadata from zarr.core.common import ( JSON, @@ -202,6 +203,22 @@ def get_nchunks(self, array_shape: tuple[int, ...]) -> int: ) +def _guess_num_chunks_per_axis_shard( + chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...] +) -> int: + bytes_per_chunk = np.prod(chunk_shape) * item_size + if max_bytes < bytes_per_chunk: + return 1 + num_axes = len(chunk_shape) + chunks_per_shard = 1 + # First check for byte size, second check to make sure we don't go bigger than the array shape + while (bytes_per_chunk * ((chunks_per_shard + 1) ** num_axes)) <= max_bytes and all( + c * (chunks_per_shard + 1) <= a for c, a in zip(chunk_shape, array_shape, strict=True) + ): + chunks_per_shard += 1 + return chunks_per_shard + + def _auto_partition( *, array_shape: tuple[int, ...], @@ -237,12 +254,24 @@ def _auto_partition( stacklevel=2, ) _shards_out = () + max_bytes_per_shard_for_auto_sharding = zarr.config.get( + "array.max_bytes_per_shard_for_auto_sharding", None + ) + num_chunks_per_shard_axis = ( + _guess_num_chunks_per_axis_shard( + chunk_shape=_chunks_out, + item_size=item_size, + max_bytes=max_bytes_per_shard_for_auto_sharding, + array_shape=array_shape, + ) + if (has_auto_shard := (max_bytes_per_shard_for_auto_sharding is not None)) + else 2 + ) for a_shape, c_shape in zip(array_shape, _chunks_out, strict=True): - # TODO: make a better heuristic than this. - # for each axis, if there are more than 8 chunks along that axis, then put - # 2 chunks in each shard for that axis. - if a_shape // c_shape > 8: - _shards_out += (c_shape * 2,) + # The previous heuristic was `a_shape // c_shape > 8` and now, with max_bytes_per_shard_for_auto_sharding, we only check that the shard size is less than the array size. + can_shard_axis = a_shape // c_shape > 8 if not has_auto_shard else True + if can_shard_axis: + _shards_out += (c_shape * num_chunks_per_shard_axis,) else: _shards_out += (c_shape,) elif isinstance(shard_shape, dict): diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 5acf242ef7..3aad7e66e4 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -96,6 +96,7 @@ def enable_gpu(self) -> ConfigSet: "array": { "order": "C", "write_empty_chunks": False, + "max_bytes_per_shard_for_auto_sharding": None, }, "async": {"concurrency": 64, "timeout": None}, "threading": {"max_workers": None}, diff --git a/tests/test_array.py b/tests/test_array.py index 5219616739..7acff567f0 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -966,33 +966,58 @@ async def test_nbytes( @pytest.mark.parametrize( - ("array_shape", "chunk_shape"), - [((256,), (2,))], + ("array_shape", "chunk_shape", "max_bytes_per_shard_for_auto_sharding", "expected_shards"), + [ + pytest.param( + (256, 256), + (32, 32), + 129 * 129, + (128, 128), + id="2d_chunking_max_byes_does_not_evenly_divide", + ), + pytest.param( + (256, 256), (32, 32), 64 * 64, (64, 64), id="2d_chunking_max_byes_evenly_divides" + ), + pytest.param( + (256, 256), + (64, 32), + 128 * 128, + (128, 64), + id="2d_non_square_chunking_max_byes_evenly_divides", + ), + pytest.param((256,), (2,), 255, (254,), id="max_bytes_just_below_array_shape"), + pytest.param((256,), (2,), 256, (256,), id="max_bytes_equal_to_array_shape"), + pytest.param((256,), (2,), 16, (16,), id="max_bytes_normal_val"), + pytest.param((256,), (2,), 2, (2,), id="max_bytes_same_as_chunk"), + pytest.param((256,), (2,), 1, (2,), id="max_bytes_less_than_chunk"), + pytest.param((256,), (2,), None, (4,), id="use_default_auto_setting"), + pytest.param((4,), (2,), None, (2,), id="small_array_shape_does_not_shard"), + ], ) def test_auto_partition_auto_shards( - array_shape: tuple[int, ...], chunk_shape: tuple[int, ...] + array_shape: tuple[int, ...], + chunk_shape: tuple[int, ...], + max_bytes_per_shard_for_auto_sharding: int | None, + expected_shards: tuple[int, ...], ) -> None: """ Test that automatically picking a shard size returns a tuple of 2 * the chunk shape for any axis where there are 8 or more chunks. """ dtype = np.dtype("uint8") - expected_shards: tuple[int, ...] = () - for cs, a_len in zip(chunk_shape, array_shape, strict=False): - if a_len // cs >= 8: - expected_shards += (2 * cs,) - else: - expected_shards += (cs,) with pytest.warns( ZarrUserWarning, match="Automatic shard shape inference is experimental and may change without notice.", ): - auto_shards, _ = _auto_partition( - array_shape=array_shape, - chunk_shape=chunk_shape, - shard_shape="auto", - item_size=dtype.itemsize, - ) + with zarr.config.set( + {"array.max_bytes_per_shard_for_auto_sharding": max_bytes_per_shard_for_auto_sharding} + ): + auto_shards, _ = _auto_partition( + array_shape=array_shape, + chunk_shape=chunk_shape, + shard_shape="auto", + item_size=dtype.itemsize, + ) assert auto_shards == expected_shards diff --git a/tests/test_config.py b/tests/test_config.py index 150aca7c96..c6a08b8118 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -53,6 +53,7 @@ def test_config_defaults_set() -> None: "array": { "order": "C", "write_empty_chunks": False, + "max_bytes_per_shard_for_auto_sharding": None, }, "async": {"concurrency": 64, "timeout": None}, "threading": {"max_workers": None},