Skip to content

ENH: JIT? #359

@34j

Description

@34j

I am not sure how to define a JIT decorator which support multiple libraries. Maybe like this?

from collections.abc import Mapping, Sequence
from functools import cache, wraps
from types import ModuleType
from typing import Any, Callable, ParamSpec, TypeVar

from array_api_compat import array_namespace
from frozendict import frozendict

P = ParamSpec("P")
T = TypeVar("T")


def get_jit_decorator(
    module: ModuleType,
    /,
    *,
    args: Mapping[ModuleType, Sequence[Any]] | None = None,
    kwargs: Mapping[ModuleType, Mapping[str, Any]] | None = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
    args = args or {}
    kwargs = kwargs or {}
    print(module.__name__)
    if "numpy" in module.__name__:
        import numba

        jit = numba.jit
    elif "torch" in module.__name__:
        import torch

        jit = torch.jit.script
    else:
        jit = getattr(module, "jit", lambda x: x)

    @wraps(jit)
    def inner(f: Callable[P, T]) -> Callable[P, T]:
        return jit(f, *args.get(module, []), **kwargs.get(module, {}))

    return inner  # type: ignore[return-value]


def jit(
    f: Callable[P, T],
    /,
    *,
    args: Mapping[ModuleType, Sequence[Any]] | None = None,
    kwargs: Mapping[ModuleType, Mapping[str, Any]] | None = None,
) -> Callable[P, T]:
    args = frozendict(args or {})
    kwargs = frozendict(kwargs or {})
    get_jit_decorator_cache = cache(get_jit_decorator)

    @wraps(f)
    def inner(*args_inner: P.args, **kwargs_inner: P.kwargs) -> T:
        try:
            xp = array_namespace(*args_inner)
        except TypeError as e:
            if e.args[0] == "Unrecognized array input":
                return f(*args_inner, **kwargs_inner)
            raise
        return get_jit_decorator_cache(xp, args=args, kwargs=kwargs)(f)(
            *args_inner, **kwargs_inner
        )

    return inner

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions