Skip to content

Commit 634d8f0

Browse files
authored
draft object dtype perioddata for oc (#223)
prototype a strongly-typed, more true-to-DFNs alternative to the imod-python style OC we've had so far. writing doesn't work yet. this could be generated from an OC DFN modified to look something like this. cf #205 also refactor the converter file into a module and rename dict_to_array to structure_array.
1 parent 6d04683 commit 634d8f0

File tree

15 files changed

+249
-175
lines changed

15 files changed

+249
-175
lines changed

flopy4/mf6/codec/writer/filters.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from io import StringIO
33
from typing import Any, Literal
44

5+
import attrs
56
import numpy as np
67
import xarray as xr
78
from modflow_devtools.dfn.schema.v2 import FieldType
@@ -202,7 +203,8 @@ def dataset2list(value: xr.Dataset):
202203
return
203204

204205
# special case OC for now.
205-
is_oc = all(
206+
# TODO remove after properly handling object dtype period data arrays
207+
is_oc = any(
206208
str(v.name).startswith("save_") or str(v.name).startswith("print_")
207209
for v in value.data_vars.values()
208210
)
@@ -211,9 +213,17 @@ def dataset2list(value: xr.Dataset):
211213
if (first := next(iter(value.data_vars.values()))).ndim == 0:
212214
if is_oc:
213215
for name in value.data_vars.keys():
216+
if not (name.startswith("save_") or name.startswith("print_")):
217+
# TODO: not working yet
218+
if name == "perioddata":
219+
val = value[name]
220+
val = val.item() if val.shape == () else val
221+
yield attrs.astuple(val, recurse=True)
222+
continue
214223
val = value[name]
215224
val = val.item() if val.shape == () else val
216225
yield (*name.split("_"), val)
226+
217227
else:
218228
vals = []
219229
for name in value.data_vars.keys():

flopy4/mf6/converter/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from pathlib import Path
2+
from typing import Any
3+
4+
import cattr
5+
import xattree
6+
from cattr import Converter
7+
from cattrs.gen import make_hetero_tuple_unstructure_fn
8+
9+
from flopy4.mf6.component import Component
10+
from flopy4.mf6.context import Context
11+
from flopy4.mf6.converter.structure import structure_array
12+
from flopy4.mf6.converter.unstructure import (
13+
unstructure_component,
14+
)
15+
from flopy4.mf6.gwf.oc import Oc
16+
17+
__all__ = [
18+
"structure",
19+
"unstructure",
20+
"structure_array",
21+
"unstructure_array",
22+
"COMPONENT_CONVERTER",
23+
]
24+
25+
26+
def _make_converter() -> Converter:
27+
converter = Converter(unstruct_strat=cattr.UnstructureStrategy.AS_TUPLE)
28+
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
29+
converter.register_unstructure_hook(Component, unstructure_component)
30+
converter.register_unstructure_hook(
31+
Oc.PrintSaveSetting, make_hetero_tuple_unstructure_fn(Oc.PrintSaveSetting, converter)
32+
)
33+
converter.register_unstructure_hook(
34+
Oc.Steps, make_hetero_tuple_unstructure_fn(Oc.Steps, converter)
35+
)
36+
return converter
37+
38+
39+
COMPONENT_CONVERTER = _make_converter()
40+
41+
42+
def structure(data: dict[str, Any], path: Path) -> Component:
43+
component = COMPONENT_CONVERTER.structure(data, Component)
44+
if isinstance(component, Context):
45+
component.workspace = path.parent
46+
component.filename = path.name
47+
return component
48+
49+
50+
def unstructure(component: Component) -> dict[str, Any]:
51+
return COMPONENT_CONVERTER.unstructure(component)

flopy4/mf6/converter/structure.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
import sparse
5+
from numpy.typing import NDArray
6+
from xattree import get_xatspec
7+
8+
from flopy4.adapters import get_nn
9+
from flopy4.mf6.config import SPARSE_THRESHOLD
10+
from flopy4.mf6.constants import FILL_DNODATA
11+
12+
13+
def structure_array(value, self_, field) -> NDArray:
14+
"""
15+
Convert a sparse dictionary representation of an array to a
16+
dense numpy array or a sparse COO array.
17+
18+
TODO: generalize this not only to dictionaries but to any
19+
form that can be converted to an array (e.g. nested list)
20+
"""
21+
22+
if not isinstance(value, dict):
23+
# if not a dict, assume it's a numpy array
24+
# and let xarray deal with it if it isn't
25+
return value
26+
27+
spec = get_xatspec(type(self_)).flat
28+
field = spec[field.name]
29+
if not field.dims:
30+
raise ValueError(f"Field {field} missing dims")
31+
32+
# resolve dims
33+
explicit_dims = self_.__dict__.get("dims", {})
34+
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
35+
dims = inherited_dims | explicit_dims
36+
shape = [dims.get(d, d) for d in field.dims]
37+
unresolved = [d for d in shape if isinstance(d, str)]
38+
if any(unresolved):
39+
raise ValueError(f"Couldn't resolve dims: {unresolved}")
40+
41+
if np.prod(shape) > SPARSE_THRESHOLD:
42+
a: dict[tuple[Any, ...], Any] = dict()
43+
44+
def set_(arr, val, *ind):
45+
arr[tuple(ind)] = val
46+
47+
def final(arr):
48+
coords = np.array(list(map(list, zip(*arr.keys()))))
49+
return sparse.COO(
50+
coords,
51+
list(arr.values()),
52+
shape=shape,
53+
fill_value=field.default or FILL_DNODATA,
54+
)
55+
else:
56+
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore
57+
58+
def set_(arr, val, *ind):
59+
arr[ind] = val
60+
61+
def final(arr):
62+
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
63+
return arr
64+
65+
if "nper" in dims:
66+
for kper, period in value.items():
67+
if kper == "*":
68+
kper = 0
69+
match len(shape):
70+
case 1:
71+
set_(a, period, kper)
72+
case _:
73+
for cellid, v in period.items():
74+
nn = get_nn(cellid, **dims)
75+
set_(a, v, kper, nn)
76+
if kper == "*":
77+
break
78+
else:
79+
for cellid, v in value.items():
80+
nn = get_nn(cellid, **dims)
81+
set_(a, v, nn)
82+
83+
return final(a)

flopy4/mf6/converter.py renamed to flopy4/mf6/converter/unstructure.py

Lines changed: 5 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,17 @@
33
from pathlib import Path
44
from typing import Any
55

6-
import numpy as np
7-
import sparse
86
import xarray as xr
97
import xattree
10-
from cattrs import Converter
118
from modflow_devtools.dfn.schema.block import block_sort_key
12-
from numpy.typing import NDArray
13-
from xattree import get_xatspec
149

15-
from flopy4.adapters import get_nn
1610
from flopy4.mf6.binding import Binding
1711
from flopy4.mf6.component import Component
18-
from flopy4.mf6.config import SPARSE_THRESHOLD
19-
from flopy4.mf6.constants import FILL_DNODATA
2012
from flopy4.mf6.context import Context
2113
from flopy4.mf6.spec import FileInOut
2214

2315

24-
def path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
16+
def _path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
2517
t = [name.upper()]
2618
if name.endswith("_file"):
2719
t[0] = name.replace("_file", "").upper()
@@ -31,7 +23,7 @@ def path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
3123
return tuple(t)
3224

3325

34-
def make_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
26+
def _make_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
3527
if not isinstance(value, Context):
3628
return {}
3729

@@ -104,7 +96,7 @@ def unstructure_component(value: Component) -> dict[str, Any]:
10496
data = xattree.asdict(value)
10597

10698
# create child component binding blocks
107-
blocks.update(make_binding_blocks(value))
99+
blocks.update(_make_binding_blocks(value))
108100

109101
# process blocks in order, unstructuring fields as needed,
110102
# then slice period data into separate kper-indexed blocks
@@ -132,6 +124,7 @@ def unstructure_component(value: Component) -> dict[str, Any]:
132124
# - 'auxiliary' fields to tuples
133125
# - xarray DataArrays with 'nper' dim to dict of kper-sliced datasets
134126
# - other values to their original form
127+
# TODO: use cattrs converters for field unstructuring?
135128
match field_value := data[field_name]:
136129
case None:
137130
continue
@@ -141,7 +134,7 @@ def unstructure_component(value: Component) -> dict[str, Any]:
141134
case Path():
142135
field_spec = xatspec.attrs[field_name]
143136
field_meta = getattr(field_spec, "metadata", {})
144-
t = path_to_tuple(
137+
t = _path_to_tuple(
145138
field_name, field_value, inout=field_meta.get("inout", "fileout")
146139
)
147140
# name may have changed e.g dropping '_file' suffix
@@ -197,98 +190,3 @@ def unstructure_component(value: Component) -> dict[str, Any]:
197190
del blocks["solutiongroup"]
198191

199192
return {name: block for name, block in blocks.items() if name != period_block_name}
200-
201-
202-
def _make_converter() -> Converter:
203-
converter = Converter()
204-
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
205-
converter.register_unstructure_hook(Component, unstructure_component)
206-
return converter
207-
208-
209-
COMPONENT_CONVERTER = _make_converter()
210-
211-
212-
def dict_to_array(value, self_, field) -> NDArray:
213-
"""
214-
Convert a sparse dictionary representation of an array to a
215-
dense numpy array or a sparse COO array.
216-
217-
TODO: generalize this not only to dictionaries but to any
218-
form that can be converted to an array (e.g. nested list)
219-
"""
220-
221-
if not isinstance(value, dict):
222-
# if not a dict, assume it's a numpy array
223-
# and let xarray deal with it if it isn't
224-
return value
225-
226-
spec = get_xatspec(type(self_)).flat
227-
field = spec[field.name]
228-
if not field.dims:
229-
raise ValueError(f"Field {field} missing dims")
230-
231-
# resolve dims
232-
explicit_dims = self_.__dict__.get("dims", {})
233-
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
234-
dims = inherited_dims | explicit_dims
235-
shape = [dims.get(d, d) for d in field.dims]
236-
unresolved = [d for d in shape if isinstance(d, str)]
237-
if any(unresolved):
238-
raise ValueError(f"Couldn't resolve dims: {unresolved}")
239-
240-
if np.prod(shape) > SPARSE_THRESHOLD:
241-
a: dict[tuple[Any, ...], Any] = dict()
242-
243-
def set_(arr, val, *ind):
244-
arr[tuple(ind)] = val
245-
246-
def final(arr):
247-
coords = np.array(list(map(list, zip(*arr.keys()))))
248-
return sparse.COO(
249-
coords,
250-
list(arr.values()),
251-
shape=shape,
252-
fill_value=field.default or FILL_DNODATA,
253-
)
254-
else:
255-
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore
256-
257-
def set_(arr, val, *ind):
258-
arr[ind] = val
259-
260-
def final(arr):
261-
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
262-
return arr
263-
264-
if "nper" in dims:
265-
for kper, period in value.items():
266-
if kper == "*":
267-
kper = 0
268-
match len(shape):
269-
case 1:
270-
set_(a, period, kper)
271-
case _:
272-
for cellid, v in period.items():
273-
nn = get_nn(cellid, **dims)
274-
set_(a, v, kper, nn)
275-
if kper == "*":
276-
break
277-
else:
278-
for cellid, v in value.items():
279-
nn = get_nn(cellid, **dims)
280-
set_(a, v, nn)
281-
282-
return final(a)
283-
284-
285-
def structure(data: dict[str, Any], path: Path) -> Component:
286-
component = COMPONENT_CONVERTER.structure(data, Component)
287-
if isinstance(component, Context):
288-
component.workspace = path.parent
289-
component.filename = path.name
290-
return component
291-
292-
293-
def unstructure(component: Component) -> dict[str, Any]:
294-
return COMPONENT_CONVERTER.unstructure(component)

flopy4/mf6/gwf/chd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from xattree import xattree
88

99
from flopy4.mf6.constants import LENBOUNDNAME
10-
from flopy4.mf6.converter import dict_to_array
10+
from flopy4.mf6.converter import structure_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field, path
1313
from flopy4.mf6.utils.grid_utils import update_maxbound
@@ -38,7 +38,7 @@ class Chd(Package):
3838
"nodes",
3939
),
4040
default=None,
41-
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
41+
converter=Converter(structure_array, takes_self=True, takes_field=True),
4242
on_setattr=update_maxbound,
4343
)
4444
aux: Optional[NDArray[np.float64]] = array(
@@ -48,7 +48,7 @@ class Chd(Package):
4848
"nodes",
4949
),
5050
default=None,
51-
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
51+
converter=Converter(structure_array, takes_self=True, takes_field=True),
5252
on_setattr=update_maxbound,
5353
)
5454
boundname: Optional[NDArray[np.str_]] = array(
@@ -59,6 +59,6 @@ class Chd(Package):
5959
"nodes",
6060
),
6161
default=None,
62-
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
62+
converter=Converter(structure_array, takes_self=True, takes_field=True),
6363
on_setattr=update_maxbound,
6464
)

0 commit comments

Comments
 (0)