From 8305a200a163d8f9bd440ebde745782273e182f4 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 12 Oct 2025 15:03:53 -0700 Subject: [PATCH 1/3] simplify implementation of merge() for DataTree --- xarray/structure/concat.py | 109 +++++++++++++++++++++++++++++++++++- xarray/structure/merge.py | 14 +---- xarray/tests/test_concat.py | 17 ++++++ 3 files changed, 126 insertions(+), 14 deletions(-) diff --git a/xarray/structure/concat.py b/xarray/structure/concat.py index fbc90ff6a50..9510e88bd08 100644 --- a/xarray/structure/concat.py +++ b/xarray/structure/concat.py @@ -10,6 +10,7 @@ from xarray.core.coordinates import Coordinates from xarray.core.duck_array_ops import lazy_array_equiv from xarray.core.indexes import Index, PandasIndex +from xarray.core.treenode import group_subtrees from xarray.core.types import T_DataArray, T_Dataset, T_Variable from xarray.core.utils import emit_user_level_warning from xarray.core.variable import Variable @@ -30,6 +31,7 @@ ) if TYPE_CHECKING: + from xarray.core.datatree import DataTree from xarray.core.types import ( CombineAttrsOptions, CompatOptions, @@ -41,6 +43,21 @@ # TODO: replace dim: Any by 1D array_likes +@overload +def concat( + objs: Iterable[DataTree], + dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, + data_vars: T_DataVars | CombineKwargDefault = "minimal", + coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault = "minimal", + compat: CompatOptions = "override", + positions: Iterable[Iterable[int]] | None = None, + fill_value: object = dtypes.NA, + join: JoinOptions = "exact", + combine_attrs: CombineAttrsOptions = "override", + create_index_for_new_dim: bool = True, +) -> DataTree: ... + + @overload def concat( objs: Iterable[T_Dataset], @@ -265,6 +282,7 @@ def concat( # dimension already exists from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree try: first_obj, objs = utils.peek_at(objs) @@ -278,7 +296,20 @@ def concat( f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" ) - if isinstance(first_obj, DataArray): + if isinstance(first_obj, DataTree): + return _datatree_concat( + objs, # type: ignore[arg-type] + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + positions=positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, + ) + elif isinstance(first_obj, DataArray): return _dataarray_concat( objs, dim=dim, @@ -311,6 +342,82 @@ def concat( ) +def _datatree_concat( + trees: Iterable[DataTree], + dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, + data_vars: T_DataVars | CombineKwargDefault, + coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault, + compat: CompatOptions, + positions: Iterable[Iterable[int]] | None, + fill_value: object, + join: JoinOptions, + combine_attrs: CombineAttrsOptions, + create_index_for_new_dim: bool, +) -> DataTree: + """Concatenate DataTree objects.""" + from xarray.core.datatree import DataTree + + if join != "exact": + raise NotImplementedError( + "Only `join='exact'` is supported for DataTree concat" + ) + if data_vars != "minimal": + raise NotImplementedError( + "Only `data_vars='minimal'` is supported for DataTree concat" + ) + if coords != "minimal": + raise NotImplementedError( + "Only `coords='minimal'` is supported for DataTree concat" + ) + if fill_value is not dtypes.NA: + raise NotImplementedError("`fill_value` is not supported for DataTree concat") + if positions is not None: + raise NotImplementedError("`positions` is not supported for DataTree concat") + + trees = list(trees) + if not all(isinstance(obj, DataTree) for obj in trees): + raise TypeError( + "concat does not support mixed type arguments when one argument " + f"is a DataTree: {trees}" + ) + + first_tree = trees[0] + for other_tree in trees[1:]: + if not first_tree.isomorphic(other_tree): + raise ValueError("All DataTree objects must be isomorphic") + + dim_name, _ = _calc_concat_dim_index(dim) + + all_dims = set() + for tree in trees: + for node in tree.subtree: + all_dims.update(node.dims) + + if dim_name not in all_dims: + raise ValueError( + f"Dimension {dim_name!r} does not exist on any of the DataTree objects being concatenated." + ) + + concatenated = {} + for path, nodes in group_subtrees(*trees): + datasets = [node.to_dataset(inherit=False) for node in nodes] + concatenated_ds = concat( + datasets, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + positions=positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + create_index_for_new_dim=create_index_for_new_dim, + ) + concatenated[path] = concatenated_ds + + return DataTree.from_dict(concatenated) + + def _calc_concat_dim_index( dim_or_data: Hashable | Any, ) -> tuple[Hashable, PandasIndex | None]: diff --git a/xarray/structure/merge.py b/xarray/structure/merge.py index cb436ca1027..6c1c3166a5c 100644 --- a/xarray/structure/merge.py +++ b/xarray/structure/merge.py @@ -813,7 +813,7 @@ def merge_trees( ) node_lists: defaultdict[str, list[DataTree]] = defaultdict(list) - for tree in trees: + for tree in list(trees): for key, node in tree.subtree_with_keys: node_lists[key].append(node) @@ -839,18 +839,6 @@ def level(kv): join=join, combine_attrs=combine_attrs, ) - # Remove inherited coordinates/indexes/dimensions. - for var_name in list(merge_result.coord_names): - if not any(var_name in node._coord_variables for node in nodes): - del merge_result.variables[var_name] - merge_result.coord_names.remove(var_name) - for index_name in list(merge_result.indexes): - if not any(index_name in node._node_indexes for node in nodes): - del merge_result.indexes[index_name] - for dim in list(merge_result.dims): - if not any(dim in node._node_dims for node in nodes): - del merge_result.dims[dim] - merged_ds = Dataset._construct_direct(**merge_result._asdict()) result[key] = DataTree(dataset=merged_ds) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 2c3692863f0..ee9ecd3fff0 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -13,6 +13,7 @@ from xarray.core import dtypes, types from xarray.core.coordinates import Coordinates from xarray.core.indexes import PandasIndex +from xarray.datatree import DataTree from xarray.structure import merge from xarray.tests import ( ConcatenatableArray, @@ -1441,6 +1442,22 @@ def test_concat_index_not_same_dim() -> None: concat([ds1, ds2], dim="x") +class TestConcatDataTree: + def test_concat_simple(self) -> None: + ds1 = Dataset({"a": ("x", [1])}) + dt1 = DataTree.from_dict({"foo": ds1}) + + ds2 = Dataset({"a": ("x", [2])}) + dt2 = DataTree.from_dict({"foo": ds2}) + + actual = concat([dt1, dt2], dim="x") + + expected_ds = Dataset({"a": ("x", [1, 2])}) + expected = DataTree.from_dict({"foo": expected_ds}) + + assert_identical(actual, expected) + + class TestNewDefaults: def test_concat_second_empty_with_scalar_data_var_only_on_first(self) -> None: ds1 = Dataset(data_vars={"a": ("y", [0.1]), "b": 0.1}, coords={"x": 0.1}) From 4fcf3bf9b6a635a10caa3ebfbe4e9dcb18a0f36e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 12 Oct 2025 15:08:53 -0700 Subject: [PATCH 2/3] revert inadvertently included changes --- xarray/structure/concat.py | 109 +----------------------------------- xarray/tests/test_concat.py | 17 ------ 2 files changed, 1 insertion(+), 125 deletions(-) diff --git a/xarray/structure/concat.py b/xarray/structure/concat.py index 9510e88bd08..fbc90ff6a50 100644 --- a/xarray/structure/concat.py +++ b/xarray/structure/concat.py @@ -10,7 +10,6 @@ from xarray.core.coordinates import Coordinates from xarray.core.duck_array_ops import lazy_array_equiv from xarray.core.indexes import Index, PandasIndex -from xarray.core.treenode import group_subtrees from xarray.core.types import T_DataArray, T_Dataset, T_Variable from xarray.core.utils import emit_user_level_warning from xarray.core.variable import Variable @@ -31,7 +30,6 @@ ) if TYPE_CHECKING: - from xarray.core.datatree import DataTree from xarray.core.types import ( CombineAttrsOptions, CompatOptions, @@ -43,21 +41,6 @@ # TODO: replace dim: Any by 1D array_likes -@overload -def concat( - objs: Iterable[DataTree], - dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, - data_vars: T_DataVars | CombineKwargDefault = "minimal", - coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault = "minimal", - compat: CompatOptions = "override", - positions: Iterable[Iterable[int]] | None = None, - fill_value: object = dtypes.NA, - join: JoinOptions = "exact", - combine_attrs: CombineAttrsOptions = "override", - create_index_for_new_dim: bool = True, -) -> DataTree: ... - - @overload def concat( objs: Iterable[T_Dataset], @@ -282,7 +265,6 @@ def concat( # dimension already exists from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.datatree import DataTree try: first_obj, objs = utils.peek_at(objs) @@ -296,20 +278,7 @@ def concat( f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" ) - if isinstance(first_obj, DataTree): - return _datatree_concat( - objs, # type: ignore[arg-type] - dim=dim, - data_vars=data_vars, - coords=coords, - compat=compat, - positions=positions, - fill_value=fill_value, - join=join, - combine_attrs=combine_attrs, - create_index_for_new_dim=create_index_for_new_dim, - ) - elif isinstance(first_obj, DataArray): + if isinstance(first_obj, DataArray): return _dataarray_concat( objs, dim=dim, @@ -342,82 +311,6 @@ def concat( ) -def _datatree_concat( - trees: Iterable[DataTree], - dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, - data_vars: T_DataVars | CombineKwargDefault, - coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault, - compat: CompatOptions, - positions: Iterable[Iterable[int]] | None, - fill_value: object, - join: JoinOptions, - combine_attrs: CombineAttrsOptions, - create_index_for_new_dim: bool, -) -> DataTree: - """Concatenate DataTree objects.""" - from xarray.core.datatree import DataTree - - if join != "exact": - raise NotImplementedError( - "Only `join='exact'` is supported for DataTree concat" - ) - if data_vars != "minimal": - raise NotImplementedError( - "Only `data_vars='minimal'` is supported for DataTree concat" - ) - if coords != "minimal": - raise NotImplementedError( - "Only `coords='minimal'` is supported for DataTree concat" - ) - if fill_value is not dtypes.NA: - raise NotImplementedError("`fill_value` is not supported for DataTree concat") - if positions is not None: - raise NotImplementedError("`positions` is not supported for DataTree concat") - - trees = list(trees) - if not all(isinstance(obj, DataTree) for obj in trees): - raise TypeError( - "concat does not support mixed type arguments when one argument " - f"is a DataTree: {trees}" - ) - - first_tree = trees[0] - for other_tree in trees[1:]: - if not first_tree.isomorphic(other_tree): - raise ValueError("All DataTree objects must be isomorphic") - - dim_name, _ = _calc_concat_dim_index(dim) - - all_dims = set() - for tree in trees: - for node in tree.subtree: - all_dims.update(node.dims) - - if dim_name not in all_dims: - raise ValueError( - f"Dimension {dim_name!r} does not exist on any of the DataTree objects being concatenated." - ) - - concatenated = {} - for path, nodes in group_subtrees(*trees): - datasets = [node.to_dataset(inherit=False) for node in nodes] - concatenated_ds = concat( - datasets, - dim=dim, - data_vars=data_vars, - coords=coords, - compat=compat, - positions=positions, - fill_value=fill_value, - join=join, - combine_attrs=combine_attrs, - create_index_for_new_dim=create_index_for_new_dim, - ) - concatenated[path] = concatenated_ds - - return DataTree.from_dict(concatenated) - - def _calc_concat_dim_index( dim_or_data: Hashable | Any, ) -> tuple[Hashable, PandasIndex | None]: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index ee9ecd3fff0..2c3692863f0 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -13,7 +13,6 @@ from xarray.core import dtypes, types from xarray.core.coordinates import Coordinates from xarray.core.indexes import PandasIndex -from xarray.datatree import DataTree from xarray.structure import merge from xarray.tests import ( ConcatenatableArray, @@ -1442,22 +1441,6 @@ def test_concat_index_not_same_dim() -> None: concat([ds1, ds2], dim="x") -class TestConcatDataTree: - def test_concat_simple(self) -> None: - ds1 = Dataset({"a": ("x", [1])}) - dt1 = DataTree.from_dict({"foo": ds1}) - - ds2 = Dataset({"a": ("x", [2])}) - dt2 = DataTree.from_dict({"foo": ds2}) - - actual = concat([dt1, dt2], dim="x") - - expected_ds = Dataset({"a": ("x", [1, 2])}) - expected = DataTree.from_dict({"foo": expected_ds}) - - assert_identical(actual, expected) - - class TestNewDefaults: def test_concat_second_empty_with_scalar_data_var_only_on_first(self) -> None: ds1 = Dataset(data_vars={"a": ("y", [0.1]), "b": 0.1}, coords={"x": 0.1}) From 10490585279e7bc7374ee38174a93540dda2a96d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 12 Oct 2025 15:09:47 -0700 Subject: [PATCH 3/3] fix merge_trees --- xarray/structure/merge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/structure/merge.py b/xarray/structure/merge.py index 6c1c3166a5c..e5f3c0959bd 100644 --- a/xarray/structure/merge.py +++ b/xarray/structure/merge.py @@ -795,7 +795,7 @@ def merge_core( def merge_trees( - trees: Iterable[DataTree], + trees: Sequence[DataTree], compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT, join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT, fill_value: object = dtypes.NA, @@ -813,7 +813,7 @@ def merge_trees( ) node_lists: defaultdict[str, list[DataTree]] = defaultdict(list) - for tree in list(trees): + for tree in trees: for key, node in tree.subtree_with_keys: node_lists[key].append(node)