Skip to content

Commit 29a930a

Browse files
leisuzztcaimm
andauthored
Bugfix for flux2 img2img2 prediction (#12855)
* Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 * Bugfix for dreambooth flux2 img2img2 Co-authored-by: tcaimm <[email protected]> --------- Co-authored-by: tcaimm <[email protected]>
1 parent dad5cb5 commit 29a930a

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,9 +1695,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16951695
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
16961696

16971697
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
1698-
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
1698+
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
1699+
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
16991700
device=cond_model_input.device
17001701
)
1702+
cond_model_input_ids = cond_model_input_ids.view(
1703+
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
1704+
)
17011705

17021706
# Sample noise that we'll add to the latents
17031707
noise = torch.randn_like(model_input)
@@ -1724,6 +1728,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17241728
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
17251729
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
17261730

1731+
orig_input_shape = packed_noisy_model_input.shape
1732+
orig_input_ids_shape = model_input_ids.shape
1733+
17271734
# concatenate the model inputs with the cond inputs
17281735
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
17291736
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
@@ -1742,7 +1749,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17421749
img_ids=model_input_ids, # B, image_seq_len, 4
17431750
return_dict=False,
17441751
)[0]
1745-
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
1752+
model_pred = model_pred[:, : orig_input_shape[1], :]
1753+
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
17461754

17471755
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
17481756

0 commit comments

Comments
 (0)