Support HiDream-O1-Image-Dev-2604#1129
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the FlashFlowMatchEulerDiscreteScheduler and updates the runner and configuration files to support the new "flash" and "flow_match" schedulers for the HiDream-O1-Image model. Key feedback on these changes includes addressing a potential bypass of the shift parameter when timesteps_list is provided, replacing exact float equality with a distance-based search to avoid precision mismatches when finding timestep indices, and removing a CPU-GPU synchronization barrier (.item()) inside the step loop to improve performance. Additionally, it is recommended to refactor duplicated code in step_post using a dynamic dictionary and to avoid hardcoding absolute user-specific paths in the shell scripts.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False): | ||
| if use_dynamic_shifting: | ||
| raise ValueError("HiDream flash scheduler does not use dynamic shifting in LightX2V.") | ||
| self.num_train_timesteps = num_train_timesteps | ||
| self.shift = shift | ||
| self.timesteps = None | ||
| self.sigmas = None | ||
| self.num_inference_steps = None | ||
| self._step_index = None | ||
|
|
||
| def set_timesteps(self, num_inference_steps, device=None): | ||
| timesteps = np.linspace(self.num_train_timesteps, 1, num_inference_steps, dtype=np.float32) | ||
| sigmas = timesteps / self.num_train_timesteps | ||
| sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) | ||
| sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) | ||
| self.timesteps = sigmas * self.num_train_timesteps | ||
| self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) | ||
| self.num_inference_steps = num_inference_steps | ||
| self._step_index = None |
There was a problem hiding this comment.
When timesteps_list is provided in build_scheduler, the scheduler's sigmas are overridden directly with unshifted values (t / 1000.0). This completely bypasses the shift parameter configured by the user, leading to incorrect generation behavior when shift != 1.0.
To fix this elegantly without breaking external overrides, we can implement sigmas as a property with a setter in FlashFlowMatchEulerDiscreteScheduler. The setter will automatically apply the shift to any externally assigned sigmas, while using an internal flag to avoid double-shifting during set_timesteps.
def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False):
if use_dynamic_shifting:
raise ValueError("HiDream flash scheduler does not use dynamic shifting in LightX2V.")
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.timesteps = None
self._sigmas = None
self.num_inference_steps = None
self._step_index = None
self._is_setting_internally = False
@property
def sigmas(self):
return self._sigmas
@sigmas.setter
def sigmas(self, value):
if value is not None and self.shift != 1.0 and not getattr(self, "_is_setting_internally", False):
value = self.shift * value / (1 + (self.shift - 1) * value)
self._sigmas = value
def set_timesteps(self, num_inference_steps, device=None):
self._is_setting_internally = True
timesteps = np.linspace(self.num_train_timesteps, 1, num_inference_steps, dtype=np.float32)
sigmas = timesteps / self.num_train_timesteps
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.timesteps = sigmas * self.num_train_timesteps
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.num_inference_steps = num_inference_steps
self._step_index = None
self._is_setting_internally = False| def index_for_timestep(self, timestep): | ||
| indices = (self.timesteps == timestep.to(self.timesteps.device)).nonzero() | ||
| pos = 1 if len(indices) > 1 else 0 | ||
| return indices[pos].item() |
There was a problem hiding this comment.
Using exact float equality (==) to find the index of a timestep is highly prone to precision mismatches and device-specific rounding differences, which can cause runtime crashes if no exact match is found.
Using torch.argmin(torch.abs(...)) is a much more robust approach that always finds the closest matching timestep.
| def index_for_timestep(self, timestep): | |
| indices = (self.timesteps == timestep.to(self.timesteps.device)).nonzero() | |
| pos = 1 if len(indices) > 1 else 0 | |
| return indices[pos].item() | |
| def index_for_timestep(self, timestep): | |
| return torch.argmin(torch.abs(self.timesteps - timestep.to(self.timesteps.device))).item() |
| if noise_clip_std > 0: | ||
| clip_val = noise_clip_std * noise.std().item() | ||
| noise = noise.clamp(min=-clip_val, max=clip_val) |
There was a problem hiding this comment.
Calling .item() on a CUDA tensor inside the main loop forces a CPU-GPU synchronization barrier, which can significantly degrade inference performance.
Since torch.clamp supports tensor bounds, we can keep clip_val as a tensor and avoid the .item() call entirely.
| if noise_clip_std > 0: | |
| clip_val = noise_clip_std * noise.std().item() | |
| noise = noise.clamp(min=-clip_val, max=clip_val) | |
| if noise_clip_std > 0: | |
| clip_val = noise_clip_std * noise.std() | |
| noise = torch.clamp(noise, min=-clip_val, max=clip_val) |
| def step_post(self): | ||
| self.latents = self.sched.step( | ||
| self.noise_pred.float(), | ||
| self.current_timestep.to(dtype=torch.float32), | ||
| self.latents.float(), | ||
| return_dict=False, | ||
| )[0].to(self.dtype) | ||
| if self.generation_config["scheduler_name"] == "flash": | ||
| self.latents = self.sched.step( | ||
| self.noise_pred.float(), | ||
| self.current_timestep.to(dtype=torch.float32), | ||
| self.latents.float(), | ||
| s_noise=self.noise_scale_schedule[self.step_index], | ||
| noise_clip_std=self.generation_config["noise_clip_std"], | ||
| return_dict=False, | ||
| )[0].to(self.dtype) | ||
| else: | ||
| self.latents = self.sched.step( | ||
| self.noise_pred.float(), | ||
| self.current_timestep.to(dtype=torch.float32), | ||
| self.latents.float(), | ||
| return_dict=False, | ||
| )[0].to(self.dtype) |
There was a problem hiding this comment.
The step_post method contains duplicated calls to self.sched.step with almost identical arguments. We can simplify this by dynamically building a kwargs dictionary, which improves readability and maintainability.
def step_post(self):
kwargs = {"return_dict": False}
if self.generation_config["scheduler_name"] == "flash":
kwargs["s_noise"] = self.noise_scale_schedule[self.step_index]
kwargs["noise_clip_std"] = self.generation_config["noise_clip_std"]
self.latents = self.sched.step(
self.noise_pred.float(),
self.current_timestep.to(dtype=torch.float32),
self.latents.float(),
**kwargs,
)[0].to(self.dtype)| lightx2v_path=/data/nvme1/yongyang/nb/LightX2V | ||
| model_path=/data/nvme1/yongyang/nb/models/HiDream-ai/HiDream-O1-Image-Dev-2604 |
There was a problem hiding this comment.
Hardcoding absolute user-specific paths (e.g., /data/nvme1/yongyang/nb/...) makes the script non-portable and broken for other environments. Consider using relative paths or environment variables with sensible defaults.
| lightx2v_path=/data/nvme1/yongyang/nb/LightX2V | |
| model_path=/data/nvme1/yongyang/nb/models/HiDream-ai/HiDream-O1-Image-Dev-2604 | |
| lightx2v_path=${LIGHTX2V_PATH:-$(pwd)} | |
| model_path=${MODEL_PATH:-/path/to/HiDream-O1-Image-Dev-2604} |
No description provided.