Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 74 additions & 38 deletions giveme/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
import warnings
from functools import partial, wraps
from inspect import iscoroutinefunction, signature
from typing import Callable, TypeVar, Union, Dict, Any, Tuple, Coroutine

from .deferredproperty import DeferredProperty

RetType = TypeVar("RetType")
OriginalFunc = Callable[..., RetType]
InjectDecorator = Callable[[OriginalFunc], OriginalFunc]
RegisterDecorator = Callable[[OriginalFunc], Callable[..., RetType]]


class DependencyNotFoundError(Exception):
pass
Expand All @@ -19,33 +25,37 @@ class DependencyNotFoundWarning(RuntimeWarning):


ambigious_not_found_msg = (
'An ambigious DependencyNotFound error occured. '
'Giveme could not find a dependency '
"An ambigious DependencyNotFound error occured. "
"Giveme could not find a dependency "
'named "{}" but a matching argument'
'was not passed. '
'Unable to tell whether you meant to inject '
'a dependency or simply forgot to pass '
'the correct arguments.'
"was not passed. "
"Unable to tell whether you meant to inject "
"a dependency or simply forgot to pass "
"the correct arguments."
)


class Dependency:

__slots__ = ('name', 'factory', 'singleton', 'threadlocal')

def __init__(self, name, factory, singleton=False, threadlocal=False):
__slots__ = ("name", "factory", "singleton", "threadlocal")

def __init__(
self,
name: str,
factory: OriginalFunc,
singleton: bool = False,
threadlocal: bool = False,
):
self.name = name
self.factory = factory
self.factory: OriginalFunc = factory
self.singleton = singleton
self.threadlocal = threadlocal


class Injector:

def __init__(self):
def __init__(self) -> None:
self._reset()

def cache(self, dependency: Dependency, value):
def cache(self, dependency: Dependency, value: RetType) -> None:
"""
Store an instance of dependency in the cache.
Does nothing if dependency is NOT a threadlocal
Expand All @@ -61,7 +71,7 @@ def cache(self, dependency: Dependency, value):
elif dependency.singleton:
self._singleton[dependency.name] = value

def cached(self, dependency):
def cached(self, dependency: Dependency) -> Union[RetType, None]:
"""
Get a cached instance of dependency.

Expand All @@ -74,7 +84,15 @@ def cached(self, dependency):
elif dependency.singleton:
return self._singleton.get(dependency.name)

def _set(self, name, factory, singleton=False, threadlocal=False):
return None

def _set(
self,
name: Union[str, None],
factory: OriginalFunc,
singleton: bool = False,
threadlocal: bool = False,
) -> None:
"""
Add a dependency factory to the registry

Expand All @@ -90,40 +108,40 @@ def _set(self, name, factory, singleton=False, threadlocal=False):
if iscoroutinefunction(factory):
raise AsyncDependencyForbiddenError(name)
name = name or factory.__name__
factory._giveme_registered_name = name
factory._giveme_registered_name = name # type: ignore
dep = Dependency(name, factory, singleton, threadlocal)
self._registry[name] = dep

def get(self, name: str):
def get(self, name: str) -> RetType:
"""
Get an instance of dependency,
this can be either a cached instance
or a new one (in which case the factory is called)
"""
dep = None
try:
dep = self._registry[name]
except KeyError:
raise DependencyNotFoundError(name) from None
value = self.cached(dep)

value: Union[RetType, None] = self.cached(dep)
if value is None:
value = dep.factory()
self.cache(dep, value)
return value

def _reset(self):
self._local = threading.local()
self._singleton = {}
self._registry = {}
def _reset(self) -> None:
self._local: threading.local = threading.local()
self._singleton: Dict[str, RetType] = {}
self._registry: Dict[str, Dependency] = {}

def clear(self):
def clear(self) -> None:
"""
Clear (unregister) all dependencies. Useful in tests, where you need
clean setup on every test.
"""
self._reset()

def delete(self, name):
def delete(self, name: str) -> None:
"""
Delete (unregister) a dependency by name.
"""
Expand All @@ -133,7 +151,14 @@ def delete(self, name):
delattr(self._local, name)
del self._registry[name]

def register(self, function=None, *, singleton=False, threadlocal=False, name=None):
def register(
self,
function: Union[OriginalFunc, None] = None,
*,
singleton: bool = False,
threadlocal: bool = False,
name: Union[str, None] = None
) -> Union[OriginalFunc, InjectDecorator]:
"""
Add an object to the injector's registry.

Expand All @@ -158,21 +183,25 @@ def register(self, function=None, *, singleton=False, threadlocal=False, name=No
:type threadlocal: bool
:type name: string
"""
def decorator(function=None):

def decorator(function: OriginalFunc) -> OriginalFunc:
self._set(name, function, singleton, threadlocal)
return function

if function:
return decorator(function)
return decorator

def _resolve_arguments(self, function, names, args, kwargs):
def _resolve_arguments(
self, function: OriginalFunc, names: Dict[str, str], args: Any, kwargs: Any
) -> Tuple[Tuple, Dict[str, RetType]]:
sig = signature(function)
params = sig.parameters

bound = sig.bind_partial(*args, **kwargs)
bound.apply_defaults()

injected_kwargs = {}
injected_kwargs: Dict[str, RetType] = {}
for key, value in params.items():
if key not in bound.arguments:
name = names.get(key)
Expand All @@ -186,13 +215,19 @@ def _resolve_arguments(self, function, names, args, kwargs):
except DependencyNotFoundError:
warnings.warn(
ambigious_not_found_msg.format(key),
DependencyNotFoundWarning
DependencyNotFoundWarning,
)

injected_kwargs.update(bound.kwargs)
return bound.args, injected_kwargs

def inject(self, function=None, **names):
def inject(
self, function: Union[OriginalFunc, None] = None, **names: str
) -> Union[
Callable[..., RetType],
Callable[..., Coroutine[Any, Any, RetType]],
RegisterDecorator,
]:
"""
Inject dependencies into `funtion`'s arguments when called.

Expand All @@ -211,14 +246,17 @@ def inject(self, function=None, **names):
the default behavior which matches dependency names with argument
names.
"""
def decorator(function):

def decorator(
function: OriginalFunc,
) -> Union[Callable[..., RetType], Callable[..., Coroutine[Any, Any, RetType]]]:
@wraps(function)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> RetType:
args, kwargs = self._resolve_arguments(function, names, args, kwargs)
return function(*args, **kwargs)

@wraps(function)
async def awrapper(*args, **kwargs):
async def awrapper(*args: Any, **kwargs: Any) -> RetType:
args, kwargs = self._resolve_arguments(function, names, args, kwargs)
return await function(*args, **kwargs)

Expand Down Expand Up @@ -252,6 +290,4 @@ def resolve(self, dependency):
else:
name = dependency._giveme_registered_name

return DeferredProperty(
partial(self.get, name)
)
return DeferredProperty(partial(self.get, name))