diff --git a/injector/__init__.py b/injector/__init__.py index 9b52729..7d372dd 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -1367,6 +1367,23 @@ def provider(function: CallableT) -> CallableT: return function +def named_provider(name: str) -> CallableT: + """Decorator for :class:`Module` methods, creating annoted type a of a type. + + >>> class MyModule(Module): + ... @named_provider("first") + ... def provide_name(self) -> str: + ... return 'Bob' + + """ + + def decorator(function: CallableT): + _mark_named_provider_function(function, name, allow_multi=False) + return function + + return decorator + + def multiprovider(function: CallableT) -> CallableT: """Like :func:`provider`, but for multibindings. Example usage:: @@ -1390,7 +1407,7 @@ def provide_strs_also(self) -> List[str]: def _mark_provider_function(function: Callable, *, allow_multi: bool) -> None: scope_ = getattr(function, '__scope__', None) try: - annotations = get_type_hints(function) + annotations = get_type_hints(function, include_extras=True) except NameError: return_type = '__deferred__' else: @@ -1399,6 +1416,20 @@ def _mark_provider_function(function: Callable, *, allow_multi: bool) -> None: function.__binding__ = Binding(return_type, inject(function), scope_) # type: ignore +def _mark_named_provider_function(function: Callable, name: str, *, allow_multi: bool) -> None: + scope_ = getattr(function, '__scope__', None) + try: + annotations = get_type_hints(function, include_extras=True) + except NameError: + return_type = '__deferred__' + else: + raw_return_type = annotations['return'] + return_type = Annotated[raw_return_type, name] + + _validate_provider_return_type(function, cast(type, return_type), allow_multi) + function.__binding__ = Binding(return_type, inject(function), scope_) + + def _validate_provider_return_type(function: Callable, return_type: type, allow_multi: bool) -> None: origin = _get_origin(_punch_through_alias(return_type)) if origin in {dict, list} and not allow_multi: diff --git a/injector_test.py b/injector_test.py index 6260033..ac64f5c 100644 --- a/injector_test.py +++ b/injector_test.py @@ -55,6 +55,7 @@ Error, UnknownArgument, InvalidInterface, + named_provider, ) @@ -2032,3 +2033,35 @@ class MyClass: injector = Injector([configure]) instance = injector.get(MyClass) assert instance.foo == 123 + + +def test_module_provider_with_annotated(): + class MyModule(Module): + @provider + def provide_first(self) -> Annotated[str, 'first']: + return 'Bob' + + @provider + def provide_second(self) -> Annotated[str, 'second']: + return 'Iger' + + module = MyModule() + injector = Injector(module) + assert injector.get(Annotated[str, 'first']) == 'Bob' + assert injector.get(Annotated[str, 'second']) == 'Iger' + + +def test_module_named_provider(): + class MyModule(Module): + @named_provider('first') + def provide_first(self) -> str: + return 'Bob' + + @named_provider('second') + def provide_second(self) -> str: + return 'Iger' + + module = MyModule() + injector = Injector(module) + assert injector.get(Annotated[str, 'first']) == 'Bob' + assert injector.get(Annotated[str, 'second']) == 'Iger'