diff --git a/giveme/injector.py b/giveme/injector.py index 618238b..3aacf6f 100644 --- a/giveme/injector.py +++ b/giveme/injector.py @@ -2,9 +2,15 @@ import warnings from functools import partial, wraps from inspect import iscoroutinefunction, signature +from typing import Callable, TypeVar, Union, Dict, Any, Tuple, Coroutine from .deferredproperty import DeferredProperty +RetType = TypeVar("RetType") +OriginalFunc = Callable[..., RetType] +InjectDecorator = Callable[[OriginalFunc], OriginalFunc] +RegisterDecorator = Callable[[OriginalFunc], Callable[..., RetType]] + class DependencyNotFoundError(Exception): pass @@ -19,33 +25,37 @@ class DependencyNotFoundWarning(RuntimeWarning): ambigious_not_found_msg = ( - 'An ambigious DependencyNotFound error occured. ' - 'Giveme could not find a dependency ' + "An ambigious DependencyNotFound error occured. " + "Giveme could not find a dependency " 'named "{}" but a matching argument' - 'was not passed. ' - 'Unable to tell whether you meant to inject ' - 'a dependency or simply forgot to pass ' - 'the correct arguments.' + "was not passed. " + "Unable to tell whether you meant to inject " + "a dependency or simply forgot to pass " + "the correct arguments." ) class Dependency: - - __slots__ = ('name', 'factory', 'singleton', 'threadlocal') - - def __init__(self, name, factory, singleton=False, threadlocal=False): + __slots__ = ("name", "factory", "singleton", "threadlocal") + + def __init__( + self, + name: str, + factory: OriginalFunc, + singleton: bool = False, + threadlocal: bool = False, + ): self.name = name - self.factory = factory + self.factory: OriginalFunc = factory self.singleton = singleton self.threadlocal = threadlocal class Injector: - - def __init__(self): + def __init__(self) -> None: self._reset() - def cache(self, dependency: Dependency, value): + def cache(self, dependency: Dependency, value: RetType) -> None: """ Store an instance of dependency in the cache. Does nothing if dependency is NOT a threadlocal @@ -61,7 +71,7 @@ def cache(self, dependency: Dependency, value): elif dependency.singleton: self._singleton[dependency.name] = value - def cached(self, dependency): + def cached(self, dependency: Dependency) -> Union[RetType, None]: """ Get a cached instance of dependency. @@ -74,7 +84,15 @@ def cached(self, dependency): elif dependency.singleton: return self._singleton.get(dependency.name) - def _set(self, name, factory, singleton=False, threadlocal=False): + return None + + def _set( + self, + name: Union[str, None], + factory: OriginalFunc, + singleton: bool = False, + threadlocal: bool = False, + ) -> None: """ Add a dependency factory to the registry @@ -90,40 +108,40 @@ def _set(self, name, factory, singleton=False, threadlocal=False): if iscoroutinefunction(factory): raise AsyncDependencyForbiddenError(name) name = name or factory.__name__ - factory._giveme_registered_name = name + factory._giveme_registered_name = name # type: ignore dep = Dependency(name, factory, singleton, threadlocal) self._registry[name] = dep - def get(self, name: str): + def get(self, name: str) -> RetType: """ Get an instance of dependency, this can be either a cached instance or a new one (in which case the factory is called) """ - dep = None try: dep = self._registry[name] except KeyError: raise DependencyNotFoundError(name) from None - value = self.cached(dep) + + value: Union[RetType, None] = self.cached(dep) if value is None: value = dep.factory() self.cache(dep, value) return value - def _reset(self): - self._local = threading.local() - self._singleton = {} - self._registry = {} + def _reset(self) -> None: + self._local: threading.local = threading.local() + self._singleton: Dict[str, RetType] = {} + self._registry: Dict[str, Dependency] = {} - def clear(self): + def clear(self) -> None: """ Clear (unregister) all dependencies. Useful in tests, where you need clean setup on every test. """ self._reset() - def delete(self, name): + def delete(self, name: str) -> None: """ Delete (unregister) a dependency by name. """ @@ -133,7 +151,14 @@ def delete(self, name): delattr(self._local, name) del self._registry[name] - def register(self, function=None, *, singleton=False, threadlocal=False, name=None): + def register( + self, + function: Union[OriginalFunc, None] = None, + *, + singleton: bool = False, + threadlocal: bool = False, + name: Union[str, None] = None + ) -> Union[OriginalFunc, InjectDecorator]: """ Add an object to the injector's registry. @@ -158,21 +183,25 @@ def register(self, function=None, *, singleton=False, threadlocal=False, name=No :type threadlocal: bool :type name: string """ - def decorator(function=None): + + def decorator(function: OriginalFunc) -> OriginalFunc: self._set(name, function, singleton, threadlocal) return function + if function: return decorator(function) return decorator - def _resolve_arguments(self, function, names, args, kwargs): + def _resolve_arguments( + self, function: OriginalFunc, names: Dict[str, str], args: Any, kwargs: Any + ) -> Tuple[Tuple, Dict[str, RetType]]: sig = signature(function) params = sig.parameters bound = sig.bind_partial(*args, **kwargs) bound.apply_defaults() - injected_kwargs = {} + injected_kwargs: Dict[str, RetType] = {} for key, value in params.items(): if key not in bound.arguments: name = names.get(key) @@ -186,13 +215,19 @@ def _resolve_arguments(self, function, names, args, kwargs): except DependencyNotFoundError: warnings.warn( ambigious_not_found_msg.format(key), - DependencyNotFoundWarning + DependencyNotFoundWarning, ) injected_kwargs.update(bound.kwargs) return bound.args, injected_kwargs - def inject(self, function=None, **names): + def inject( + self, function: Union[OriginalFunc, None] = None, **names: str + ) -> Union[ + Callable[..., RetType], + Callable[..., Coroutine[Any, Any, RetType]], + RegisterDecorator, + ]: """ Inject dependencies into `funtion`'s arguments when called. @@ -211,14 +246,17 @@ def inject(self, function=None, **names): the default behavior which matches dependency names with argument names. """ - def decorator(function): + + def decorator( + function: OriginalFunc, + ) -> Union[Callable[..., RetType], Callable[..., Coroutine[Any, Any, RetType]]]: @wraps(function) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> RetType: args, kwargs = self._resolve_arguments(function, names, args, kwargs) return function(*args, **kwargs) @wraps(function) - async def awrapper(*args, **kwargs): + async def awrapper(*args: Any, **kwargs: Any) -> RetType: args, kwargs = self._resolve_arguments(function, names, args, kwargs) return await function(*args, **kwargs) @@ -252,6 +290,4 @@ def resolve(self, dependency): else: name = dependency._giveme_registered_name - return DeferredProperty( - partial(self.get, name) - ) + return DeferredProperty(partial(self.get, name))