Fix NaNs when kpad_mask is all zeros over a tile#15
Draft
AntonOresten wants to merge 4 commits intoFluxML:masterfrom
Draft
Fix NaNs when kpad_mask is all zeros over a tile#15AntonOresten wants to merge 4 commits intoFluxML:masterfrom
AntonOresten wants to merge 4 commits intoFluxML:masterfrom
Conversation
Contributor
Author
|
Not sure about this one tbh. Preferably, there'd be a sequence lengths vector of length B passed instead, where we'd skip trailing tiles entirely. Fairly straight-forward for keys, not sure about queries. |
Contributor
Author
|
We saw NaNs again specifically on a branch without these changes, just for the record. |
Member
|
Can you share MWE? |
Contributor
Author
Show groupsizejulia> @eval NNop begin
function flash_attention_groupsize(::Type{T}; emb_dim::Int, target_shmem::UInt64) where T
# TODO
# - return `qk_fp16` to configure kernel
# - optional qk_fp16
# qk_fp16s = (false, true)
# TODO prefer bigger groupsize?
qk_fp16s = (true,)
for qk_fp16 in qk_fp16s, groupsize in (256, 128, 64, 32, 16)
shmem = flash_attention_shmem_bwd(T; emb_dim, groupsize, qk_fp16)
shmem ≤ target_shmem && begin @show groupsize; return groupsize end
end
error("Failed to find groupsize for Flash Attention that satisfies Shared Memory constraint.")
end
end
flash_attention_groupsize (generic function with 1 method)julia> begin
H, L = 64, 64
pad = 32
x = CUDA.rand(H, L, 1, 1);
kpad_mask = CuArray([trues(L-pad); falses(pad);;]);
any(isnan, NNop.flash_attention(x, x, x; causal=false, kpad_mask))
end
groupsize = 32
true
julia> begin
H, L = 64, 64
pad = 31
x = CUDA.rand(H, L, 1, 1);
kpad_mask = CuArray([trues(L-pad); falses(pad);;]);
any(isnan, NNop.flash_attention(x, x, x; causal=false, kpad_mask))
end
groupsize = 32
falseWhen L is not a multiple of groupsize: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
We encountered NaNs when kpad_mask had a lot of zeros. The solution in this PR was somewhat vibe-coded, so may or may not be the simplest solution, but it seems to work.