Skip to content

Bug in reverse ODE sampling #89

@MetaZuo

Description

@MetaZuo

Current code uses v(x, 1-t) as the reverse drift and attempts to run in decreasing time (t0 > t1)

if reverse:
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
else:
drift = self.drift
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=reverse,
last_step_size=0.0,
)
_ode = ode(
drift=drift,
t0=t0,
t1=t1,
sampler_type=sampling_method,
num_steps=num_steps,
atol=atol,
rtol=rtol,
timestep_shift=timestep_shift,
)

if reverse:
t0, t1 = 1 - t0, 1 - t1
return t0, t1

which first of all causes an assertion error in class ode.
assert t0 < t1, "ODE sampler has to be in forward time"

Beyond that error, the current combination of drift and time is incorrect. Either 1) use drift -v(x, 1-t) and let t increase from 0 to 1, or 2) use the original drift, keep reverse time and delete the assertion, as torchdiffeq.odeint already supports decreasing time. In detail, choose one from the following.

Type 1 fix:

  • delete line 108-109 in transport.py
  • negative drift at line 419 drift = lambda x, t, model, **kwargs: - self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)

Type 2 fix:

  • delete the if-else at line 418-421 in transport.py and only use drift = self.drift
  • delete the assertion at line 91 in integrators.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions