Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3547.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a `array.max_bytes_per_shard_for_auto_sharding` to [`zarr.config`][] to allow users to set a maximum number of bytes per-shard when `shards="auto"` in, for example, [`zarr.create_array`][].
2 changes: 2 additions & 0 deletions docs/user-guide/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ z6 = zarr.create_array(store={}, shape=(10000, 10000, 1000), shards=(1000, 1000,
print(z6.info)
```

`shards` can be `"auto"` as well, in which case the `array.max_bytes_per_shard_for_auto_sharding` setting can be used to control the size of shards; otherwise, a default is used.

### Chunk memory layout

The order of bytes **within each chunk** of an array can be changed via the
Expand Down
39 changes: 34 additions & 5 deletions src/zarr/core/chunk_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np

import zarr
from zarr.abc.metadata import Metadata
from zarr.core.common import (
JSON,
Expand Down Expand Up @@ -202,6 +203,22 @@ def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
)


def _guess_num_chunks_per_axis_shard(
chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...]
) -> int:
bytes_per_chunk = np.prod(chunk_shape) * item_size
if max_bytes < bytes_per_chunk:
return 1
num_axes = len(chunk_shape)
chunks_per_shard = 1
# First check for byte size, second check to make sure we don't go bigger than the array shape
while (bytes_per_chunk * ((chunks_per_shard + 1) ** num_axes)) <= max_bytes and all(
c * (chunks_per_shard + 1) <= a for c, a in zip(chunk_shape, array_shape, strict=True)
):
chunks_per_shard += 1
return chunks_per_shard


def _auto_partition(
*,
array_shape: tuple[int, ...],
Expand Down Expand Up @@ -237,12 +254,24 @@ def _auto_partition(
stacklevel=2,
)
_shards_out = ()
max_bytes_per_shard_for_auto_sharding = zarr.config.get(
"array.max_bytes_per_shard_for_auto_sharding", None
)
num_chunks_per_shard_axis = (
_guess_num_chunks_per_axis_shard(
chunk_shape=_chunks_out,
item_size=item_size,
max_bytes=max_bytes_per_shard_for_auto_sharding,
array_shape=array_shape,
)
if (has_auto_shard := (max_bytes_per_shard_for_auto_sharding is not None))
else 2
)
for a_shape, c_shape in zip(array_shape, _chunks_out, strict=True):
# TODO: make a better heuristic than this.
# for each axis, if there are more than 8 chunks along that axis, then put
# 2 chunks in each shard for that axis.
if a_shape // c_shape > 8:
_shards_out += (c_shape * 2,)
# The previous heuristic was `a_shape // c_shape > 8` and now, with max_bytes_per_shard_for_auto_sharding, we only check that the shard size is less than the array size.
can_shard_axis = a_shape // c_shape > 8 if not has_auto_shard else True
if can_shard_axis:
_shards_out += (c_shape * num_chunks_per_shard_axis,)
else:
_shards_out += (c_shape,)
elif isinstance(shard_shape, dict):
Expand Down
1 change: 1 addition & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def enable_gpu(self) -> ConfigSet:
"array": {
"order": "C",
"write_empty_chunks": False,
"max_bytes_per_shard_for_auto_sharding": None,
},
"async": {"concurrency": 64, "timeout": None},
"threading": {"max_workers": None},
Expand Down
55 changes: 40 additions & 15 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,33 +966,58 @@ async def test_nbytes(


@pytest.mark.parametrize(
("array_shape", "chunk_shape"),
[((256,), (2,))],
("array_shape", "chunk_shape", "max_bytes_per_shard_for_auto_sharding", "expected_shards"),
[
pytest.param(
(256, 256),
(32, 32),
129 * 129,
(128, 128),
id="2d_chunking_max_byes_does_not_evenly_divide",
),
pytest.param(
(256, 256), (32, 32), 64 * 64, (64, 64), id="2d_chunking_max_byes_evenly_divides"
),
pytest.param(
(256, 256),
(64, 32),
128 * 128,
(128, 64),
id="2d_non_square_chunking_max_byes_evenly_divides",
),
pytest.param((256,), (2,), 255, (254,), id="max_bytes_just_below_array_shape"),
pytest.param((256,), (2,), 256, (256,), id="max_bytes_equal_to_array_shape"),
pytest.param((256,), (2,), 16, (16,), id="max_bytes_normal_val"),
pytest.param((256,), (2,), 2, (2,), id="max_bytes_same_as_chunk"),
pytest.param((256,), (2,), 1, (2,), id="max_bytes_less_than_chunk"),
pytest.param((256,), (2,), None, (4,), id="use_default_auto_setting"),
pytest.param((4,), (2,), None, (2,), id="small_array_shape_does_not_shard"),
],
)
def test_auto_partition_auto_shards(
array_shape: tuple[int, ...], chunk_shape: tuple[int, ...]
array_shape: tuple[int, ...],
chunk_shape: tuple[int, ...],
max_bytes_per_shard_for_auto_sharding: int | None,
expected_shards: tuple[int, ...],
) -> None:
"""
Test that automatically picking a shard size returns a tuple of 2 * the chunk shape for any axis
where there are 8 or more chunks.
"""
dtype = np.dtype("uint8")
expected_shards: tuple[int, ...] = ()
for cs, a_len in zip(chunk_shape, array_shape, strict=False):
if a_len // cs >= 8:
expected_shards += (2 * cs,)
else:
expected_shards += (cs,)
with pytest.warns(
ZarrUserWarning,
match="Automatic shard shape inference is experimental and may change without notice.",
):
auto_shards, _ = _auto_partition(
array_shape=array_shape,
chunk_shape=chunk_shape,
shard_shape="auto",
item_size=dtype.itemsize,
)
with zarr.config.set(
{"array.max_bytes_per_shard_for_auto_sharding": max_bytes_per_shard_for_auto_sharding}
):
auto_shards, _ = _auto_partition(
array_shape=array_shape,
chunk_shape=chunk_shape,
shard_shape="auto",
item_size=dtype.itemsize,
)
assert auto_shards == expected_shards


Expand Down
1 change: 1 addition & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_config_defaults_set() -> None:
"array": {
"order": "C",
"write_empty_chunks": False,
"max_bytes_per_shard_for_auto_sharding": None,
},
"async": {"concurrency": 64, "timeout": None},
"threading": {"max_workers": None},
Expand Down