Skip to content

Commit 37f3b2f

Browse files
fix a bug when loading a config file
The HDF5 file stores config values like log_level_global as a numeric (e.g., numpy.int64). When loading, that integer is passed into VariPEPS_Config. In setattr, the field type for log_level_global is the Enum LogLevel, so integers must be coerced to LogLevel. Your version’s coercion path doesn’t catch your value, so it falls through and raises: Type mismatch for option 'log_level_global', got '<class numpy.int64>', expected '<enum 'LogLevel'>'. Why it falls through The loader passes a numpy integer (or a 0-d/1-d array) instead of a Python int. The Enum branch in setattr is too strict about the numeric checks, so it doesn’t convert that value into LogLevel.
1 parent fedab0d commit 37f3b2f

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

varipeps/config.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,33 @@ def __setattr__(self, name: str, value: Any) -> NoReturn:
367367
elif (
368368
field.type is bool
369369
and hasattr(value, "dtype")
370-
and np.isdtype(value.dtype, np.bool)
370+
and np.issubdtype(value.dtype, np.bool_)
371371
and value.size == 1
372372
):
373373
if value.ndim > 0:
374374
value = value.reshape(-1)[0]
375375
value = bool(value)
376+
elif isinstance(field.type, type) and issubclass(field.type, Enum):
377+
# Accept ints/np.int64 or enum names for Enum fields
378+
if isinstance(value, field.type):
379+
pass
380+
elif isinstance(value, (int,)) or (
381+
hasattr(value, "dtype")
382+
and np.issubdtype(value.dtype, np.integer)
383+
and value.size == 1
384+
):
385+
if hasattr(value, "ndim") and value.ndim > 0:
386+
value = value.reshape(-1)[0]
387+
value = field.type(int(value))
388+
elif isinstance(value, str):
389+
try:
390+
value = field.type[value]
391+
except KeyError:
392+
value = field.type(int(value))
393+
else:
394+
raise TypeError(
395+
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
396+
)
376397
else:
377398
raise TypeError(
378399
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."

0 commit comments

Comments
 (0)