77from contextlib import suppress
88from typing import Any , Callable , List , Optional , Set , Tuple , Type , Union
99
10+ from ._common import is_dataclass_like , is_subclass
1011from .actions import _ActionConfigLoad
11- from .optionals import (
12- attrs_support ,
13- get_doc_short_description ,
14- import_attrs ,
15- import_pydantic ,
16- pydantic_support ,
17- )
12+ from .optionals import get_doc_short_description , import_pydantic , pydantic_support
1813from .parameter_resolvers import (
1914 ParamData ,
2015 get_parameter_origins ,
2116 get_signature_parameters ,
2217)
2318from .typehints import ActionTypeHint , LazyInitBaseClass , is_optional
24- from .typing import is_final_class
25- from .util import LoggerProperty , get_import_path , is_subclass , iter_to_set_str
19+ from .util import LoggerProperty , get_import_path , iter_to_set_str
2620
2721__all__ = [
2822 'compose_dataclasses' ,
@@ -93,7 +87,7 @@ def add_class_arguments(
9387 if default :
9488 skip = skip or set ()
9589 prefix = nested_key + '.' if nested_key else ''
96- defaults = default .lazy_get_init_data (). as_dict ()
90+ defaults = default .lazy_get_init_args ()
9791 if defaults :
9892 defaults = {prefix + k : v for k , v in defaults .items () if k not in skip }
9993 self .set_defaults (** defaults ) # type: ignore
@@ -317,8 +311,7 @@ def _add_signature_parameter(
317311 elif not as_positional :
318312 kwargs ['required' ] = True
319313 is_subclass_typehint = False
320- is_final_class_typehint = is_final_class (annotation )
321- is_pure_dataclass_typehint = is_pure_dataclass (annotation )
314+ is_dataclass_like_typehint = is_dataclass_like (annotation )
322315 dest = (nested_key + '.' if nested_key else '' ) + name
323316 args = [dest if is_required and as_positional else '--' + dest ]
324317 if param .origin :
@@ -332,8 +325,7 @@ def _add_signature_parameter(
332325 )
333326 if annotation in {str , int , float , bool } or \
334327 is_subclass (annotation , (str , int , float )) or \
335- is_final_class_typehint or \
336- is_pure_dataclass_typehint :
328+ is_dataclass_like_typehint :
337329 kwargs ['type' ] = annotation
338330 elif annotation != inspect_empty :
339331 try :
@@ -360,7 +352,7 @@ def _add_signature_parameter(
360352 'sub_configs' : sub_configs ,
361353 'instantiate' : instantiate ,
362354 }
363- if is_final_class_typehint or is_pure_dataclass_typehint :
355+ if is_dataclass_like_typehint :
364356 kwargs .update (sub_add_kwargs )
365357 action = group .add_argument (* args , ** kwargs )
366358 action .sub_add_kwargs = sub_add_kwargs
@@ -401,8 +393,8 @@ def add_dataclass_arguments(
401393 ValueError: When not given a dataclass.
402394 ValueError: When default is not instance of or kwargs for theclass.
403395 """
404- if not is_pure_dataclass (theclass ):
405- raise ValueError (f'Expected "theclass" argument to be a pure dataclass, given { theclass } ' )
396+ if not is_dataclass_like (theclass ):
397+ raise ValueError (f'Expected "theclass" argument to be a dataclass-like , given { theclass } ' )
406398
407399 doc_group = get_doc_short_description (theclass , logger = self .logger )
408400 for key in ['help' , 'title' ]:
@@ -420,6 +412,7 @@ def add_dataclass_arguments(
420412 defaults = dataclass_to_dict (default )
421413
422414 added_args : List [str ] = []
415+ param_kwargs = {k : v for k , v in kwargs .items () if k == 'sub_configs' }
423416 for param in get_signature_parameters (theclass , None , logger = self .logger ):
424417 self ._add_signature_parameter (
425418 group ,
@@ -428,6 +421,7 @@ def add_dataclass_arguments(
428421 added_args ,
429422 fail_untyped = fail_untyped ,
430423 default = defaults .get (param .name , inspect_empty ),
424+ ** param_kwargs ,
431425 )
432426
433427 return added_args
@@ -467,8 +461,8 @@ def add_subclass_arguments(
467461 Raises:
468462 ValueError: When given an invalid base class.
469463 """
470- if is_final_class (baseclass ):
471- raise ValueError ("Not allowed for classes that are final ." )
464+ if is_dataclass_like (baseclass ):
465+ raise ValueError ("Not allowed for dataclass-like classes ." )
472466 if type (baseclass ) is not tuple :
473467 baseclass = (baseclass ,) # type: ignore
474468 if not all (inspect .isclass (c ) for c in baseclass ):
@@ -550,32 +544,18 @@ def is_factory_class(value):
550544 return value .__class__ == dataclasses ._HAS_DEFAULT_FACTORY_CLASS
551545
552546
553- def is_pure_dataclass (value ):
554- if not inspect .isclass (value ):
555- return False
556- classes = [c for c in inspect .getmro (value ) if c != object ]
557- all_dataclasses = all (dataclasses .is_dataclass (c ) for c in classes )
558- if not all_dataclasses and pydantic_support :
559- pydantic = import_pydantic ('is_pure_dataclass' )
560- classes = [c for c in classes if c != pydantic .utils .Representation ]
561- all_dataclasses = all (is_subclass (c , pydantic .BaseModel ) for c in classes )
562- if not all_dataclasses and attrs_support :
563- attrs = import_attrs ('is_pure_dataclass' )
564- if attrs .has (value ):
565- return True
566- return all_dataclasses
567-
568-
569- def dataclass_to_dict (value ):
547+ def dataclass_to_dict (value ) -> dict :
570548 if pydantic_support :
571549 pydantic = import_pydantic ('dataclass_to_dict' )
572550 if isinstance (value , pydantic .BaseModel ):
573551 return value .dict ()
552+ if isinstance (value , LazyInitBaseClass ):
553+ return value .lazy_get_init_data ().as_dict ()
574554 return dataclasses .asdict (value )
575555
576556
577557def compose_dataclasses (* args ):
578- """Returns a pure dataclass inheriting all given dataclasses and properly handling __post_init__."""
558+ """Returns a dataclass inheriting all given dataclasses and properly handling __post_init__."""
579559
580560 @dataclasses .dataclass
581561 class ComposedDataclass (* args ):
0 commit comments