Skip to content

Commit 8da3252

Browse files
Plumbing correct bias dims from TE to cudnn
Signed-off-by: Kshitij Lakhani <[email protected]>
1 parent eb8e792 commit 8da3252

File tree

3 files changed

+41
-19
lines changed

3 files changed

+41
-19
lines changed

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
5252
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
5353
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
5454
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
55-
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
55+
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training,
5656
bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
5757
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
5858
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
@@ -120,6 +120,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
120120
max_pages_per_seq_v,
121121
bias_b,
122122
bias_h,
123+
bias_sq,
124+
bias_skv,
123125
scaling_factor,
124126
is_training,
125127
dropout_probability,
@@ -263,8 +265,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
263265
if (is_bias) {
264266
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
265267
.set_name("bias")
266-
.set_dim({bias_b, bias_h, s_q, s_kv})
267-
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
268+
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
269+
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
268270
sdpa_options.set_bias(bias);
269271
}
270272

@@ -539,7 +541,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
539541

540542
void fused_attn_arbitrary_seqlen_bwd_impl(
541543
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
542-
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
544+
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv,
543545
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
544546
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
545547
int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
@@ -612,6 +614,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
612614
0,
613615
bias_b,
614616
bias_h,
617+
bias_sq,
618+
bias_skv,
615619
scaling_factor,
616620
true,
617621
dropout_probability,
@@ -794,12 +798,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
794798
if (is_bias) {
795799
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
796800
.set_name("bias")
797-
.set_dim({bias_b, bias_h, s_q, s_kv})
798-
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
801+
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
802+
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
799803
dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
800804
.set_name("dBias")
801-
.set_dim({bias_b, bias_h, s_q, s_kv})
802-
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
805+
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
806+
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
803807
sdpa_backward_options.set_bias(bias);
804808
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
805809
// are not supported for dbias calculation but they are
@@ -1064,10 +1068,14 @@ void fused_attn_arbitrary_seqlen_fwd(
10641068
void *devPtrBias = nullptr;
10651069
size_t bias_b = 0;
10661070
size_t bias_h = 0;
1071+
size_t bias_sq = 0;
1072+
size_t bias_skv = 0;
10671073
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
10681074
devPtrBias = input_Bias->data.dptr;
10691075
bias_b = input_Bias->data.shape[0];
10701076
bias_h = input_Bias->data.shape[1];
1077+
bias_sq = input_Bias->data.shape[2];
1078+
bias_skv = input_Bias->data.shape[3];
10711079
}
10721080
void *devPtrSoftmaxOffset = nullptr;
10731081
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
@@ -1133,7 +1141,7 @@ void fused_attn_arbitrary_seqlen_fwd(
11331141
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
11341142
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
11351143
output_bias->data.dptr = nullptr;
1136-
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
1144+
output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv};
11371145
output_bias->data.dtype = QKV_type;
11381146
}
11391147

@@ -1178,7 +1186,7 @@ void fused_attn_arbitrary_seqlen_fwd(
11781186
fused_attn_arbitrary_seqlen_fwd_impl(
11791187
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
11801188
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
1181-
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
1189+
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training,
11821190
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
11831191
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias,
11841192
devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
@@ -1224,11 +1232,15 @@ void fused_attn_arbitrary_seqlen_bwd(
12241232
void *devPtrdBias = nullptr;
12251233
size_t bias_b = 0;
12261234
size_t bias_h = 0;
1235+
size_t bias_sq = 0;
1236+
size_t bias_skv = 0;
12271237
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
12281238
devPtrBias = input_Bias->data.dptr;
12291239
devPtrdBias = output_dBias->data.dptr;
12301240
bias_b = output_dBias->data.shape[0];
12311241
bias_h = output_dBias->data.shape[1];
1242+
bias_sq = input_Bias->data.shape[2];
1243+
bias_skv = input_Bias->data.shape[3];
12321244
}
12331245

12341246
size_t max_batch_size = 0;
@@ -1271,7 +1283,7 @@ void fused_attn_arbitrary_seqlen_bwd(
12711283

12721284
fused_attn_arbitrary_seqlen_bwd_impl(
12731285
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
1274-
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
1286+
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout,
12751287
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
12761288
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
12771289
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,

transformer_engine/common/fused_attn/fused_attn_fp8.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,8 @@ void fused_attn_fp8_fwd_impl_v1(
16711671
bool is_dropout = (is_training && dropout_probability != 0.0f);
16721672
auto bias_b = b;
16731673
auto bias_h = h;
1674+
auto bias_sq = s_q;
1675+
auto bias_skv = s_kv;
16741676
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
16751677
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
16761678
bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF ||
@@ -1697,6 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1(
16971699
0,
16981700
bias_b,
16991701
bias_h,
1702+
bias_sq,
1703+
bias_skv,
17001704
scaling_factor,
17011705
is_training,
17021706
dropout_probability,
@@ -1817,8 +1821,8 @@ void fused_attn_fp8_fwd_impl_v1(
18171821
// if (is_bias) {
18181822
// bias = mha_graph->tensor(fe::graph::Tensor_attributes()
18191823
// .set_name("bias")
1820-
// .set_dim({bias_b, bias_h, s_q, s_kv})
1821-
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
1824+
// .set_dim({bias_b, bias_h, bias_sq, bias_skv})
1825+
// .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
18221826
// sdpa_options.set_bias(bias);
18231827
// }
18241828

@@ -1998,6 +2002,8 @@ void fused_attn_fp8_bwd_impl_v1(
19982002
bool is_dropout = (dropout_probability != 0.0f);
19992003
auto bias_b = b;
20002004
auto bias_h = h;
2005+
auto bias_sq = s_q;
2006+
auto bias_skv = s_kv;
20012007
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
20022008
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
20032009
bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF ||
@@ -2026,6 +2032,8 @@ void fused_attn_fp8_bwd_impl_v1(
20262032
0,
20272033
bias_b,
20282034
bias_h,
2035+
bias_sq,
2036+
bias_skv,
20292037
scaling_factor,
20302038
true,
20312039
dropout_probability,
@@ -2192,12 +2200,12 @@ void fused_attn_fp8_bwd_impl_v1(
21922200
// if (is_bias) {
21932201
// bias = mha_graph->tensor(fe::graph::Tensor_attributes()
21942202
// .set_name("bias")
2195-
// .set_dim({bias_b, bias_h, s_q, s_kv})
2196-
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
2203+
// .set_dim({bias_b, bias_h, bias_sq, bias_skv})
2204+
// .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
21972205
// dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
21982206
// .set_name("dBias")
2199-
// .set_dim({bias_b, bias_h, s_q, s_kv})
2200-
// .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
2207+
// .set_dim({bias_b, bias_h, bias_sq, bias_skv})
2208+
// .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
22012209
// sdpa_backward_options.set_bias(bias);
22022210
// // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
22032211
// // are not supported for dbias calculation but they are

transformer_engine/common/fused_attn/utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ struct FADescriptor_v1 {
101101
std::int64_t max_pages_per_seq_v;
102102
std::int64_t bias_b;
103103
std::int64_t bias_h;
104+
std::int64_t bias_sq;
105+
std::int64_t bias_skv;
104106
float attnScale;
105107
bool isTraining;
106108
float dropoutProbability;
@@ -119,13 +121,13 @@ struct FADescriptor_v1 {
119121

120122
bool operator<(const FADescriptor_v1 &rhs) const {
121123
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
122-
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
124+
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv,
123125
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
124126
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
125127
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
126128
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
127129
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
128-
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
130+
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining,
129131
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
130132
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
131133
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,

0 commit comments

Comments
 (0)