@@ -52,7 +52,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
5252 int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
5353 int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
5454 int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
55- int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
55+ int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training,
5656 bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
5757 NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
5858 int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
@@ -120,6 +120,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
120120 max_pages_per_seq_v,
121121 bias_b,
122122 bias_h,
123+ bias_sq,
124+ bias_skv,
123125 scaling_factor,
124126 is_training,
125127 dropout_probability,
@@ -263,8 +265,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
263265 if (is_bias) {
264266 bias = mha_graph->tensor (fe::graph::Tensor_attributes ()
265267 .set_name (" bias" )
266- .set_dim ({bias_b, bias_h, s_q, s_kv })
267- .set_stride ({bias_h * s_q * s_kv, s_q * s_kv, s_kv , 1 }));
268+ .set_dim ({bias_b, bias_h, bias_sq, bias_skv })
269+ .set_stride ({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv , 1 }));
268270 sdpa_options.set_bias (bias);
269271 }
270272
@@ -539,7 +541,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
539541
540542void fused_attn_arbitrary_seqlen_bwd_impl (
541543 int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
542- int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
544+ int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv,
543545 float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
544546 NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
545547 int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
@@ -612,6 +614,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
612614 0 ,
613615 bias_b,
614616 bias_h,
617+ bias_sq,
618+ bias_skv,
615619 scaling_factor,
616620 true ,
617621 dropout_probability,
@@ -794,12 +798,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
794798 if (is_bias) {
795799 bias = mha_graph->tensor (fe::graph::Tensor_attributes ()
796800 .set_name (" bias" )
797- .set_dim ({bias_b, bias_h, s_q, s_kv })
798- .set_stride ({bias_h * s_q * s_kv, s_q * s_kv, s_kv , 1 }));
801+ .set_dim ({bias_b, bias_h, bias_sq, bias_skv })
802+ .set_stride ({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv , 1 }));
799803 dBias = mha_graph->tensor (fe::graph::Tensor_attributes ()
800804 .set_name (" dBias" )
801- .set_dim ({bias_b, bias_h, s_q, s_kv })
802- .set_stride ({bias_h * s_q * s_kv, s_q * s_kv, s_kv , 1 }));
805+ .set_dim ({bias_b, bias_h, bias_sq, bias_skv })
806+ .set_stride ({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv , 1 }));
803807 sdpa_backward_options.set_bias (bias);
804808 // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
805809 // are not supported for dbias calculation but they are
@@ -1064,10 +1068,14 @@ void fused_attn_arbitrary_seqlen_fwd(
10641068 void *devPtrBias = nullptr ;
10651069 size_t bias_b = 0 ;
10661070 size_t bias_h = 0 ;
1071+ size_t bias_sq = 0 ;
1072+ size_t bias_skv = 0 ;
10671073 if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
10681074 devPtrBias = input_Bias->data .dptr ;
10691075 bias_b = input_Bias->data .shape [0 ];
10701076 bias_h = input_Bias->data .shape [1 ];
1077+ bias_sq = input_Bias->data .shape [2 ];
1078+ bias_skv = input_Bias->data .shape [3 ];
10711079 }
10721080 void *devPtrSoftmaxOffset = nullptr ;
10731081 if (softmax_type != NVTE_VANILLA_SOFTMAX) {
@@ -1133,7 +1141,7 @@ void fused_attn_arbitrary_seqlen_fwd(
11331141 if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
11341142 Tensor *output_bias = convertNVTETensorCheck (Aux_CTX_Tensors->tensors [i++]);
11351143 output_bias->data .dptr = nullptr ;
1136- output_bias->data .shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv };
1144+ output_bias->data .shape = {bias_b, bias_h, bias_sq, bias_skv };
11371145 output_bias->data .dtype = QKV_type;
11381146 }
11391147
@@ -1178,7 +1186,7 @@ void fused_attn_arbitrary_seqlen_fwd(
11781186 fused_attn_arbitrary_seqlen_fwd_impl (
11791187 batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
11801188 max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
1181- page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
1189+ page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training,
11821190 return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
11831191 window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias,
11841192 devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
@@ -1224,11 +1232,15 @@ void fused_attn_arbitrary_seqlen_bwd(
12241232 void *devPtrdBias = nullptr ;
12251233 size_t bias_b = 0 ;
12261234 size_t bias_h = 0 ;
1235+ size_t bias_sq = 0 ;
1236+ size_t bias_skv = 0 ;
12271237 if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
12281238 devPtrBias = input_Bias->data .dptr ;
12291239 devPtrdBias = output_dBias->data .dptr ;
12301240 bias_b = output_dBias->data .shape [0 ];
12311241 bias_h = output_dBias->data .shape [1 ];
1242+ bias_sq = input_Bias->data .shape [2 ];
1243+ bias_skv = input_Bias->data .shape [3 ];
12321244 }
12331245
12341246 size_t max_batch_size = 0 ;
@@ -1271,7 +1283,7 @@ void fused_attn_arbitrary_seqlen_bwd(
12711283
12721284 fused_attn_arbitrary_seqlen_bwd_impl (
12731285 batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
1274- max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
1286+ max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout,
12751287 qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
12761288 deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
12771289 devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
0 commit comments