Skip to content

Support HiDream-O1-Image-Dev-2604#1129

Merged
llmc-reviewer merged 1 commit into
mainfrom
hd
Jun 8, 2026
Merged

Support HiDream-O1-Image-Dev-2604#1129
llmc-reviewer merged 1 commit into
mainfrom
hd

Conversation

@helloyongyang

Copy link
Copy Markdown
Contributor

No description provided.

@llmc-reviewer llmc-reviewer merged commit 8bdaa69 into main Jun 8, 2026
2 checks passed
@llmc-reviewer llmc-reviewer deleted the hd branch June 8, 2026 06:10

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the FlashFlowMatchEulerDiscreteScheduler and updates the runner and configuration files to support the new "flash" and "flow_match" schedulers for the HiDream-O1-Image model. Key feedback on these changes includes addressing a potential bypass of the shift parameter when timesteps_list is provided, replacing exact float equality with a distance-based search to avoid precision mismatches when finding timestep indices, and removing a CPU-GPU synchronization barrier (.item()) inside the step loop to improve performance. Additionally, it is recommended to refactor duplicated code in step_post using a dynamic dictionary and to avoid hardcoding absolute user-specific paths in the shell scripts.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +16 to +34
def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False):
if use_dynamic_shifting:
raise ValueError("HiDream flash scheduler does not use dynamic shifting in LightX2V.")
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.timesteps = None
self.sigmas = None
self.num_inference_steps = None
self._step_index = None

def set_timesteps(self, num_inference_steps, device=None):
timesteps = np.linspace(self.num_train_timesteps, 1, num_inference_steps, dtype=np.float32)
sigmas = timesteps / self.num_train_timesteps
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.timesteps = sigmas * self.num_train_timesteps
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.num_inference_steps = num_inference_steps
self._step_index = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When timesteps_list is provided in build_scheduler, the scheduler's sigmas are overridden directly with unshifted values (t / 1000.0). This completely bypasses the shift parameter configured by the user, leading to incorrect generation behavior when shift != 1.0.

To fix this elegantly without breaking external overrides, we can implement sigmas as a property with a setter in FlashFlowMatchEulerDiscreteScheduler. The setter will automatically apply the shift to any externally assigned sigmas, while using an internal flag to avoid double-shifting during set_timesteps.

    def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False):
        if use_dynamic_shifting:
            raise ValueError("HiDream flash scheduler does not use dynamic shifting in LightX2V.")
        self.num_train_timesteps = num_train_timesteps
        self.shift = shift
        self.timesteps = None
        self._sigmas = None
        self.num_inference_steps = None
        self._step_index = None
        self._is_setting_internally = False

    @property
    def sigmas(self):
        return self._sigmas

    @sigmas.setter
    def sigmas(self, value):
        if value is not None and self.shift != 1.0 and not getattr(self, "_is_setting_internally", False):
            value = self.shift * value / (1 + (self.shift - 1) * value)
        self._sigmas = value

    def set_timesteps(self, num_inference_steps, device=None):
        self._is_setting_internally = True
        timesteps = np.linspace(self.num_train_timesteps, 1, num_inference_steps, dtype=np.float32)
        sigmas = timesteps / self.num_train_timesteps
        sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
        self.timesteps = sigmas * self.num_train_timesteps
        self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
        self.num_inference_steps = num_inference_steps
        self._step_index = None
        self._is_setting_internally = False

Comment on lines +36 to +39
def index_for_timestep(self, timestep):
indices = (self.timesteps == timestep.to(self.timesteps.device)).nonzero()
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using exact float equality (==) to find the index of a timestep is highly prone to precision mismatches and device-specific rounding differences, which can cause runtime crashes if no exact match is found.

Using torch.argmin(torch.abs(...)) is a much more robust approach that always finds the closest matching timestep.

Suggested change
def index_for_timestep(self, timestep):
indices = (self.timesteps == timestep.to(self.timesteps.device)).nonzero()
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def index_for_timestep(self, timestep):
return torch.argmin(torch.abs(self.timesteps - timestep.to(self.timesteps.device))).item()

Comment on lines +60 to +62
if noise_clip_std > 0:
clip_val = noise_clip_std * noise.std().item()
noise = noise.clamp(min=-clip_val, max=clip_val)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .item() on a CUDA tensor inside the main loop forces a CPU-GPU synchronization barrier, which can significantly degrade inference performance.

Since torch.clamp supports tensor bounds, we can keep clip_val as a tensor and avoid the .item() call entirely.

Suggested change
if noise_clip_std > 0:
clip_val = noise_clip_std * noise.std().item()
noise = noise.clamp(min=-clip_val, max=clip_val)
if noise_clip_std > 0:
clip_val = noise_clip_std * noise.std()
noise = torch.clamp(noise, min=-clip_val, max=clip_val)

Comment on lines 157 to +173
def step_post(self):
self.latents = self.sched.step(
self.noise_pred.float(),
self.current_timestep.to(dtype=torch.float32),
self.latents.float(),
return_dict=False,
)[0].to(self.dtype)
if self.generation_config["scheduler_name"] == "flash":
self.latents = self.sched.step(
self.noise_pred.float(),
self.current_timestep.to(dtype=torch.float32),
self.latents.float(),
s_noise=self.noise_scale_schedule[self.step_index],
noise_clip_std=self.generation_config["noise_clip_std"],
return_dict=False,
)[0].to(self.dtype)
else:
self.latents = self.sched.step(
self.noise_pred.float(),
self.current_timestep.to(dtype=torch.float32),
self.latents.float(),
return_dict=False,
)[0].to(self.dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The step_post method contains duplicated calls to self.sched.step with almost identical arguments. We can simplify this by dynamically building a kwargs dictionary, which improves readability and maintainability.

    def step_post(self):
        kwargs = {"return_dict": False}
        if self.generation_config["scheduler_name"] == "flash":
            kwargs["s_noise"] = self.noise_scale_schedule[self.step_index]
            kwargs["noise_clip_std"] = self.generation_config["noise_clip_std"]
        self.latents = self.sched.step(
            self.noise_pred.float(),
            self.current_timestep.to(dtype=torch.float32),
            self.latents.float(),
            **kwargs,
        )[0].to(self.dtype)

Comment on lines +4 to +5
lightx2v_path=/data/nvme1/yongyang/nb/LightX2V
model_path=/data/nvme1/yongyang/nb/models/HiDream-ai/HiDream-O1-Image-Dev-2604

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding absolute user-specific paths (e.g., /data/nvme1/yongyang/nb/...) makes the script non-portable and broken for other environments. Consider using relative paths or environment variables with sensible defaults.

Suggested change
lightx2v_path=/data/nvme1/yongyang/nb/LightX2V
model_path=/data/nvme1/yongyang/nb/models/HiDream-ai/HiDream-O1-Image-Dev-2604
lightx2v_path=${LIGHTX2V_PATH:-$(pwd)}
model_path=${MODEL_PATH:-/path/to/HiDream-O1-Image-Dev-2604}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants