-
Notifications
You must be signed in to change notification settings - Fork 16
Description
State of the art
Dask struggles with masked updates. This is due to the fact that x[mask] has unknown shape, and Dask today is not smart enough to track that in x[mask] = y[mask] lhs and rhs have the same shape (dask/dask#11831).
As a way to cope with that, xpx.apply_where calls da.map_blocks and applies f1 and f2 to the individual chunks. While this works, it has the issue that the final user needs to be aware of the meta namespace, that is the namespace of the Dask chunks.
This is currently solved internally with a private function meta_namespace:
from array_api_compat import array_namespace
import array_api_extra as xpx
from array_api_extra._lib._utils._helpers import meta_namespace
xp = array_namespace(x)
mxp = meta_namespace(x, xp=xp) # Same as xp unless xp is Dask
y = xpx.apply_where(
x > 0,
x,
lambda x: mxp.sin(x),
lambda x: mxp.cos(x),
xp=xp,
)If you forget about the meta-namespace and just use xp in the lambdas, at the moment most things will keep working.
This is because accidentally several functions in the dask.array, numpy, and cupy namespaces are interoperable or even the same function. However you will find cases where this doesn't hold true and you need the correct namespace.
This will become a much bigger source of headaches in the future when dask around generic Array API compatible namespaces will become commonplace (note: Dask does NOT support them today).
This pattern repeats itself many, many times in scipy. At the moment there are only a handful of cases that are array API-aware, and they all use xp.divide, so the problem can be worked around by replacing it with operator.truediv. But if you look at scipy.stats in scipy/scipy#22557 you'll find a miriad of calls to np. functions inside the lambdas.
Proposed solutions
In the long term, I see several possible ways forward:
-
make
meta_namespacepublic API.
like: explicit is better than implicit
dislike: very verbose -
add signature magic to
apply_where; if f1 or f2 accept a keyword argument called "xp", pass to it the meta-namespace:
out = apply_where(
..., # cond
lambda a, b, xp: xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)),
lambda a, b, xp: xp.abs(a - b) <= (atol + rtol * xp.abs(b)),
(a, b)
)like: synthetic; no need for helper functions
dislike: obscure functionality which needs to be commented every time; otherwise unwary maintainers will break it by trying to simplify it. (this negates its compactness benefit)
-
as above, but call the special parameter
mxp.
like: unlikely to shadow another local variable, so new readers are forced to stop and think how it's populated
dislike: pattern is not used anywhere else. It's immediately clear to all whatxpmeans; not so much withmxp. -
just use
xpfrom the outer context in the lambdas. Expect that, by the time Dask starts supporting arbitrary array api compliant meta-namespaces, it will also have fixed its issues with NaN shapes. When that happens, we'll remove all special case handling for Dask in array-api-extra and the lamdbas will just run on filtered Dask arrays.
like: cleanest; no need to explicitly test alternative backends
dislike: not going to happen without substantial effort.
What about lazy_apply?
lazy_apply(as_numpy=False) has the same issue. However, one would expect most applied functions not to be lambdas there, so one can expect that they all start with the pattern xp = array_namespace(x, ...) on their first line.