@@ -977,17 +977,17 @@ def adapt_typehints(
977977
978978 # Subclass
979979 elif not hasattr (typehint , "__origin__" ) and inspect .isclass (typehint ):
980- if isinstance (val , typehint ):
980+ if is_instance_or_supports_protocol (val , typehint ):
981981 if serialize :
982982 val = serialize_class_instance (val )
983983 return val
984984 if serialize and isinstance (val , str ):
985985 return val
986986
987987 val_input = val
988- if prev_val is None and not inspect .isabstract (typehint ):
988+ if prev_val is None and not inspect .isabstract (typehint ) and not is_protocol ( typehint ) :
989989 with suppress (ValueError ):
990- prev_val = Namespace (class_path = get_import_path (typehint ))
990+ prev_val = Namespace (class_path = get_import_path (typehint )) # implicit class_path
991991 val = subclass_spec_as_namespace (val , prev_val )
992992 if not is_subclass_spec (val ):
993993 raise_unexpected_value (
@@ -1000,20 +1000,20 @@ def adapt_typehints(
10001000
10011001 try :
10021002 val_class = import_object (resolve_class_path_by_name (typehint , val ["class_path" ]))
1003- if isinstance (val_class , typehint ):
1004- return val_class
1003+ if is_instance_or_supports_protocol (val_class , typehint ):
1004+ return val_class # importable instance
10051005 not_subclass = False
1006- if not is_subclass (val_class , typehint ):
1006+ if not is_subclass_or_implements_protocol (val_class , typehint ):
10071007 not_subclass = True
10081008 if not inspect .isclass (val_class ) and callable (val_class ):
10091009 from ._postponed_annotations import get_return_type
10101010
10111011 return_type = get_return_type (val_class , logger )
1012- if is_subclass (return_type , typehint ):
1012+ if is_subclass_or_implements_protocol (return_type , typehint ):
10131013 not_subclass = False
10141014 if not_subclass :
10151015 raise_unexpected_value (
1016- f' Import path { val [" class_path" ]} does not correspond to a subclass of { typehint } '
1016+ f" Import path { val [' class_path' ]} does not correspond to a subclass of { typehint . __name__ } "
10171017 )
10181018 val ["class_path" ] = get_import_path (val_class )
10191019 val = adapt_class_type (val , serialize , instantiate_classes , sub_add_kwargs , prev_val = prev_val )
@@ -1029,6 +1029,46 @@ def adapt_typehints(
10291029 return val
10301030
10311031
1032+ def implements_protocol (value , protocol ) -> bool :
1033+ from jsonargparse ._parameter_resolvers import get_signature_parameters
1034+ from jsonargparse ._postponed_annotations import get_return_type
1035+
1036+ if not inspect .isclass (value ):
1037+ return False
1038+ members = 0
1039+ for name , _ in inspect .getmembers (protocol , predicate = inspect .isfunction ):
1040+ if name .startswith ("_" ):
1041+ continue
1042+ if not hasattr (value , name ):
1043+ return False
1044+ members += 1
1045+ proto_params = get_signature_parameters (protocol , name )
1046+ value_params = get_signature_parameters (value , name )
1047+ if [(p .name , p .annotation ) for p in proto_params ] != [(p .name , p .annotation ) for p in value_params ]:
1048+ return False
1049+ proto_return = get_return_type (inspect .getattr_static (protocol , name ))
1050+ value_return = get_return_type (inspect .getattr_static (value , name ))
1051+ if proto_return != value_return :
1052+ return False
1053+ return True if members else False
1054+
1055+
1056+ def is_protocol (class_type ) -> bool :
1057+ return getattr (class_type , "_is_protocol" , False )
1058+
1059+
1060+ def is_subclass_or_implements_protocol (value , class_type ) -> bool :
1061+ if is_protocol (class_type ):
1062+ return implements_protocol (value , class_type )
1063+ return is_subclass (value , class_type )
1064+
1065+
1066+ def is_instance_or_supports_protocol (value , class_type ):
1067+ if is_protocol (class_type ):
1068+ return is_subclass_or_implements_protocol (value .__class__ , class_type )
1069+ return isinstance (value , class_type )
1070+
1071+
10321072def is_subclass_spec (val ):
10331073 is_class = isinstance (val , (dict , Namespace )) and "class_path" in val
10341074 if is_class :
0 commit comments