You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
⚠️ The implementation of this issue is reserved for core developers, feedback welcome from anyone⚠️
This was lightly discussed internally. We suggest adding a new user-facing Scan API similar to JAX's, with signature:
defstep_fn(carry, x):
returnnew_carry, None# with traced outputs (old outputs_info=[None])returnnew_carry, y,
# optional break_cond, that turns it into a while scanreturnnew_carry, y, break_condfinal, ys, done=scan(step_fn, init, xs, n_steps)
# done only if there was a break_cond, it denotes whether the break_cond ever become true# (as we still require a maximum number of n_steps in the inner scan, or n_steps could be zero)
init/carry, sequences, and y can be a simple pytree (nested tuple/dictionary), to avoid flat signature mess users need to navigate these days.
carry/new_carry pytree must have the same structure and root types
Internally we keep using the existing ScanOp, which requires flattenning everything, and ordering things in a specific way, but the user only interacts directly with the structured inputs and outputs
The existing user-facing scan function (that helps build a ScanOp) would still exist with a deprecation warning. New one will need a new name or new module location. maybe call it pytensor.loop?
Taps
Taps is one of the largest sources of complexity of Scan, but arguably quite aligned with memory layout we want in the end (i.e., a single outer buffer whose last n entries can be accessed in the inner function)
But internally we want to convert xtm2, xtm1 to a tap outputs_info in the regular scan. Graph analysis is perhaps too tricky for this, so I suggest implementing an helper.
init=scan.taps(a, taps=[-2, -1]) # if taps=None, default to all the entries of x leading dimension, if length is statically knowndefstep_fn(carry, x):
atm2, atm1=carryat=f(atm2, atm1, x)
next_carry=carry.push_tap(xt)
scan.taps creates a special class that new pytensor scan recognizes. They must be at the bottom of the pytree, and can be composed with it (that is stuff like carry=(taps(a), b) is also fine).
We can raise informative errors if carry.push_tap wasn't called or was called multiple times.
Scan sequences also allow taps (with positive integers). Can also composed at the bottom of the xs pytree.
We forbid carry.push_tap in these
If we find graph analysis is sufficiently robust, we can always deprecate later down the road
User benefits
Simpler syntax
It's currently very hard to get the internal function signature correctly when you have many sequences and outputs with taps. Codegen scans with arbitrary functions are particularly hard, and already require some sort of flatenning/unflatening ad-hoc logic to keeps things manageable
Easier adoption
Specially if you already know JAX. I feel devs are much more at ease writing jax scans than they are at writing pytensor scans, even though they are conceptually quite similar
Having a python helper / class with proper docstrings / introspectability for scan.taps is a huge plus. I always have to go and check how the dictionary that's required know looks like (what are the keys).
Consistency for updates of variables that can't be traced, like RNGs. The deprecated updates signature was mostly for this. The new proposed API with return_updates=False, puts these outputs together with regular recurring outputs with the gotcha that you only get back the last state, which is inconsistent. For regular outputs you still need to do [-1] if you only want the last one.
For types that can't be traced, Scan will be happy to handle them in the init/carry/final side of the Scan, and will raise an informative error if you try to trace it as a y
Valid state for n_steps=0. Because users are only given a "trace" in the current implementation, there's no clean way to get last_state if n_steps=0. With the init/final design, this is valid. last_state = init_state. No risky trace[init:][-1] business. (There's some resolved discussion in Implement unconstraining transform for LKJCorr pymc#7380 (comment))
It's still invalid for traced outputs (aka the y in the example above), but you usually want the whole sequence for those, and an empty sequence is well defined
(except for the little gotcha of not always knowing the exact zero shape it should have, shape=(0, 5) is different than shape=(0, 2, 3))
Internal benefits
Most cases of scan_save_mem become trivial to analyze. We mostly need to distinguish between traced scan and final_state scan. The cases where a user wants the last n states, with n > 1, are exceedingly rare. The cases where the user wants the first n < N steps are also rare, sometimes in autodiff, when the gradient doesn't depend on the output quantity, the very last step can be cut. I reckon it's fine to lose this fine-grained optimization while focusing on the main one (whole scan or only last state).
scan_save_mem is one of the most complicated rewrites we have, and it's mostly due to bad initial semantics for Scan
Description
This was lightly discussed internally. We suggest adding a new user-facing Scan API similar to JAX's, with signature:
pytensor.loop?Taps
Taps is one of the largest sources of complexity of Scan, but arguably quite aligned with memory layout we want in the end (i.e., a single outer buffer whose last n entries can be accessed in the inner function)
The user could naively implement taps like:
But internally we want to convert xtm2, xtm1 to a tap outputs_info in the regular scan. Graph analysis is perhaps too tricky for this, so I suggest implementing an helper.
carry=(taps(a), b)is also fine).carry.push_tapwasn't called or was called multiple times.xspytree.carry.push_tapin theseUser benefits
scan.tapsis a huge plus. I always have to go and check how the dictionary that's required know looks like (what are the keys).updatessignature was mostly for this. The new proposed API withreturn_updates=False, puts these outputs together with regular recurring outputs with the gotcha that you only get back the last state, which is inconsistent. For regular outputs you still need to do [-1] if you only want the last one.init/carry/finalside of the Scan, and will raise an informative error if you try to trace it as aytrace[init:][-1]business. (There's some resolved discussion in Implement unconstraining transform for LKJCorr pymc#7380 (comment))shape=(0, 5)is different thanshape=(0, 2, 3))Internal benefits
scan_save_membecome trivial to analyze. We mostly need to distinguish betweentracedscan andfinal_statescan. The cases where a user wants the last n states, with n > 1, are exceedingly rare. The cases where the user wants the first n < N steps are also rare, sometimes in autodiff, when the gradient doesn't depend on the output quantity, the very last step can be cut. I reckon it's fine to lose this fine-grained optimization while focusing on the main one (whole scan or only last state).scan_save_memis one of the most complicated rewrites we have, and it's mostly due to bad initial semantics for Scanlocal_subtensor_mergesomewhat which is a source of graph explosion and rewrite slowdown:local_subtensor_mergecan complicate graphs #112