Skip to content

Commit 2c5814f

Browse files
Resolve "Add cmethods attributes to output files" (#177)
1 parent fc6716c commit 2c5814f

File tree

6 files changed

+121
-5
lines changed

6 files changed

+121
-5
lines changed

cmethods/core.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from __future__ import annotations
2727

28+
from datetime import datetime, timezone
2829
from typing import TYPE_CHECKING, Callable, Dict, Optional
2930

3031
import xarray as xr
@@ -49,6 +50,46 @@
4950
}
5051

5152

53+
def _add_cmethods_metadata(
54+
result: xr.Dataset | xr.DataArray,
55+
method: str,
56+
**kwargs,
57+
) -> xr.Dataset | xr.DataArray:
58+
"""
59+
Add metadata to the result indicating it was processed by python-cmethods.
60+
61+
:param result: The bias-corrected dataset or dataarray
62+
:param method: The method used for bias correction
63+
:param kwargs: Additional method parameters
64+
:return: Result with added metadata
65+
"""
66+
try:
67+
from importlib.metadata import version # noqa: PLC0415
68+
69+
pkg_version = version("python-cmethods")
70+
except Exception: # noqa: BLE001
71+
pkg_version = "unknown"
72+
73+
attrs_to_add = {
74+
"cmethods_version": pkg_version,
75+
"cmethods_method": method,
76+
"cmethods_timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC"),
77+
"cmethods_source": "https://github.com/btschwertfeger/python-cmethods",
78+
}
79+
80+
if kind := kwargs.get("kind"):
81+
attrs_to_add["cmethods_kind"] = kind
82+
if n_quantiles := kwargs.get("n_quantiles"):
83+
attrs_to_add["cmethods_n_quantiles"] = str(n_quantiles)
84+
if group := kwargs.get("group"):
85+
attrs_to_add["cmethods_group"] = str(group)
86+
87+
if isinstance(result, (xr.Dataset, xr.DataArray)):
88+
result.attrs.update(attrs_to_add)
89+
90+
return result
91+
92+
5293
def apply_ufunc(
5394
method: str,
5495
obs: xr.xarray.core.dataarray.DataArray,
@@ -144,6 +185,8 @@ def adjust(
144185
:return: The bias corrected/adjusted data set
145186
:rtype: xr.xarray.core.dataarray.DataArray | xr.xarray.core.dataarray.Dataset
146187
"""
188+
metadata_kwargs = {k: v for k, v in kwargs.items() if k in {"kind", "n_quantiles", "group"}}
189+
147190
kwargs["adjust_called"] = True
148191
ensure_xr_dataarray(obs=obs, simh=simh, simp=simp)
149192

@@ -159,7 +202,8 @@ def adjust(
159202
# mock this function or apply ``CMethods.__apply_ufunc` directly
160203
# on your data sets.
161204
if kwargs.get("group") is None:
162-
return apply_ufunc(method, obs, simh, simp, **kwargs).to_dataset()
205+
result = apply_ufunc(method, obs, simh, simp, **kwargs).to_dataset()
206+
return _add_cmethods_metadata(result, method, **metadata_kwargs)
163207

164208
if method not in SCALING_METHODS:
165209
raise ValueError(
@@ -204,7 +248,7 @@ def adjust(
204248

205249
result = monthly_result if result is None else xr.merge([result, monthly_result])
206250

207-
return result
251+
return _add_cmethods_metadata(result, method, **metadata_kwargs)
208252

209253

210254
__all__ = ["adjust"]

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ min-file-size = 1024
264264
"TID252", # ban relative imports
265265
"PTH118", # `os.path.join()` should be replaced by `Path` with `/` operator,
266266
"PTH120", # `os.path.dirname()` should be replaced by `Path.parent`
267+
"PLR2004", # magic value in comparison
268+
"PLC2701" # Private name import
267269
]
268270

269271
[tool.ruff.lint.flake8-quotes]

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
dask[distributed]
2+
matplotlib
23
pytest
34
pytest-cov
45
pytest-retry

tests/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def get_dataset(data, time, kind: str) -> xr.Dataset:
101101
.to_dataset(name=kind)
102102
)
103103

104-
if kind == "+": # noqa: PLR2004
104+
if kind == "+":
105105
some_data = [get_hist_temp_for_lat(val) for val in latitudes]
106106
data = np.array(
107107
[

tests/test_misc.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
import numpy as np
1717
import pytest
18+
import xarray as xr
1819

1920
from cmethods import adjust
21+
from cmethods.core import _add_cmethods_metadata
2022
from cmethods.distribution import (
2123
detrended_quantile_mapping,
2224
quantile_delta_mapping,
@@ -130,3 +132,70 @@ def test_adjust_failing_no_group_for_distribution(datasets: dict) -> None:
130132
n_quantiles=100,
131133
group="time.month",
132134
)
135+
136+
137+
def test_add_cmethods_metadata_with_dataarray() -> None:
138+
"""Test that _add_cmethods_metadata adds correct attributes to a DataArray"""
139+
data = xr.DataArray(
140+
np.array([1, 2, 3, 4, 5]),
141+
dims=["time"],
142+
coords={"time": np.arange(5)},
143+
)
144+
145+
result = _add_cmethods_metadata(
146+
data,
147+
method="linear_scaling",
148+
kind="+",
149+
n_quantiles=100,
150+
group="time.month",
151+
)
152+
153+
assert "cmethods_version" in result.attrs
154+
assert "cmethods_method" in result.attrs
155+
assert "cmethods_timestamp" in result.attrs
156+
assert "cmethods_source" in result.attrs
157+
158+
assert result.attrs["cmethods_method"] == "linear_scaling"
159+
assert result.attrs["cmethods_kind"] == "+"
160+
assert result.attrs["cmethods_n_quantiles"] == "100"
161+
assert result.attrs["cmethods_group"] == "time.month"
162+
assert result.attrs["cmethods_source"] == "https://github.com/btschwertfeger/python-cmethods"
163+
assert "UTC" in result.attrs["cmethods_timestamp"]
164+
165+
166+
def test_add_cmethods_metadata_with_dataset() -> None:
167+
"""Test that _add_cmethods_metadata adds correct attributes to a Dataset"""
168+
data = xr.Dataset(
169+
{
170+
"temperature": xr.DataArray(
171+
np.array([1, 2, 3, 4, 5]),
172+
dims=["time"],
173+
coords={"time": np.arange(5)},
174+
),
175+
},
176+
)
177+
178+
result = _add_cmethods_metadata(data, method="quantile_mapping")
179+
180+
assert "cmethods_version" in result.attrs
181+
assert "cmethods_method" in result.attrs
182+
assert "cmethods_timestamp" in result.attrs
183+
assert "cmethods_source" in result.attrs
184+
assert result.attrs["cmethods_method"] == "quantile_mapping"
185+
186+
187+
def test_add_cmethods_metadata_optional_params() -> None:
188+
"""Test that _add_cmethods_metadata handles optional parameters correctly"""
189+
data = xr.DataArray(
190+
np.array([1, 2, 3]),
191+
dims=["time"],
192+
coords={"time": np.arange(3)},
193+
)
194+
195+
result = _add_cmethods_metadata(data, method="variance_scaling")
196+
197+
assert "cmethods_method" in result.attrs
198+
assert result.attrs["cmethods_method"] == "variance_scaling"
199+
assert "cmethods_kind" not in result.attrs
200+
assert "cmethods_n_quantiles" not in result.attrs
201+
assert "cmethods_group" not in result.attrs

tests/test_zarr_dask_compatibility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_3d_scaling_zarr(
3737
kind: str,
3838
dask_cluster: Any, # noqa: ARG001
3939
) -> None:
40-
variable: str = "tas" if kind == "+" else "pr" # noqa: PLR2004
40+
variable: str = "tas" if kind == "+" else "pr"
4141
obsh: xr.DataArray = datasets_from_zarr[kind]["obsh"][variable]
4242
obsp: xr.DataArray = datasets_from_zarr[kind]["obsp"][variable]
4343
simh: xr.DataArray = datasets_from_zarr[kind]["simh"][variable]
@@ -81,7 +81,7 @@ def test_3d_distribution_zarr(
8181
kind: str,
8282
dask_cluster: Any, # noqa: ARG001
8383
) -> None:
84-
variable: str = "tas" if kind == "+" else "pr" # noqa: PLR2004
84+
variable: str = "tas" if kind == "+" else "pr"
8585
obsh: XRData_t = datasets_from_zarr[kind]["obsh"][variable]
8686
obsp: XRData_t = datasets_from_zarr[kind]["obsp"][variable]
8787
simh: XRData_t = datasets_from_zarr[kind]["simh"][variable]

0 commit comments

Comments
 (0)