Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
650088b
Add topk
dcherian Jul 27, 2024
889be0c
Negative k
dcherian Jul 28, 2024
996ff2a
dask support
dcherian Jul 28, 2024
776d233
test
dcherian Jul 28, 2024
a5eb7b9
wip
dcherian Jul 28, 2024
4fa9a4c
fix
dcherian Jul 28, 2024
4b04fde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2024
93800aa
Handle dtypes.NA properly for datetime/timedelta
dcherian Jul 31, 2024
80c67f4
Fix
dcherian Jul 31, 2024
7056d18
Merge branch 'main' into topk
dcherian Aug 7, 2024
44f5f3f
Merge branch 'main' into topk
dcherian Jan 7, 2025
c924017
Fixes
dcherian Jan 7, 2025
7a794ba
one more fix
dcherian Jan 7, 2025
eec4dd4
fix
dcherian Jan 7, 2025
6ac9a1f
one more fix
dcherian Jan 7, 2025
83594e8
Fixes.
dcherian Jan 7, 2025
740f85f
WIP
dcherian Jan 7, 2025
5d64fd9
Merge branch 'main' into topk
dcherian Jan 7, 2025
e177efd
fixes
dcherian Jan 7, 2025
9393470
fix
dcherian Jan 7, 2025
17eb915
cleanup
dcherian Jan 7, 2025
dc0df3e
works?
dcherian Jan 7, 2025
83ae5d8
fix quantile
dcherian Jan 7, 2025
95d20b8
optimize xrutils.topk
dcherian Jan 7, 2025
0b9fafc
Merge branch 'main' into topk
dcherian Jan 8, 2025
caa98b8
Update tests/test_properties.py
dcherian Jan 8, 2025
820d46c
generalize new_dims_func
dcherian Jan 13, 2025
17a4d5d
Merge branch 'main' into topk
dcherian Jan 13, 2025
6aa923a
Revert "generalize new_dims_func"
dcherian Jan 13, 2025
16b0bac
Merge branch 'main' into topk
dcherian Jan 13, 2025
2c6d486
Support bool
dcherian Jan 13, 2025
0dcd87c
more skipping
dcherian Jan 13, 2025
9b874ea
fix
dcherian Jan 14, 2025
adebbec
more xfail
dcherian Jan 15, 2025
ace2af5
Merge branch 'main' into topk
dcherian Jan 19, 2025
4f35230
cleanup
dcherian Jan 19, 2025
cd2f150
one more xfail
dcherian Jan 19, 2025
70e6f22
typing
dcherian Jan 19, 2025
5d45603
minor docs
dcherian Jan 19, 2025
096f6b9
disable log in CI
dcherian Jan 19, 2025
0277cb9
Fix boolean
dcherian Jan 19, 2025
6c7e84a
bool -> bool_
dcherian Jan 20, 2025
43c3408
update int limits
dcherian Jan 20, 2025
01eabfb
fix rtd
dcherian Jan 20, 2025
6e4ce69
Add note
dcherian Jan 20, 2025
4500c7e
Merge branch 'main' into topk
dcherian Jan 24, 2025
8f60477
Add unit test
dcherian Jan 24, 2025
15fcfa1
WIP
dcherian Jan 24, 2025
a5bcc5b
fix
dcherian Jan 24, 2025
489c843
Merge branch 'main' into topk
dcherian Mar 18, 2025
91e1d07
Switch DUMMY_AXIS to 0
dcherian Mar 18, 2025
2d868fe
More support for edge cases
dcherian Mar 18, 2025
d244d60
minor
dcherian Mar 18, 2025
8319f7f
[WIP] failing test
dcherian Mar 18, 2025
d21eec5
Merge branch 'main' into topk
dcherian Jul 16, 2025
dfb1e88
fix expected
dcherian Mar 26, 2025
8b31f5d
Revert "[WIP] failing test"
dcherian Mar 26, 2025
fce4f2b
[revert] failing test
dcherian Mar 26, 2025
0f7ee05
fix
dcherian Jul 16, 2025
4c3e6d3
Fix topk extraction for groups with fewer than k elements
dcherian Nov 30, 2025
902e60e
Add nantopk function for NaN-aware topk combine
dcherian Nov 30, 2025
6e7a035
Configure topk aggregation for dask map-reduce
dcherian Nov 30, 2025
e2c7e42
Force simple_combine for aggregations with new_dims_func
dcherian Nov 30, 2025
0893ba0
Raise NotImplementedError for topk with reindex=False
dcherian Nov 30, 2025
31591ba
Add comprehensive unit tests for topk with NaN handling
dcherian Nov 30, 2025
04740fa
Remove temporary and auto-generated files from git tracking
dcherian Nov 30, 2025
f61c856
Remove docs/oisst.ipynb from git tracking
dcherian Nov 30, 2025
a8f07a5
Merge main into topk branch
dcherian Nov 30, 2025
37d2cc7
Fix mypy error in topk function
dcherian Nov 30, 2025
cbf9596
move _DUMMY_AXIS to dask.py
dcherian Dec 1, 2025
bacda97
Avoid squeezing
dcherian Dec 1, 2025
8d1606d
Properly handle empty input
dcherian Dec 1, 2025
35a8a22
Fix combining
dcherian Dec 1, 2025
2ca52b1
Fix axis parameter for single-element tuples in _simple_combine
dcherian Dec 1, 2025
0193434
Fix topk tokenization to include finalize_kwargs
dcherian Dec 1, 2025
15aa174
FIx test
dcherian Dec 1, 2025
f82c3bd
Fix test_groupby_reduce_all for topk with chunked arrays
dcherian Dec 1, 2025
6716e63
Fix _var_combine to handle integer axis parameter
dcherian Dec 1, 2025
a863f5b
revert a change
dcherian Dec 1, 2025
b38b606
Fix must_use_simple_combine logic for argmax/argmin
dcherian Dec 1, 2025
17b179d
Update CLAUDE.md with implementation details learned
dcherian Dec 1, 2025
ac1d9e1
Revert test_groupby_reduce_all parameterize to match main
dcherian Dec 1, 2025
73ba70c
tweak
dcherian Dec 1, 2025
16bdec3
Remove Claude settings files from git tracking
dcherian Dec 1, 2025
5785714
Revert "tweak"
dcherian Dec 1, 2025
dd1c03a
Fix must_use_simple_combine logic to use num_new_vector_dims
dcherian Dec 1, 2025
de431d2
Fix xarray topk support and add tests
dcherian Dec 1, 2025
c640f2e
Move topk tests to test_xarray.py
dcherian Dec 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
id: status
run: |
uv run --no-dev python -c "import xarray; xarray.show_versions()" || true
uv run --no-dev pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
uv run --no-dev pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci --log-disable=flox
- name: Upload code coverage to Codecov
uses: codecov/codecov-action@v5.5.1
with:
Expand Down
19 changes: 19 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,22 @@ venv.bak/

# Git worktrees
worktrees/

# Auto-generated version file
flox/_version.py

# Temporary files
Untitled.ipynb
*.rej
*.py.rej
mutmut-cache
.mutmut-cache
mydask.png
profile.json
profile.html
test.png
uv.lock
devel/

# Claude Code
.claude/
67 changes: 67 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,70 @@ asv preview
- Integration testing with xarray upstream development branch
- **Python Support**: Minimum version 3.11 (updated from 3.10)
- **Git Worktrees**: `worktrees/` directory is ignored for development workflows
- **Running Tests**: Always use `uv run pytest` to run tests (not just `pytest`)

## Key Implementation Details

### Map-Reduce Combine Strategies (`flox/dask.py`)

There are two strategies for combining intermediate results in dask's tree reduction:

1. **`_simple_combine`**: Used for most reductions. Tree-reduces the reduction itself (not the groupby-reduction) for performance. Requirements:

- All blocks must contain all groups after blockwise step (reindex.blockwise=True)
- Must know expected_groups
- Inserts DUMMY_AXIS=-2 via `_expand_dims`, reduces along it, then squeezes it out
- Used when: not an arg reduction, not first/last with non-float dtype, and labels are known

1. **`_grouped_combine`**: More general solution that tree-reduces the groupby-reduction itself. Used for:

- Arg reductions (argmax, argmin, etc.)
- When labels are unknown (dask arrays without expected_groups)
- First/last reductions with non-float dtypes

### Aggregations with New Dimensions

Some aggregations add new dimensions to the output (e.g., topk, quantile):

- **`new_dims_func`**: Function that returns tuple of Dim objects for new dimensions
- These MUST use `_simple_combine` because intermediate results have an extra dimension that needs to be reduced along DUMMY_AXIS
- Check if `new_dims_func(**finalize_kwargs)` returns non-empty tuple to determine if aggregation actually adds dimensions
- **Note**: argmax/argmin have `new_dims_func` but return empty tuple, so they use `_grouped_combine`

### topk Implementation

The topk aggregation is special:

- Uses `_simple_combine` (has non-empty new_dims_func)
- First intermediate (topk values) combines along axis 0, not DUMMY_AXIS
- Does NOT squeeze out DUMMY_AXIS in final aggregate step
- `_expand_dims` only expands non-topk intermediates (the second one, nanlen)

### Axis Parameter Handling

- **`_simple_combine`**: Always receives axis as tuple (e.g., `(-2,)` for DUMMY_AXIS)
- **numpy functions**: Most accept both tuple and integer axis (e.g., np.max, np.sum)
- **Exception**: argmax/argmin don't accept tuple axis, but these use `_grouped_combine`
- **Custom functions**: Like `_var_combine` should normalize axis to tuple if needed for iteration

### Test Organization

- **`test_groupby_reduce_all`**: Comprehensive test for all aggregations with various parameters (nby, chunks, etc.)

- Tests both with and without NaN handling
- For topk: sorts results along axis 0 before comparison (k dimension is at axis 0)
- Uses `np.moveaxis` not `np.swapaxes` for topk to avoid swapping other dimensions

- **`test_groupby_reduce_axis_subset_against_numpy`**: Tests reductions over subsets of axes

- Compares dask results against numpy results
- Tests various axis combinations: None, single int, tuples
- Skip arg reductions with axis=None or multiple axes (not supported)

### Common Pitfalls

1. **Axis transformations for topk**: Use `np.moveaxis(expected, src, 0)` not `np.swapaxes(expected, src, 0)` to move k dimension to position 0 without reordering other dimensions

1. **new_dims_func checking**: Check if it returns non-empty dimensions, not just if it exists (argmax has one that returns `()`)

1. **Axis parameter types**: Custom combine functions should handle both tuple and integer axis by normalizing at the start
1 change: 1 addition & 0 deletions devel
15 changes: 8 additions & 7 deletions docs/source/aggregations.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,16 @@ the `func` kwarg:
- `"mean"`, `"nanmean"`
- `"var"`, `"nanvar"`
- `"std"`, `"nanstd"`
- `"argmin"`
- `"argmax"`
- `"argmin"`, `"nanargmax"`
- `"argmax"`, `"nanargmin"`
- `"first"`, `"nanfirst"`
- `"last"`, `"nanlast"`
- `"median"`, `"nanmedian"`
- `"mode"`, `"nanmode"`
- `"quantile"`, `"nanquantile"`
- `"topk"`

```{tip}
We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome!
```

## Custom Aggregations
## Custom Reductions

`flox` also allows you to specify a custom Aggregation (again inspired by dask.dataframe),
though this might not be fully functional at the moment. See `aggregations.py` for examples.
Expand All @@ -46,3 +43,7 @@ mean = Aggregation(
final_fill_value=np.nan,
)
```

## Custom Scans

Coming soon!
153 changes: 107 additions & 46 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,32 @@ def _lerp(a, b, *, t, dtype, out=None):
return out


def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None):
def quantile_or_topk(
array,
inv_idx,
*,
q=None,
k=None,
axis,
skipna,
group_idx,
dtype=None,
out=None,
fill_value=None,
):
assert q is not None or k is not None
assert axis == -1

inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))

array_validmask = notnull(array)
actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis)
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes
if k is not None:
nanmask = actual_sizes < abs(k)
else:
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes

# The approach here is to use (complex_array.partition) because
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
Expand All @@ -72,36 +90,63 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# So we determine which indices we need using the fact that NaNs get sorted to the end.
# This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
# but not any more now that I use partition and avoid replacing NaNs
qin = q
q = np.atleast_1d(qin)
q = np.reshape(q, (len(q),) + (1,) * array.ndim)
if k is not None:
is_scalar_param = False
param = np.sort(np.arange(abs(k)) * np.sign(k))
else:
is_scalar_param = is_scalar(q)
param = np.atleast_1d(q)
param = np.reshape(param, (param.size,) + (1,) * array.ndim)

# This is numpy's method="linear"
# TODO: could support all the interpolations here
offset = actual_sizes.cumsum(axis=-1)
actual_sizes -= 1
virtual_index = q * actual_sizes
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]

is_scalar_q = is_scalar(qin)
if is_scalar_q:
virtual_index = virtual_index.squeeze(axis=0)
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)
else:
idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
# For topk(.., k=+1 or -1), we always return the singleton dimension.
idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)

lo_ = np.floor(
virtual_index,
casting="unsafe",
out=np.empty(virtual_index.shape, dtype=np.int64),
)
hi_ = np.ceil(
virtual_index,
casting="unsafe",
out=np.empty(virtual_index.shape, dtype=np.int64),
)
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
if q is not None:
# This is numpy's method="linear"
# TODO: could support all the interpolations here
actual_sizes -= 1
virtual_index = param * actual_sizes
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]

if is_scalar_param:
virtual_index = virtual_index.squeeze(axis=0)
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)

lo_ = np.floor(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64))
hi_ = np.ceil(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64))
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))

else:
virtual_index = (actual_sizes - k) if k > 0 else (np.zeros_like(actual_sizes) + abs(k) - 1)
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]
k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim)
lo_ = k_offset + virtual_index[np.newaxis, ...]
# For groups with fewer than k elements, clamp extraction indices to valid range
# and mark out-of-bounds positions for filling with fill_value.
# Compute group boundaries: starts = [0, offset[:-1]], ends = offset
# We prepend 0 to offset[:-1] to get group start positions
group_starts = np.insert(offset[..., :-1], 0, 0, axis=-1)

# Mark positions outside group boundaries (before clamping to detect invalid indices)
# Broadcasting happens implicitly in comparison
badmask = (lo_ < group_starts) | (lo_ >= offset)

# Clamp lo_ in-place to [group_starts, array.shape[axis]-1]
# Using out= avoids intermediate array allocations
np.clip(lo_, group_starts, array.shape[axis] - 1, out=lo_)
# Note: we don't include nanmask here because for intermediate chunk results,
# we want to keep partial results. nanmask is used separately for final output.
# kth must include ALL indices we'll extract, not just the starting index per group.
# np.partition only guarantees correct values at kth positions; other positions may
# have elements from different groups due to how introselect works with complex numbers.
kth = np.unique(np.concatenate([np.unique(offset), np.unique(lo_)]))
kth = kth[kth >= 0]
kth[kth >= array.shape[axis]] = array.shape[axis] - 1

# partition the complex array in-place
labels_broadcast = np.broadcast_to(group_idx, array.shape)
Expand All @@ -111,20 +156,33 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j`
cmplx.real = labels_broadcast
cmplx.partition(kth=kth, axis=-1)
if is_scalar_q:
a_ = cmplx.imag
else:
a_ = np.broadcast_to(cmplx.imag, (q.shape[0],) + array.shape)

# get bounds, Broadcast to (num quantiles, ..., num labels)
loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)
a_ = cmplx.imag
if not is_scalar_param:
a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape)

# TODO: could support all the interpolations here
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
if not skipna and np.any(nanmask):
result[..., nanmask] = np.nan
if array.dtype.kind in "Mm":
a_ = a_.view(array.dtype)

loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
if q is not None:
# get bounds, Broadcast to (num quantiles, ..., num labels)
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)

# TODO: could support all the interpolations here
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
if not skipna and np.any(nanmask):
result[..., nanmask] = fill_value
else:
result = loval
if badmask.any():
result[badmask] = fill_value

if k is not None:
result = result.astype(dtype, copy=False)
if out is not None:
np.copyto(out, result)
return result


Expand Down Expand Up @@ -158,12 +216,14 @@ def _np_grouped_op(

if out is None:
q = kwargs.get("q", None)
if q is None:
k = kwargs.get("k", None)
if q is None and k is None:
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
else:
nq = len(np.atleast_1d(q))
nq = len(np.atleast_1d(q)) if q is not None else abs(k)
out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
kwargs["group_idx"] = group_idx
kwargs["fill_value"] = fill_value

if (len(uniques) == size) and (uniques == np.arange(size, like=aux)).all():
# The previous version of this if condition
Expand Down Expand Up @@ -200,10 +260,11 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
min = partial(_np_grouped_op, op=np.minimum.reduceat)
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True))
topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True))
# TODO: all, any


Expand Down
Loading
Loading