-
Notifications
You must be signed in to change notification settings - Fork 54
Open
Description
Current code uses v(x, 1-t) as the reverse drift and attempts to run in decreasing time (t0 > t1)
LightningDiT/transport/transport.py
Lines 418 to 441 in 2725fed
| if reverse: | |
| drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) | |
| else: | |
| drift = self.drift | |
| t0, t1 = self.transport.check_interval( | |
| self.transport.train_eps, | |
| self.transport.sample_eps, | |
| sde=False, | |
| eval=True, | |
| reverse=reverse, | |
| last_step_size=0.0, | |
| ) | |
| _ode = ode( | |
| drift=drift, | |
| t0=t0, | |
| t1=t1, | |
| sampler_type=sampling_method, | |
| num_steps=num_steps, | |
| atol=atol, | |
| rtol=rtol, | |
| timestep_shift=timestep_shift, | |
| ) |
LightningDiT/transport/transport.py
Lines 108 to 111 in 2725fed
| if reverse: | |
| t0, t1 = 1 - t0, 1 - t1 | |
| return t0, t1 |
which first of all causes an assertion error in class
ode.LightningDiT/transport/integrators.py
Line 91 in 2725fed
| assert t0 < t1, "ODE sampler has to be in forward time" |
Beyond that error, the current combination of drift and time is incorrect. Either 1) use drift
-v(x, 1-t) and let t increase from 0 to 1, or 2) use the original drift, keep reverse time and delete the assertion, as torchdiffeq.odeint already supports decreasing time. In detail, choose one from the following.
Type 1 fix:
- delete line 108-109 in
transport.py - negative drift at line 419
drift = lambda x, t, model, **kwargs: - self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
Type 2 fix:
- delete the if-else at line 418-421 in
transport.pyand only usedrift = self.drift - delete the assertion at line 91 in
integrators.py
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels