1616 is_dataclass_like ,
1717 is_subclass ,
1818)
19+ from ._namespace import Namespace
1920from ._optionals import get_doc_short_description , is_pydantic_model , pydantic_support
2021from ._parameter_resolvers import (
2122 ParamData ,
2930 get_subclasses_from_type ,
3031 is_optional ,
3132)
32- from ._util import get_private_kwargs , iter_to_set_str
33+ from ._util import NoneType , get_private_kwargs , iter_to_set_str
3334from .typing import register_pydantic_type
3435
3536__all__ = [
@@ -51,7 +52,7 @@ def add_class_arguments(
5152 nested_key : Optional [str ] = None ,
5253 as_group : bool = True ,
5354 as_positional : bool = False ,
54- default : Optional [LazyInitBaseClass ] = None ,
55+ default : Optional [Union [ dict , Namespace , LazyInitBaseClass ] ] = None ,
5556 skip : Optional [Set [Union [str , int ]]] = None ,
5657 instantiate : bool = True ,
5758 fail_untyped : bool = True ,
@@ -67,7 +68,7 @@ def add_class_arguments(
6768 nested_key: Key for nested namespace.
6869 as_group: Whether arguments should be added to a new argument group.
6970 as_positional: Whether to add required parameters as positional arguments.
70- default: Default value used to override parameter defaults. Must be lazy_instance.
71+ default: Default value used to override parameter defaults.
7172 skip: Names of parameters or number of positionals that should be skipped.
7273 instantiate: Whether the class group should be instantiated by :code:`instantiate_classes`.
7374 fail_untyped: Whether to raise exception if a required parameter does not have a type.
@@ -81,9 +82,14 @@ def add_class_arguments(
8182 ValueError: When there are required parameters without at least one valid type.
8283 """
8384 if not inspect .isclass (get_generic_origin (get_unaliased_type (theclass ))):
84- raise ValueError (f'Expected "theclass" parameter to be a class type, got: { theclass } .' )
85- if default and not (isinstance (default , LazyInitBaseClass ) and isinstance (default , theclass )):
86- raise ValueError (f'Expected "default" parameter to be a lazy instance of the class, got: { default } .' )
85+ raise ValueError (f"Expected 'theclass' parameter to be a class type, got: { theclass } " )
86+ if not (
87+ isinstance (default , (NoneType , dict , Namespace ))
88+ or (isinstance (default , LazyInitBaseClass ) and isinstance (default , theclass ))
89+ ):
90+ raise ValueError (
91+ f"Expected 'default' parameter to be a dict, Namespace or lazy instance of the class, got: { default } "
92+ )
8793 linked_targets = get_private_kwargs (kwargs , linked_targets = None )
8894
8995 added_args = self ._add_signature_arguments (
@@ -102,9 +108,13 @@ def add_class_arguments(
102108 if default :
103109 skip = skip or set ()
104110 prefix = nested_key + "." if nested_key else ""
105- defaults = default .lazy_get_init_args ()
111+ defaults = default
112+ if isinstance (default , LazyInitBaseClass ):
113+ defaults = default .lazy_get_init_args ().as_dict ()
114+ elif isinstance (default , Namespace ):
115+ defaults = default .as_dict ()
106116 if defaults :
107- defaults = {prefix + k : v for k , v in defaults .__dict__ . items () if k not in skip }
117+ defaults = {prefix + k : v for k , v in defaults .items () if k not in skip }
108118 self .set_defaults (** defaults ) # type: ignore[attr-defined]
109119
110120 return added_args
0 commit comments