Skip to content

Commit f992ffb

Browse files
committed
Fix issue with import of subgroups from helpers
Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
1 parent 1bd1535 commit f992ffb

File tree

4 files changed

+69
-39
lines changed

4 files changed

+69
-39
lines changed

simple_parsing/helpers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .hparams import HyperParameters
1515
from .partial import Partial, config_for
1616
from .serialization import FrozenSerializable, Serializable, SimpleJsonEncoder, encode
17+
from .subgroups import subgroups
1718

1819
try:
1920
from .serialization import YamlSerializable
@@ -45,4 +46,5 @@
4546
"subparsers",
4647
"flag",
4748
"flags",
49+
"subgroups",
4850
]

simple_parsing/parsing.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Simple, Elegant Argument parsing.
2+
23
@author: Fabrice Normandin
34
"""
45
from __future__ import annotations
@@ -98,7 +99,6 @@ class ArgumentParser(argparse.ArgumentParser):
9899
99100
- add_config_path_arg : bool, optional
100101
When set to `True`, adds a `--config_path` argument, of type Path, which is used to parse
101-
102102
"""
103103

104104
def __init__(
@@ -531,8 +531,8 @@ def _preprocessing(self, args: Sequence[str] = (), namespace: Namespace | None =
531531
# Create one argument group per dataclass
532532
for wrapped_dataclass in wrapped_dataclasses:
533533
logger.debug(
534-
f"Parser {id(self)} is Adding arguments for dataclass: {wrapped_dataclass.dataclass} "
535-
f"at destinations {wrapped_dataclass.destinations}"
534+
f"Parser {id(self)} is Adding arguments for dataclass: "
535+
f"{wrapped_dataclass.dataclass} at destinations {wrapped_dataclass.destinations}"
536536
)
537537
wrapped_dataclass.add_arguments(parser=self)
538538

@@ -636,7 +636,8 @@ def _resolve_subgroups(
636636
# Do rounds of parsing with just the subgroup arguments, until all the subgroups
637637
# are resolved to a dataclass type.
638638
logger.debug(
639-
f"Starting subgroup parsing round {current_nesting_level}: {list(unresolved_subgroups.keys())}"
639+
f"Starting subgroup parsing round {current_nesting_level}: "
640+
f"{list(unresolved_subgroups.keys())}"
640641
)
641642
# Add all the unresolved subgroups arguments.
642643
for dest, subgroup_field in unresolved_subgroups.items():
@@ -761,6 +762,7 @@ def _resolve_subgroups(
761762

762763
def _remove_subgroups_from_namespace(self, parsed_args: argparse.Namespace) -> None:
763764
"""Removes the subgroup choice results from the namespace.
765+
764766
Modifies the namespace in-place.
765767
"""
766768
# find all subgroup fields
@@ -876,8 +878,9 @@ def _instantiate_dataclasses(
876878
existing = getattr(parsed_args, destination)
877879
if dc_wrapper.dest in self._defaults:
878880
logger.debug(
879-
f"Overwriting defaults in the namespace at destination '{destination}' "
880-
f"on the Namespace ({existing}) to a value of {value_for_dataclass_field}"
881+
f"Overwriting defaults in the namespace at destination "
882+
f"'{destination}' on the Namespace ({existing}) to a value of "
883+
f"{value_for_dataclass_field}"
881884
)
882885
setattr(parsed_args, destination, value_for_dataclass_field)
883886
else:
@@ -937,19 +940,37 @@ def _fill_constructor_arguments_with_fields(
937940
parsed_arg_values = vars(parsed_args)
938941
deleted_values: dict[str, Any] = {}
939942

940-
for wrapper in wrappers:
941-
for field in wrapper.fields:
942-
if argparse.SUPPRESS in wrapper.defaults and field.dest not in parsed_args:
943+
# BUG: Need to check that the non-init fields DO have a FieldWrapper here, and that there
944+
# isn't a value for that field in the constructor arguments!
945+
946+
for dc_wrapper in wrappers:
947+
for non_init_field in [
948+
f for f in dataclasses.fields(dc_wrapper.dataclass) if not f.init
949+
]:
950+
field_dest = dc_wrapper.dest + "." + non_init_field.name
951+
# We fetch the constructor arguments for the containing dataclass and check that it
952+
# doesn't have a value set.
953+
dc_constructor_args = constructor_arguments
954+
for dest_part in dc_wrapper.dest.split("."):
955+
dc_constructor_args = dc_constructor_args[dest_part]
956+
if non_init_field.name in dc_constructor_args:
957+
logger.warning(
958+
f"Field {field_dest} is a field with init=False, but a value is "
959+
f"present in the serialized config. This value will be ignored."
960+
)
961+
dc_constructor_args.pop(non_init_field.name)
962+
963+
for field in dc_wrapper.fields:
964+
if argparse.SUPPRESS in dc_wrapper.defaults and field.dest not in parsed_args:
943965
continue
944966

945967
if field.is_subgroup:
946968
# Skip the subgroup fields, since we added a child DataclassWrapper for them.
947969
logger.debug(f"Not calling the subgroup FieldWrapper for dest {field.dest}")
948970
continue
949971

950-
if not field.field.init:
951-
# The field isn't an argument of the dataclass constructor.
952-
continue
972+
# We only create FieldWrappers for fields that have init=True.
973+
assert field.field.init
953974

954975
# NOTE: If the field is reused (when using the ConflictResolution.ALWAYS_MERGE
955976
# strategy), then we store the multiple values in the `dest` of the first field.
@@ -979,11 +1000,12 @@ def _fill_constructor_arguments_with_fields(
9791000
@property
9801001
def confilct_resolver_max_attempts(self) -> int:
9811002
return self._conflict_resolver.max_attempts
982-
1003+
9831004
@confilct_resolver_max_attempts.setter
9841005
def confilct_resolver_max_attempts(self, value: int):
9851006
self._conflict_resolver.max_attempts = value
9861007

1008+
9871009
# TODO: Change the order of arguments to put `args` as the second argument.
9881010
def parse(
9891011
config_class: type[DataclassT],
@@ -1068,7 +1090,9 @@ def parse_known_args(
10681090
add_config_path_arg=add_config_path_arg,
10691091
)
10701092
parser.add_arguments(config_class, dest=dest, default=default)
1071-
parsed_args, unknown_args = parser.parse_known_args(args, attempt_to_reorder=attempt_to_reorder)
1093+
parsed_args, unknown_args = parser.parse_known_args(
1094+
args, attempt_to_reorder=attempt_to_reorder
1095+
)
10721096
config: Dataclass = getattr(parsed_args, dest)
10731097
return config, unknown_args
10741098

simple_parsing/utils.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,8 @@ def is_subparser_field(field: Field) -> bool:
108108

109109

110110
class InconsistentArgumentError(RuntimeError):
111-
"""
112-
Error raised when the number of arguments provided is inconsistent when parsing multiple instances from command line.
113-
"""
111+
"""Error raised when the number of arguments provided is inconsistent when parsing multiple
112+
instances from command line."""
114113

115114
def __init__(self, *args, **kwargs):
116115
super().__init__(*args, **kwargs)
@@ -126,9 +125,8 @@ def camel_case(name):
126125

127126

128127
def str2bool(raw_value: str | bool) -> bool:
129-
"""
130-
Taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
131-
"""
128+
"""Taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-
129+
argparse."""
132130
if isinstance(raw_value, bool):
133131
return raw_value
134132
v = raw_value.strip().lower()
@@ -208,12 +206,13 @@ def get_item_type(container_type: type[Container[T]]) -> T:
208206
def get_argparse_type_for_container(
209207
container_type: type[Container[T]],
210208
) -> type[T] | Callable[[str], T]:
211-
"""Gets the argparse 'type' option to be used for a given container type.
212-
When an annotation is present, the 'type' option of argparse is set to that type.
213-
if not, then the default value of 'str' is returned.
209+
"""Gets the argparse 'type' option to be used for a given container type. When an annotation is
210+
present, the 'type' option of argparse is set to that type. if not, then the default value of
211+
'str' is returned.
214212
215213
Arguments:
216-
container_type {Type} -- A container type (ideally a typing.Type such as List, Tuple, along with an item annotation: List[str], Tuple[int, int], etc.)
214+
container_type -- A container type (ideally a typing.Type such as List, Tuple, along
215+
with an item annotation: List[str], Tuple[int, int], etc.)
217216
218217
Returns:
219218
typing.Type -- the type that should be used in argparse 'type' argument option.
@@ -413,7 +412,9 @@ def is_dataclass_type_or_typevar(t: type) -> bool:
413412
Returns:
414413
bool: Whether its a dataclass type.
415414
"""
416-
return dataclasses.is_dataclass(t) or (is_typevar(t) and dataclasses.is_dataclass(get_bound(t)))
415+
return dataclasses.is_dataclass(t) or (
416+
is_typevar(t) and dataclasses.is_dataclass(get_bound(t))
417+
)
417418

418419

419420
def is_enum(t: type) -> bool:
@@ -431,7 +432,7 @@ def is_tuple_or_list(t: type) -> bool:
431432

432433

433434
def is_union(t: type) -> bool:
434-
"""Returns whether or not the given Type annotation is a variant (or subclass) of typing.Union
435+
"""Returns whether or not the given Type annotation is a variant (or subclass) of typing.Union.
435436
436437
Args:
437438
t (Type): some type annotation
@@ -453,8 +454,7 @@ def is_union(t: type) -> bool:
453454

454455

455456
def is_homogeneous_tuple_type(t: type[tuple]) -> bool:
456-
"""Returns whether the given Tuple type is homogeneous: if all items types are the
457-
same.
457+
"""Returns whether the given Tuple type is homogeneous: if all items types are the same.
458458
459459
This also includes Tuple[<some_type>, ...]
460460
@@ -651,19 +651,22 @@ def _parse(value: str) -> list[Any]:
651651
# if it doesn't work, fall back to the parse_fn.
652652
values = _fallback_parse(value)
653653

654-
# we do the default 'argparse' action, which is to add the values to a bigger list of values.
654+
# we do the default 'argparse' action, which is to add the values to a bigger list of
655+
# values.
655656
# result.extend(values)
656657
logger.debug(f"returning values: {values}")
657658
return values
658659

659660
def _parse_literal(value: str) -> list[Any] | Any:
660661
"""try to parse the string to a python expression directly.
662+
661663
(useful for nested lists or tuples.)
662664
"""
663665
literal = ast.literal_eval(value)
664666
logger.debug(f"Parsed literal: {literal}")
665667
if not isinstance(literal, (list, tuple)):
666-
# we were passed a single-element container, like "--some_list 1", which should give [1].
668+
# we were passed a single-element container, like "--some_list 1", which should give
669+
# [1].
667670
# We therefore return the literal itself, and argparse will append it.
668671
return T(literal)
669672
else:
@@ -723,8 +726,8 @@ def get_nesting_level(possibly_nested_list):
723726

724727

725728
def default_value(field: dataclasses.Field) -> T | _MISSING_TYPE:
726-
"""Returns the default value of a field in a dataclass, if available.
727-
When not available, returns `dataclasses.MISSING`.
729+
"""Returns the default value of a field in a dataclass, if available. When not available,
730+
returns `dataclasses.MISSING`.
728731
729732
Args:
730733
field (dataclasses.Field): The dataclasses.Field to get the default value of.
@@ -781,7 +784,6 @@ def keep_keys(d: dict, keys_to_keep: Iterable[str]) -> tuple[dict, dict]:
781784
Tuple[Dict, Dict]
782785
The same dictionary (with all the unwanted keys removed) as well as a
783786
new dict containing only the removed item.
784-
785787
"""
786788
d_keys = set(d.keys()) # save a copy since we will modify the dict.
787789
removed = {}
@@ -792,7 +794,7 @@ def keep_keys(d: dict, keys_to_keep: Iterable[str]) -> tuple[dict, dict]:
792794

793795

794796
def compute_identity(size: int = 16, **sample) -> str:
795-
"""Compute a unique hash out of a dictionary
797+
"""Compute a unique hash out of a dictionary.
796798
797799
Parameters
798800
----------
@@ -801,7 +803,6 @@ def compute_identity(size: int = 16, **sample) -> str:
801803
802804
**sample:
803805
Dictionary to compute the hash from
804-
805806
"""
806807
sample_hash = hashlib.sha256()
807808

@@ -840,7 +841,7 @@ def zip_dicts(*dicts: dict[K, V]) -> Iterable[tuple[K, tuple[V | None, ...]]]:
840841

841842

842843
def dict_union(*dicts: dict[K, V], recurse: bool = True, dict_factory=dict) -> dict[K, V]:
843-
"""Simple dict union until we use python 3.9
844+
"""Simple dict union until we use python 3.9.
844845
845846
If `recurse` is True, also does the union of nested dictionaries.
846847
NOTE: The returned dictionary has keys sorted alphabetically.
@@ -924,7 +925,8 @@ def unflatten(flattened: Mapping[tuple[K, ...], V]) -> PossiblyNestedDict[K, V]:
924925

925926

926927
def flatten_join(nested: PossiblyNestedMapping[str, V], sep: str = ".") -> dict[str, V]:
927-
"""Flatten a dictionary of dictionaries. Joins different nesting levels with `sep` as separator.
928+
"""Flatten a dictionary of dictionaries. Joins different nesting levels with `sep` as
929+
separator.
928930
929931
>>> flatten_join({'a': {'b': 2, 'c': 3}, 'c': {'d': 3, 'e': 4}})
930932
{'a.b': 2, 'a.c': 3, 'c.d': 3, 'c.e': 4}

test/test_examples.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""A test to make sure that all the example files work without crashing.
2-
(Could be seen as a kind of integration test.)
32
3+
(Could be seen as a kind of integration test.)
44
"""
55
from __future__ import annotations
66

@@ -134,7 +134,9 @@ def test_running_example_outputs_expected_without_arg(
134134
set_prog_name: Callable[[str, str | None], None],
135135
assert_equals_stdout: Callable[[str, str], None],
136136
):
137-
return test_running_example_outputs_expected(file_path, "", set_prog_name, assert_equals_stdout)
137+
return test_running_example_outputs_expected(
138+
file_path, "", set_prog_name, assert_equals_stdout
139+
)
138140

139141

140142
@contextmanager

0 commit comments

Comments
 (0)