2222DEALINGS IN THE SOFTWARE.
2323"""
2424
25+ from typing import (
26+ Any ,
27+ Dict ,
28+ ForwardRef ,
29+ Iterable ,
30+ Literal ,
31+ Tuple ,
32+ Union ,
33+ get_args as get_typing_args ,
34+ get_origin as get_typing_origin ,
35+ )
2536import asyncio
2637import functools
2738import inspect
28- import typing
2939import datetime
3040import sys
3141
6474 'bot_has_guild_permissions'
6575)
6676
77+ PY_310 = sys .version_info >= (3 , 10 )
78+
79+ def flatten_literal_params (parameters : Iterable [Any ]) -> Tuple [Any , ...]:
80+ params = []
81+ literal_cls = type (Literal [0 ])
82+ for p in parameters :
83+ if isinstance (p , literal_cls ):
84+ params .extend (p .__args__ )
85+ else :
86+ params .append (p )
87+ return tuple (params )
88+
89+ def _evaluate_annotation (tp : Any , globals : Dict [str , Any ], cache : Dict [str , Any ] = {}, * , implicit_str = True ):
90+ if isinstance (tp , ForwardRef ):
91+ tp = tp .__forward_arg__
92+ # ForwardRefs always evaluate their internals
93+ implicit_str = True
94+
95+ if implicit_str and isinstance (tp , str ):
96+ if tp in cache :
97+ return cache [tp ]
98+ evaluated = eval (tp , globals )
99+ cache [tp ] = evaluated
100+ return _evaluate_annotation (evaluated , globals , cache )
101+
102+ if hasattr (tp , '__args__' ):
103+ implicit_str = True
104+ args = tp .__args__
105+ if tp .__origin__ is Literal :
106+ if not PY_310 :
107+ args = flatten_literal_params (tp .__args__ )
108+ implicit_str = False
109+
110+ evaluated_args = tuple (
111+ _evaluate_annotation (arg , globals , cache , implicit_str = implicit_str ) for arg in args
112+ )
113+
114+ if evaluated_args == args :
115+ return tp
116+
117+ try :
118+ return tp .copy_with (evaluated_args )
119+ except AttributeError :
120+ return tp .__origin__ [evaluated_args ]
121+
122+ return tp
123+
124+ def resolve_annotation (annotation : Any , globalns : Dict [str , Any ], cache : Dict [str , Any ] = {}) -> Any :
125+ if annotation is None :
126+ return type (None )
127+ if isinstance (annotation , str ):
128+ annotation = ForwardRef (annotation )
129+ return _evaluate_annotation (annotation , globalns , cache )
130+
131+ def get_signature_parameters (function ) -> Dict [str , inspect .Parameter ]:
132+ globalns = function .__globals__
133+ signature = inspect .signature (function )
134+ params = {}
135+ cache : Dict [str , Any ] = {}
136+ for name , parameter in signature .parameters .items ():
137+ annotation = parameter .annotation
138+ if annotation is parameter .empty :
139+ params [name ] = parameter
140+ continue
141+ if annotation is None :
142+ params [name ] = parameter .replace (annotation = type (None ))
143+ continue
144+
145+ annotation = _evaluate_annotation (annotation , globalns , cache )
146+ if annotation is converters .Greedy :
147+ raise TypeError ('Unparameterized Greedy[...] is disallowed in signature.' )
148+
149+ params [name ] = parameter .replace (annotation = annotation )
150+
151+ return params
152+
153+
67154def wrap_callback (coro ):
68155 @functools .wraps (coro )
69156 async def wrapped (* args , ** kwargs ):
@@ -300,40 +387,7 @@ def callback(self):
300387 def callback (self , function ):
301388 self ._callback = function
302389 self .module = function .__module__
303-
304- signature = inspect .signature (function )
305- self .params = signature .parameters .copy ()
306-
307- # see: https://bugs.python.org/issue41341
308- resolve = self ._recursive_resolve if sys .version_info < (3 , 9 ) else self ._return_resolved
309-
310- try :
311- type_hints = {k : resolve (v ) for k , v in typing .get_type_hints (function ).items ()}
312- except NameError as e :
313- raise NameError (f'unresolved forward reference: { e .args [0 ]} ' ) from None
314-
315- for key , value in self .params .items ():
316- # coalesce the forward references
317- if key in type_hints :
318- self .params [key ] = value = value .replace (annotation = type_hints [key ])
319-
320- # fail early for when someone passes an unparameterized Greedy type
321- if value .annotation is converters .Greedy :
322- raise TypeError ('Unparameterized Greedy[...] is disallowed in signature.' )
323-
324- def _return_resolved (self , type , ** kwargs ):
325- return type
326-
327- def _recursive_resolve (self , type , * , globals = None ):
328- if not isinstance (type , typing .ForwardRef ):
329- return type
330-
331- resolved = eval (type .__forward_arg__ , globals )
332- args = typing .get_args (resolved )
333- for index , arg in enumerate (args ):
334- inner_resolve_result = self ._recursive_resolve (arg , globals = globals )
335- resolved [index ] = inner_resolve_result
336- return resolved
390+ self .params = get_signature_parameters (function )
337391
338392 def add_check (self , func ):
339393 """Adds a check to the command.
@@ -493,12 +547,12 @@ async def _actual_conversion(self, ctx, converter, argument, param):
493547 raise BadArgument (f'Converting to "{ name } " failed for parameter "{ param .name } ".' ) from exc
494548
495549 async def do_conversion (self , ctx , converter , argument , param ):
496- origin = typing . get_origin (converter )
550+ origin = get_typing_origin (converter )
497551
498- if origin is typing . Union :
552+ if origin is Union :
499553 errors = []
500554 _NoneType = type (None )
501- for conv in typing . get_args (converter ):
555+ for conv in get_typing_args (converter ):
502556 # if we got to this part in the code, then the previous conversions have failed
503557 # so we should just undo the view, return the default, and allow parsing to continue
504558 # with the other parameters
@@ -514,13 +568,12 @@ async def do_conversion(self, ctx, converter, argument, param):
514568 return value
515569
516570 # if we're here, then we failed all the converters
517- raise BadUnionArgument (param , typing . get_args (converter ), errors )
571+ raise BadUnionArgument (param , get_typing_args (converter ), errors )
518572
519- if origin is typing . Literal :
573+ if origin is Literal :
520574 errors = []
521575 conversions = {}
522- literal_args = tuple (self ._flattened_typing_literal_args (converter ))
523- for literal in literal_args :
576+ for literal in converter .__args__ :
524577 literal_type = type (literal )
525578 try :
526579 value = conversions [literal_type ]
@@ -538,7 +591,7 @@ async def do_conversion(self, ctx, converter, argument, param):
538591 return value
539592
540593 # if we're here, then we failed to match all the literals
541- raise BadLiteralArgument (param , literal_args , errors )
594+ raise BadLiteralArgument (param , converter . __args__ , errors )
542595
543596 return await self ._actual_conversion (ctx , converter , argument , param )
544597
@@ -1021,14 +1074,7 @@ def short_doc(self):
10211074 return ''
10221075
10231076 def _is_typing_optional (self , annotation ):
1024- return typing .get_origin (annotation ) is typing .Union and typing .get_args (annotation )[- 1 ] is type (None )
1025-
1026- def _flattened_typing_literal_args (self , annotation ):
1027- for literal in typing .get_args (annotation ):
1028- if typing .get_origin (literal ) is typing .Literal :
1029- yield from self ._flattened_typing_literal_args (literal )
1030- else :
1031- yield literal
1077+ return get_typing_origin (annotation ) is Union and get_typing_args (annotation )[- 1 ] is type (None )
10321078
10331079 @property
10341080 def signature (self ):
@@ -1048,17 +1094,16 @@ def signature(self):
10481094 # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
10491095 # parameter signature is a literal list of it's values
10501096 annotation = param .annotation .converter if greedy else param .annotation
1051- origin = typing . get_origin (annotation )
1052- if not greedy and origin is typing . Union :
1053- union_args = typing . get_args (annotation )
1097+ origin = get_typing_origin (annotation )
1098+ if not greedy and origin is Union :
1099+ union_args = get_typing_args (annotation )
10541100 optional = union_args [- 1 ] is type (None )
10551101 if optional :
10561102 annotation = union_args [0 ]
1057- origin = typing . get_origin (annotation )
1103+ origin = get_typing_origin (annotation )
10581104
1059- if origin is typing .Literal :
1060- name = '|' .join (f'"{ v } "' if isinstance (v , str ) else str (v )
1061- for v in self ._flattened_typing_literal_args (annotation ))
1105+ if origin is Literal :
1106+ name = '|' .join (f'"{ v } "' if isinstance (v , str ) else str (v ) for v in annotation .__args__ )
10621107 if param .default is not param .empty :
10631108 # We don't want None or '' to trigger the [name=value] case and instead it should
10641109 # do [name] since [name=None] or [name=] are not exactly useful for the user.
0 commit comments