@@ -191,11 +191,16 @@ namespace Qwen {
191191 };
192192
193193 class QwenImageTransformerBlock : public GGMLBlock {
194+ protected:
195+ bool zero_cond_t ;
196+
194197 public:
195198 QwenImageTransformerBlock (int64_t dim,
196199 int64_t num_attention_heads,
197200 int64_t attention_head_dim,
198- float eps = 1e-6 ) {
201+ float eps = 1e-6 ,
202+ bool zero_cond_t = false )
203+ : zero_cond_t (zero_cond_t ) {
199204 // img_mod.0 is nn.SiLU()
200205 blocks[" img_mod.1" ] = std::shared_ptr<GGMLBlock>(new Linear (dim, 6 * dim, true ));
201206
@@ -220,11 +225,37 @@ namespace Qwen {
220225 eps));
221226 }
222227
228+ std::vector<ggml_tensor*> get_mod_params_vec (ggml_context* ctx, ggml_tensor* mod_params, ggml_tensor* index = nullptr ) {
229+ // index: [N, n_img_token]
230+ // mod_params: [N, hidden_size * 12]
231+ if (index == nullptr ) {
232+ return ggml_ext_chunk (ctx, mod_params, 6 , 0 );
233+ }
234+ mod_params = ggml_reshape_1d (ctx, mod_params, ggml_nelements (mod_params));
235+ auto mod_params_vec = ggml_ext_chunk (ctx, mod_params, 12 , 0 );
236+ index = ggml_reshape_3d (ctx, index, 1 , index->ne [0 ], index->ne [1 ]); // [N, n_img_token, 1]
237+ index = ggml_repeat_4d (ctx, index, mod_params_vec[0 ]->ne [0 ], index->ne [1 ], index->ne [2 ], index->ne [3 ]); // [N, n_img_token, hidden_size]
238+ std::vector<ggml_tensor*> mod_results;
239+ for (int i = 0 ; i < 6 ; i++) {
240+ auto mod_0 = mod_params_vec[i];
241+ auto mod_1 = mod_params_vec[i + 6 ];
242+
243+ // mod_result = torch.where(index == 0, mod_0, mod_1)
244+ // mod_result = (1 - index)*mod_0 + index*mod_1
245+ mod_0 = ggml_sub (ctx, ggml_repeat (ctx, mod_0, index), ggml_mul (ctx, index, mod_0)); // [N, n_img_token, hidden_size]
246+ mod_1 = ggml_mul (ctx, index, mod_1); // [N, n_img_token, hidden_size]
247+ auto mod_result = ggml_add (ctx, mod_0, mod_1);
248+ mod_results.push_back (mod_result);
249+ }
250+ return mod_results;
251+ }
252+
223253 virtual std::pair<ggml_tensor*, ggml_tensor*> forward (GGMLRunnerContext* ctx,
224254 struct ggml_tensor * img,
225255 struct ggml_tensor * txt,
226256 struct ggml_tensor * t_emb,
227- struct ggml_tensor * pe) {
257+ struct ggml_tensor * pe,
258+ struct ggml_tensor * modulate_index = nullptr ) {
228259 // img: [N, n_img_token, hidden_size]
229260 // txt: [N, n_txt_token, hidden_size]
230261 // pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
@@ -244,14 +275,18 @@ namespace Qwen {
244275
245276 auto img_mod_params = ggml_silu (ctx->ggml_ctx , t_emb);
246277 img_mod_params = img_mod_1->forward (ctx, img_mod_params);
247- auto img_mod_param_vec = ggml_ext_chunk (ctx->ggml_ctx , img_mod_params, 6 , 0 );
278+ auto img_mod_param_vec = get_mod_params_vec (ctx->ggml_ctx , img_mod_params, modulate_index);
279+
280+ if (zero_cond_t ) {
281+ t_emb = ggml_ext_chunk (ctx->ggml_ctx , t_emb, 2 , 1 )[0 ];
282+ }
248283
249284 auto txt_mod_params = ggml_silu (ctx->ggml_ctx , t_emb);
250285 txt_mod_params = txt_mod_1->forward (ctx, txt_mod_params);
251- auto txt_mod_param_vec = ggml_ext_chunk (ctx->ggml_ctx , txt_mod_params, 6 , 0 );
286+ auto txt_mod_param_vec = get_mod_params_vec (ctx->ggml_ctx , txt_mod_params);
252287
253288 auto img_normed = img_norm1->forward (ctx, img);
254- auto img_modulated = Flux::modulate (ctx->ggml_ctx , img_normed, img_mod_param_vec[0 ], img_mod_param_vec[1 ]);
289+ auto img_modulated = Flux::modulate (ctx->ggml_ctx , img_normed, img_mod_param_vec[0 ], img_mod_param_vec[1 ], modulate_index != nullptr );
255290 auto img_gate1 = img_mod_param_vec[2 ];
256291
257292 auto txt_normed = txt_norm1->forward (ctx, txt);
@@ -264,7 +299,7 @@ namespace Qwen {
264299 txt = ggml_add (ctx->ggml_ctx , txt, ggml_mul (ctx->ggml_ctx , txt_attn_output, txt_gate1));
265300
266301 auto img_normed2 = img_norm2->forward (ctx, img);
267- auto img_modulated2 = Flux::modulate (ctx->ggml_ctx , img_normed2, img_mod_param_vec[3 ], img_mod_param_vec[4 ]);
302+ auto img_modulated2 = Flux::modulate (ctx->ggml_ctx , img_normed2, img_mod_param_vec[3 ], img_mod_param_vec[4 ], modulate_index != nullptr );
268303 auto img_gate2 = img_mod_param_vec[5 ];
269304
270305 auto txt_normed2 = txt_norm2->forward (ctx, txt);
@@ -325,6 +360,7 @@ namespace Qwen {
325360 float theta = 10000 ;
326361 std::vector<int > axes_dim = {16 , 56 , 56 };
327362 int64_t axes_dim_sum = 128 ;
363+ bool zero_cond_t = false ;
328364 };
329365
330366 class QwenImageModel : public GGMLBlock {
@@ -346,7 +382,8 @@ namespace Qwen {
346382 auto block = std::shared_ptr<GGMLBlock>(new QwenImageTransformerBlock (inner_dim,
347383 params.num_attention_heads ,
348384 params.attention_head_dim ,
349- 1e-6f ));
385+ 1e-6f ,
386+ params.zero_cond_t ));
350387 blocks[" transformer_blocks." + std::to_string (i)] = block;
351388 }
352389
@@ -421,7 +458,8 @@ namespace Qwen {
421458 struct ggml_tensor * x,
422459 struct ggml_tensor * timestep,
423460 struct ggml_tensor * context,
424- struct ggml_tensor * pe) {
461+ struct ggml_tensor * pe,
462+ struct ggml_tensor * modulate_index = nullptr ) {
425463 auto time_text_embed = std::dynamic_pointer_cast<QwenTimestepProjEmbeddings>(blocks[" time_text_embed" ]);
426464 auto txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks[" txt_norm" ]);
427465 auto img_in = std::dynamic_pointer_cast<Linear>(blocks[" img_in" ]);
@@ -430,18 +468,26 @@ namespace Qwen {
430468 auto proj_out = std::dynamic_pointer_cast<Linear>(blocks[" proj_out" ]);
431469
432470 auto t_emb = time_text_embed->forward (ctx, timestep);
433- auto img = img_in->forward (ctx, x);
434- auto txt = txt_norm->forward (ctx, context);
435- txt = txt_in->forward (ctx, txt);
471+ if (params.zero_cond_t ) {
472+ auto t_emb_0 = time_text_embed->forward (ctx, ggml_ext_zeros (ctx->ggml_ctx , timestep->ne [0 ], timestep->ne [1 ], timestep->ne [2 ], timestep->ne [3 ]));
473+ t_emb = ggml_concat (ctx->ggml_ctx , t_emb, t_emb_0, 1 );
474+ }
475+ auto img = img_in->forward (ctx, x);
476+ auto txt = txt_norm->forward (ctx, context);
477+ txt = txt_in->forward (ctx, txt);
436478
437479 for (int i = 0 ; i < params.num_layers ; i++) {
438480 auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks[" transformer_blocks." + std::to_string (i)]);
439481
440- auto result = block->forward (ctx, img, txt, t_emb, pe);
482+ auto result = block->forward (ctx, img, txt, t_emb, pe, modulate_index );
441483 img = result.first ;
442484 txt = result.second ;
443485 }
444486
487+ if (params.zero_cond_t ) {
488+ t_emb = ggml_ext_chunk (ctx->ggml_ctx , t_emb, 2 , 1 )[0 ];
489+ }
490+
445491 img = norm_out->forward (ctx, img, t_emb);
446492 img = proj_out->forward (ctx, img);
447493
@@ -453,7 +499,8 @@ namespace Qwen {
453499 struct ggml_tensor * timestep,
454500 struct ggml_tensor * context,
455501 struct ggml_tensor * pe,
456- std::vector<ggml_tensor*> ref_latents = {}) {
502+ std::vector<ggml_tensor*> ref_latents = {},
503+ struct ggml_tensor * modulate_index = nullptr ) {
457504 // Forward pass of DiT.
458505 // x: [N, C, H, W]
459506 // timestep: [N,]
@@ -479,7 +526,7 @@ namespace Qwen {
479526 int64_t h_len = ((H + (params.patch_size / 2 )) / params.patch_size );
480527 int64_t w_len = ((W + (params.patch_size / 2 )) / params.patch_size );
481528
482- auto out = forward_orig (ctx, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
529+ auto out = forward_orig (ctx, img, timestep, context, pe, modulate_index ); // [N, h_len*w_len, ph*pw*C]
483530
484531 if (out->ne [1 ] > img_tokens) {
485532 out = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , out, 0 , 2 , 1 , 3 )); // [num_tokens, N, C * patch_size * patch_size]
@@ -502,15 +549,19 @@ namespace Qwen {
502549 QwenImageParams qwen_image_params;
503550 QwenImageModel qwen_image;
504551 std::vector<float > pe_vec;
552+ std::vector<float > modulate_index_vec;
505553 SDVersion version;
506554
507555 QwenImageRunner (ggml_backend_t backend,
508556 bool offload_params_to_cpu,
509557 const String2TensorStorage& tensor_storage_map = {},
510558 const std::string prefix = " " ,
511- SDVersion version = VERSION_QWEN_IMAGE)
559+ SDVersion version = VERSION_QWEN_IMAGE,
560+ bool zero_cond_t = false )
512561 : GGMLRunner(backend, offload_params_to_cpu) {
513- qwen_image_params.num_layers = 0 ;
562+ qwen_image_params.num_layers = 0 ;
563+ qwen_image_params.zero_cond_t = zero_cond_t ;
564+ LOG_DEBUG (" zero_cond_t: %d" , zero_cond_t );
514565 for (auto pair : tensor_storage_map) {
515566 std::string tensor_name = pair.first ;
516567 if (tensor_name.find (prefix) == std::string::npos)
@@ -576,14 +627,40 @@ namespace Qwen {
576627 // pe->data = nullptr;
577628 set_backend_tensor_data (pe, pe_vec.data ());
578629
630+ ggml_tensor* modulate_index = nullptr ;
631+ if (qwen_image_params.zero_cond_t ) {
632+ modulate_index_vec.clear ();
633+
634+ int64_t h_len = ((x->ne [1 ] + (qwen_image_params.patch_size / 2 )) / qwen_image_params.patch_size );
635+ int64_t w_len = ((x->ne [0 ] + (qwen_image_params.patch_size / 2 )) / qwen_image_params.patch_size );
636+ int64_t num_img_tokens = h_len * w_len;
637+
638+ modulate_index_vec.insert (modulate_index_vec.end (), num_img_tokens, 0 .f );
639+ int64_t num_ref_img_tokens = 0 ;
640+ for (ggml_tensor* ref : ref_latents) {
641+ int64_t h_len = ((ref->ne [1 ] + (qwen_image_params.patch_size / 2 )) / qwen_image_params.patch_size );
642+ int64_t w_len = ((ref->ne [0 ] + (qwen_image_params.patch_size / 2 )) / qwen_image_params.patch_size );
643+
644+ num_ref_img_tokens += h_len * w_len;
645+ }
646+
647+ if (num_ref_img_tokens > 0 ) {
648+ modulate_index_vec.insert (modulate_index_vec.end (), num_ref_img_tokens, 1 .f );
649+ }
650+
651+ modulate_index = ggml_new_tensor_1d (compute_ctx, GGML_TYPE_F32, modulate_index_vec.size ());
652+ set_backend_tensor_data (modulate_index, modulate_index_vec.data ());
653+ }
654+
579655 auto runner_ctx = get_context ();
580656
581657 struct ggml_tensor * out = qwen_image.forward (&runner_ctx,
582658 x,
583659 timesteps,
584660 context,
585661 pe,
586- ref_latents);
662+ ref_latents,
663+ modulate_index);
587664
588665 ggml_build_forward_expand (gf, out);
589666
0 commit comments