Skip to content

Commit 3516b69

Browse files
authored
fix(write): next step non-numeric period output (#253)
1 parent ddc1cbe commit 3516b69

File tree

2 files changed

+131
-14
lines changed

2 files changed

+131
-14
lines changed

flopy4/mf6/converter/unstructure.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from typing import Any
55

6+
import numpy as np
67
import xarray as xr
78
import xattree
89
from modflow_devtools.dfn.schema.block import block_sort_key
@@ -89,6 +90,51 @@ def _hack_structured_grid_dims(
8990
)
9091

9192

93+
def _hack_period_non_numeric(name: str, value: xr.DataArray) -> dict[str, dict[int, Any]]:
94+
from flopy4.mf6.gwf import Oc
95+
96+
def oc_setting_data(rec):
97+
dat = {}
98+
if rec.steps.first:
99+
dat = {kper: "first" for kper in range(value.sizes["nper"])}
100+
elif rec.steps.last:
101+
dat = {kper: "last" for kper in range(value.sizes["nper"])}
102+
elif rec.steps.steps:
103+
steps = " ".join(str(x - 1) for x in rec.steps.steps)
104+
dat = {kper: f"steps {steps}" for kper in range(value.sizes["nper"])}
105+
elif rec.steps.all:
106+
# check last as this defaults to True
107+
dat = {kper: "all" for kper in range(value.sizes["nper"])}
108+
109+
return dat
110+
111+
data = {}
112+
match value.dtype:
113+
case np.bool:
114+
dat = {kper: "" for kper in range(value.sizes["nper"]) if value.values[kper]} # type: ignore
115+
data[name] = dat
116+
case np.dtypes.StringDType():
117+
fname = name.replace("_", " ")
118+
dat = {kper: value.values[kper] for kper in range(value.sizes["nper"])}
119+
data[fname] = dat
120+
case object():
121+
if isinstance(value.values[0], Oc.PrintSaveSetting):
122+
if hasattr(value.values[0], "printrecord") and isinstance(
123+
value.values[0].printrecord, list
124+
):
125+
for rec in value.values[0].printrecord:
126+
key = f"{rec.print} {rec.rtype}"
127+
data[key] = oc_setting_data(rec)
128+
if hasattr(value.values[0], "saverecord") and isinstance(
129+
value.values[0].saverecord, list
130+
):
131+
for rec in value.values[0].saverecord: # type: ignore
132+
key = f"{rec.save} {rec.rtype}" # type: ignore
133+
data[key] = oc_setting_data(rec)
134+
135+
return data
136+
137+
92138
def unstructure_component(value: Component) -> dict[str, Any]:
93139
blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore
94140
blocks: dict[str, dict[str, Any]] = {}
@@ -157,10 +203,15 @@ def unstructure_component(value: Component) -> dict[str, Any]:
157203
structured_grid_dims=value.parent.data.dims, # type: ignore
158204
)
159205
if block_name == "period":
160-
period_data[field_name] = {
161-
kper: field_value.isel(nper=kper)
162-
for kper in range(field_value.sizes["nper"])
163-
}
206+
if not np.issubdtype(field_value.dtype, np.number):
207+
dat = _hack_period_non_numeric(field_name, field_value)
208+
for n, v in dat.items():
209+
period_data[n] = v
210+
else:
211+
period_data[field_name] = {
212+
kper: field_value.isel(nper=kper) # type: ignore
213+
for kper in range(field_value.sizes["nper"])
214+
}
164215
else:
165216
blocks[block_name][field_name] = field_value
166217

@@ -174,11 +225,20 @@ def unstructure_component(value: Component) -> dict[str, Any]:
174225
period_blocks[kper] = {}
175226
period_blocks[kper][arr_name] = arr
176227

228+
# sort kper order
229+
period_blocks = dict(sorted(period_blocks.items()))
230+
177231
# setup indexed period blocks, combine arrays into datasets
178232
for kper, block in period_blocks.items():
179-
blocks[f"period {kper + 1}"] = {
180-
"period": xr.Dataset(block, coords=block[arr_name].coords)
181-
}
233+
blocks[f"period {kper + 1}"] = {}
234+
for arr_name, val in block.items():
235+
match block[arr_name]:
236+
case str():
237+
blocks[f"period {kper + 1}"][arr_name] = val
238+
case xr.DataArray():
239+
blocks[f"period {kper + 1}"]["period"] = xr.Dataset(
240+
block, coords=block[arr_name].coords
241+
)
182242

183243
# combine "perioddata" block arrays (tdis, ats) into datasets
184244
# so they render as lists. temp hack TODO do this generically

test/test_mf6_codec.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from pprint import pprint
44

5-
import pytest
6-
75
from flopy4.mf6.codec import dumps, loads
86
from flopy4.mf6.converter import COMPONENT_CONVERTER
97

@@ -57,7 +55,31 @@ def test_dumps_ic():
5755
pprint(loaded)
5856

5957

60-
@pytest.mark.xfail(reason="nested type unstructuring not yet supported")
58+
def test_dumps_sto():
59+
from flopy4.mf6.gwf import Dis, Gwf, Sto
60+
61+
dis = Dis()
62+
gwf = Gwf(dis=dis)
63+
sto = Sto(
64+
dims={"nper": 3},
65+
parent=gwf,
66+
steady_state=[False, True, False],
67+
transient=[True, False, True],
68+
)
69+
70+
dumped = dumps(COMPONENT_CONVERTER.unstructure(sto))
71+
print("STO dump:")
72+
print(dumped)
73+
assert "BEGIN PERIOD 1\n TRANSIENT" in dumped
74+
assert "BEGIN PERIOD 2\n STEADY_STATE" in dumped
75+
assert "BEGIN PERIOD 3\n TRANSIENT" in dumped
76+
assert dumped
77+
78+
loaded = loads(dumped)
79+
print("STO load:")
80+
pprint(loaded)
81+
82+
6183
def test_dumps_oc():
6284
from flopy4.mf6.gwf import Oc
6385

@@ -80,10 +102,45 @@ def test_dumps_oc():
80102
dumped = dumps(COMPONENT_CONVERTER.unstructure(oc))
81103
print("OC dump:")
82104
print(dumped)
83-
assert "save head all" in dumped
84-
assert "save budget all" in dumped
85-
assert "print head all" in dumped
86-
assert "print budget all" in dumped
105+
assert "SAVE HEAD all" in dumped
106+
assert "SAVE BUDGET all" in dumped
107+
assert "PRINT HEAD all" in dumped
108+
assert "PRINT BUDGET all" in dumped
109+
assert dumped
110+
111+
loaded = loads(dumped)
112+
print("OC load:")
113+
pprint(loaded)
114+
115+
116+
def test_dumps_oc2():
117+
from flopy4.mf6.gwf import Oc
118+
119+
oc = Oc(
120+
dims={"nper": 1},
121+
budget_file="test.bud",
122+
head_file="test.hds",
123+
perioddata={
124+
0: Oc.PrintSaveSetting(
125+
printrecord=[
126+
Oc.PrintRecord("head", Oc.Steps(first=True)),
127+
Oc.PrintRecord("budget", Oc.Steps(steps=(2, 3, 5))),
128+
],
129+
saverecord=[
130+
Oc.SaveRecord("head", Oc.Steps(last=True)),
131+
Oc.SaveRecord("budget", Oc.Steps(first=True)),
132+
],
133+
)
134+
},
135+
)
136+
137+
dumped = dumps(COMPONENT_CONVERTER.unstructure(oc))
138+
print("OC dump:")
139+
print(dumped)
140+
assert "SAVE HEAD last" in dumped
141+
assert "SAVE BUDGET first" in dumped
142+
assert "PRINT HEAD first" in dumped
143+
assert "PRINT BUDGET steps 1 2 4" in dumped
87144
assert dumped
88145

89146
loaded = loads(dumped)

0 commit comments

Comments
 (0)