Skip to content

Commit 6d04683

Browse files
authored
more miscellaneous writer cleanup/testing (#222)
* tidy up converter logic * fix inset in array macro * default oc period data arrays to None * clarify temporary oc special casing * no need for Solution slnfname field * add strict arg to Component.to_dict * drop duplicate filename in Simulation * move maxbound fn to utils module * add some docstrings * add test assertions
1 parent d1300c2 commit 6d04683

File tree

16 files changed

+200
-160
lines changed

16 files changed

+200
-160
lines changed

flopy4/mf6/codec/writer/filters.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,15 @@ def dataset2list(value: xr.Dataset):
201201
if value is None or not any(value.data_vars):
202202
return
203203

204-
first = next(iter(value.data_vars.values()))
205-
is_union = first.dtype.type is np.str_
204+
# special case OC for now.
205+
is_oc = all(
206+
str(v.name).startswith("save_") or str(v.name).startswith("print_")
207+
for v in value.data_vars.values()
208+
)
206209

207-
if first.ndim == 0: # handle scalar
208-
if is_union:
210+
# handle scalar
211+
if (first := next(iter(value.data_vars.values()))).ndim == 0:
212+
if is_oc:
209213
for name in value.data_vars.keys():
210214
val = value[name]
211215
val = val.item() if val.shape == () else val
@@ -230,7 +234,7 @@ def dataset2list(value: xr.Dataset):
230234
has_spatial_dims = len(spatial_dims) > 0
231235
indices = np.where(combined_mask)
232236
for i in range(len(indices[0])):
233-
if is_union:
237+
if is_oc:
234238
for name in value.data_vars.keys():
235239
val = value[name][tuple(idx[i] for idx in indices)]
236240
val = val.item() if val.shape == () else val

flopy4/mf6/codec/writer/templates/macros.jinja

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
{{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}
3333

3434
{% if how == "constant" %}
35-
CONSTANT {{ value|array2const }}
35+
{{ inset }}CONSTANT {{ value|array2const }}
3636
{% elif how == "layered constant" %}
3737
{% for layer in value -%}
38-
CONSTANT {{ layer|array2const }}
38+
{{ inset }}CONSTANT {{ layer|array2const }}
3939
{%- endfor %}
4040
{% elif how == "internal" %}
4141
INTERNAL

flopy4/mf6/component.py

Lines changed: 41 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,69 +3,17 @@
33
from pathlib import Path
44
from typing import Any, ClassVar
55

6-
import numpy as np
76
from attrs import fields
87
from modflow_devtools.dfn import Dfn, Field
98
from packaging.version import Version
109
from xattree import asdict as xattree_asdict
1110
from xattree import xattree
1211

13-
from flopy4.mf6.constants import FILL_DNODATA, MF6
12+
from flopy4.mf6.constants import MF6
1413
from flopy4.mf6.spec import field, fields_dict, to_field
14+
from flopy4.mf6.utils.grid_utils import update_maxbound
1515
from flopy4.uio import IO, Loader, Writer
1616

17-
18-
def update_maxbound(instance, attribute, new_value):
19-
"""
20-
Generalized function to update maxbound when period block arrays change.
21-
22-
This function automatically finds all period block arrays in the instance
23-
and calculates maxbound based on the maximum number of non-default values
24-
across all arrays.
25-
26-
Args:
27-
instance: The package instance
28-
attribute: The attribute being set (from attrs on_setattr)
29-
new_value: The new value being set
30-
31-
Returns:
32-
The new_value (unchanged)
33-
"""
34-
35-
period_arrays = []
36-
instance_fields = fields(instance.__class__)
37-
for f in instance_fields:
38-
if (
39-
f.metadata
40-
and f.metadata.get("block") == "period"
41-
and f.metadata.get("xattree", {}).get("dims")
42-
):
43-
period_arrays.append(f.name)
44-
45-
maxbound_values = []
46-
for array_name in period_arrays:
47-
if attribute and attribute.name == array_name:
48-
array_val = new_value
49-
else:
50-
array_val = getattr(instance, array_name, None)
51-
52-
if array_val is not None:
53-
array_data = (
54-
array_val if array_val.data.shape == array_val.shape else array_val.todense()
55-
)
56-
57-
if array_data.dtype.kind in ["U", "S"]: # String arrays
58-
non_default_count = len(np.where(array_data != "")[0])
59-
else: # Numeric arrays
60-
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])
61-
62-
maxbound_values.append(non_default_count)
63-
if maxbound_values:
64-
instance.maxbound = max(maxbound_values)
65-
66-
return new_value
67-
68-
6917
COMPONENTS = {}
7018
"""MF6 component registry."""
7119

@@ -86,11 +34,14 @@ class Component(ABC, MutableMapping):
8634
_write = IO(Writer) # type: ignore
8735

8836
dfn: ClassVar[Dfn]
37+
"""The component's definition (i.e. specification)."""
38+
8939
filename: str | None = field(default=None)
40+
"""The name of the component's input file."""
9041

9142
@property
9243
def path(self) -> Path:
93-
"""Get the path to the component's input file."""
44+
"""The path to the component's input file."""
9445
self.filename = self.filename or self.default_filename()
9546
return Path.cwd() / self.filename
9647

@@ -202,18 +153,45 @@ def write(self, format: str = MF6) -> None:
202153
for child in self.children.values(): # type: ignore
203154
child.write(format=format)
204155

205-
def to_dict(self, blocks: bool = False) -> dict[str, Any]:
206-
"""Convert the component to a dictionary representation."""
156+
def to_dict(self, blocks: bool = False, strict: bool = False) -> dict[str, Any]:
157+
"""
158+
Convert the component to a dictionary representation.
159+
160+
Parameters
161+
----------
162+
blocks : bool, optional
163+
If True, return a nested dict keyed by block name
164+
with values as dicts of fields. Default is False.
165+
strict : bool, optional
166+
If True, include only fields in the DFN specification.
167+
168+
Returns
169+
-------
170+
dict[str, Any]
171+
Dictionary containing component data, either
172+
in terms of fields (flat) or blocks (nested).
173+
"""
207174
data = xattree_asdict(self)
208-
data.pop("filename")
209-
data.pop("workspace", None) # might be a Context
210-
data.pop("nodes", None) # TODO: find a better way to omit
175+
spec = self.dfn.fields
176+
177+
if strict:
178+
data.pop("filename")
179+
data.pop("workspace", None) # might be a Context
180+
211181
if blocks:
212182
blocks_ = {} # type: ignore
213-
for field_name, field_value in data.items():
214-
block_name = self.dfn.fields[field_name].block
183+
for field_name in spec.keys():
184+
field_value = data[field_name]
185+
block_name = spec[field_name].block
186+
if strict and block_name is None:
187+
continue
215188
if block_name not in blocks_:
216189
blocks_[block_name] = {}
217190
blocks_[block_name][field_name] = field_value
218191
return blocks_
219-
return data
192+
else:
193+
return {
194+
field_name: data[field_name]
195+
for field_name in spec.keys()
196+
if spec[field_name].block or not strict
197+
}

flopy4/mf6/converter.py

Lines changed: 60 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
3131
return tuple(t)
3232

3333

34-
def get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
34+
def make_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
3535
if not isinstance(value, Context):
3636
return {}
3737

@@ -103,13 +103,19 @@ def unstructure_component(value: Component) -> dict[str, Any]:
103103
xatspec = xattree.get_xatspec(type(value))
104104
data = xattree.asdict(value)
105105

106-
blocks.update(binding_blocks := get_binding_blocks(value))
106+
# create child component binding blocks
107+
blocks.update(make_binding_blocks(value))
107108

109+
# process blocks in order, unstructuring fields as needed,
110+
# then slice period data into separate kper-indexed blocks
111+
# each of which contains a dataset indexed for that period.
108112
for block_name, block in blockspec.items():
113+
period_data = {} # type: ignore
114+
period_blocks = {} # type: ignore
115+
period_block_name = None
116+
109117
if block_name not in blocks:
110118
blocks[block_name] = {}
111-
period_data = {}
112-
period_blocks = {} # type: ignore
113119

114120
for field_name in block.keys():
115121
# Skip child components that have been processed as bindings
@@ -119,82 +125,78 @@ def unstructure_component(value: Component) -> dict[str, Any]:
119125
if child_spec.metadata["block"] == block_name: # type: ignore
120126
continue
121127

122-
field_value = data[field_name]
123-
# convert:
128+
# filter out empty values and false keywords, and convert:
124129
# - paths to records
125-
# - datetime to ISO format
126-
# - auxiliary fields to tuples
127-
# - xarray DataArrays with 'nper' dimension to kper-sliced datasets
128-
# (and split the period data into separate kper-indexed blocks)
130+
# - datetimes to ISO format
131+
# - filter out false keywords
132+
# - 'auxiliary' fields to tuples
133+
# - xarray DataArrays with 'nper' dim to dict of kper-sliced datasets
129134
# - other values to their original form
130-
if isinstance(field_value, Path):
131-
field_spec = xatspec.attrs[field_name]
132-
field_meta = getattr(field_spec, "metadata", {})
133-
t = path_to_tuple(field_name, field_value, inout=field_meta.get("inout", "fileout"))
134-
# name may have changed e.g dropping '_file' suffix
135-
blocks[block_name][t[0]] = t
136-
elif isinstance(field_value, datetime):
137-
blocks[block_name][field_name] = field_value.isoformat()
138-
elif (
139-
field_name == "auxiliary"
140-
and hasattr(field_value, "values")
141-
and field_value is not None
142-
):
143-
blocks[block_name][field_name] = tuple(field_value.values.tolist())
144-
elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims:
145-
has_spatial_dims = any(
146-
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
147-
)
148-
if has_spatial_dims:
149-
field_value = _hack_structured_grid_dims(
150-
field_value,
151-
structured_grid_dims=value.parent.data.dims, # type: ignore
135+
match field_value := data[field_name]:
136+
case None:
137+
continue
138+
case bool():
139+
if field_value:
140+
blocks[block_name][field_name] = field_value
141+
case Path():
142+
field_spec = xatspec.attrs[field_name]
143+
field_meta = getattr(field_spec, "metadata", {})
144+
t = path_to_tuple(
145+
field_name, field_value, inout=field_meta.get("inout", "fileout")
152146
)
153-
154-
period_data[field_name] = {
155-
kper: field_value.isel(nper=kper)
156-
for kper in range(field_value.sizes["nper"])
157-
}
158-
else:
159-
# TODO why not putting in block here but doing below? how does this even work
160-
if np.issubdtype(field_value.dtype, np.str_):
147+
# name may have changed e.g dropping '_file' suffix
148+
blocks[block_name][t[0]] = t
149+
case datetime():
150+
blocks[block_name][field_name] = field_value.isoformat()
151+
case t if (
152+
field_name == "auxiliary"
153+
and hasattr(field_value, "values")
154+
and field_value is not None
155+
):
156+
blocks[block_name][field_name] = tuple(field_value.values.tolist())
157+
case xr.DataArray() if "nper" in field_value.dims:
158+
has_spatial_dims = any(
159+
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
160+
)
161+
if has_spatial_dims:
162+
field_value = _hack_structured_grid_dims(
163+
field_value,
164+
structured_grid_dims=value.parent.data.dims, # type: ignore
165+
)
166+
if "period" in block_name:
167+
period_block_name = block_name
161168
period_data[field_name] = {
162-
kper: field_value[kper] for kper in range(field_value.sizes["nper"])
169+
kper: field_value.isel(nper=kper)
170+
for kper in range(field_value.sizes["nper"])
163171
}
164-
else:
165-
if block_name not in period_data:
166-
period_data[block_name] = {}
167-
period_data[block_name][field_name] = field_value # type: ignore
168-
else:
169-
if field_value is not None:
170-
if isinstance(field_value, bool):
171-
if field_value:
172-
blocks[block_name][field_name] = field_value
173172
else:
174173
blocks[block_name][field_name] = field_value
175174

176-
if block_name in period_data and isinstance(period_data[block_name], dict):
177-
dataset = xr.Dataset(period_data[block_name])
178-
blocks[block_name] = {block_name: dataset}
179-
del period_data[block_name]
175+
case _:
176+
blocks[block_name][field_name] = field_value
180177

178+
# invert key order, (arr_name, kper) -> (kper, arr_name)
181179
for arr_name, periods in period_data.items():
182180
for kper, arr in periods.items():
183181
if kper not in period_blocks:
184182
period_blocks[kper] = {}
185183
period_blocks[kper][arr_name] = arr
186184

185+
# setup indexed period blocks, combine arrays into datasets
187186
for kper, block in period_blocks.items():
188-
dataset = xr.Dataset(block)
189-
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}
187+
assert isinstance(period_block_name, str)
188+
blocks[f"{period_block_name} {kper + 1}"] = {
189+
period_block_name: xr.Dataset(block, coords=block[arr_name].coords)
190+
}
190191

191-
# total temporary hack! manually set solutiongroup 1. still need to support multiple..
192+
# total temporary hack! manually set solutiongroup 1.
193+
# TODO still need to support multiple..
192194
if "solutiongroup" in blocks:
193195
sg = blocks["solutiongroup"]
194196
blocks["solutiongroup 1"] = sg
195197
del blocks["solutiongroup"]
196198

197-
return {name: block for name, block in blocks.items() if name != "period"}
199+
return {name: block for name, block in blocks.items() if name != period_block_name}
198200

199201

200202
def _make_converter() -> Converter:

flopy4/mf6/gwf/chd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.component import update_maxbound
109
from flopy4.mf6.constants import LENBOUNDNAME
1110
from flopy4.mf6.converter import dict_to_array
1211
from flopy4.mf6.package import Package
1312
from flopy4.mf6.spec import array, field, path
13+
from flopy4.mf6.utils.grid_utils import update_maxbound
1414
from flopy4.utils import to_path
1515

1616

flopy4/mf6/gwf/drn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.component import update_maxbound
109
from flopy4.mf6.constants import LENBOUNDNAME
1110
from flopy4.mf6.converter import dict_to_array
1211
from flopy4.mf6.package import Package
1312
from flopy4.mf6.spec import array, field, path
13+
from flopy4.mf6.utils.grid_utils import update_maxbound
1414
from flopy4.utils import to_path
1515

1616

0 commit comments

Comments
 (0)