@@ -1695,9 +1695,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16951695 cond_model_input = (cond_model_input - latents_bn_mean ) / latents_bn_std
16961696
16971697 model_input_ids = Flux2Pipeline ._prepare_latent_ids (model_input ).to (device = model_input .device )
1698- cond_model_input_ids = Flux2Pipeline ._prepare_image_ids (cond_model_input ).to (
1698+ cond_model_input_list = [cond_model_input [i ].unsqueeze (0 ) for i in range (cond_model_input .shape [0 ])]
1699+ cond_model_input_ids = Flux2Pipeline ._prepare_image_ids (cond_model_input_list ).to (
16991700 device = cond_model_input .device
17001701 )
1702+ cond_model_input_ids = cond_model_input_ids .view (
1703+ cond_model_input .shape [0 ], - 1 , model_input_ids .shape [- 1 ]
1704+ )
17011705
17021706 # Sample noise that we'll add to the latents
17031707 noise = torch .randn_like (model_input )
@@ -1724,6 +1728,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17241728 packed_noisy_model_input = Flux2Pipeline ._pack_latents (noisy_model_input )
17251729 packed_cond_model_input = Flux2Pipeline ._pack_latents (cond_model_input )
17261730
1731+ orig_input_shape = packed_noisy_model_input .shape
1732+ orig_input_ids_shape = model_input_ids .shape
1733+
17271734 # concatenate the model inputs with the cond inputs
17281735 packed_noisy_model_input = torch .cat ([packed_noisy_model_input , packed_cond_model_input ], dim = 1 )
17291736 model_input_ids = torch .cat ([model_input_ids , cond_model_input_ids ], dim = 1 )
@@ -1742,7 +1749,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17421749 img_ids = model_input_ids , # B, image_seq_len, 4
17431750 return_dict = False ,
17441751 )[0 ]
1745- model_pred = model_pred [:, : packed_noisy_model_input .size (1 ) :]
1752+ model_pred = model_pred [:, : orig_input_shape [1 ], :]
1753+ model_input_ids = model_input_ids [:, : orig_input_ids_shape [1 ], :]
17461754
17471755 model_pred = Flux2Pipeline ._unpack_latents_with_ids (model_pred , model_input_ids )
17481756
0 commit comments