@@ -45,23 +45,20 @@ def is_valid_type(obj: Any) -> bool:
4545 return False
4646
4747
48- def is_valid_callable (obj : Any ) -> bool :
48+ def is_sqlstr_callable (obj : Any ) -> bool :
4949 """Check if the object is a valid callable for a parameter type."""
5050 if not callable (obj ):
5151 return False
5252
5353 returns = utils .get_annotations (obj ).get ('return' , None )
5454
55- if inspect .isclass (returns ) and issubclass (returns , str ):
55+ if inspect .isclass (returns ) and issubclass (returns , SQLString ):
5656 return True
5757
58- raise TypeError (
59- f'callable { obj } must return a str, '
60- f'but got { returns } ' ,
61- )
58+ return False
6259
6360
64- def expand_types (args : Any ) -> Optional [Union [ List [str ], Type [ Any ] ]]:
61+ def expand_types (args : Any ) -> Optional [List [Any ]]:
6562 """Expand the types for the function arguments / return values."""
6663 if args is None :
6764 return None
@@ -70,28 +67,32 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
7067 if isinstance (args , str ):
7168 return [args ]
7269
73- # General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
74- elif is_valid_type (args ):
75- return args
76-
7770 # List of SQL strings or callables
7871 elif isinstance (args , list ):
79- new_args = []
72+ new_args : List [ Any ] = []
8073 for arg in args :
8174 if isinstance (arg , str ):
8275 new_args .append (arg )
83- elif callable (arg ):
76+ elif is_sqlstr_callable (arg ):
8477 new_args .append (arg ())
78+ elif type (arg ) is type :
79+ new_args .append (arg )
80+ elif is_valid_type (arg ):
81+ new_args .append (arg )
8582 else :
8683 raise TypeError (f'unrecognized type for parameter: { arg } ' )
8784 return new_args
8885
8986 # Callable that returns a SQL string
90- elif is_valid_callable (args ):
91- out = args ()
92- if not isinstance (out , str ):
93- raise TypeError (f'unrecognized type for parameter: { args } ' )
94- return [out ]
87+ elif is_sqlstr_callable (args ):
88+ return [args ()]
89+
90+ # General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
91+ elif is_valid_type (args ):
92+ return [args ]
93+
94+ elif type (args ) is type :
95+ return [args ]
9596
9697 raise TypeError (f'unrecognized type for parameter: { args } ' )
9798
0 commit comments