Draft
Conversation
Forward pass implementation for THD format in AttnFuncWithCPAndKVAllGather: - Remove THD assert, add format branching (skip Q reshape, K/V movedim) - AllGather + reorder K/V using reorder_seq_chunks_after_a2a_before_attn_thd - Virtual cu_seqlens approach: per-step cu_seqlens_q_padded and cu_seqlens_kv to select valid tokens without physical tensor slicing - Step 0: padding_causal (first chunk), Step 1: padding_causal_bottom_right (second chunk, front-padded) - Output assembly copies valid token ranges per sequence per step Also includes test_fused_attn_split_q verifying front-padding mechanism and debug logging in fused_attn_fwd. Backward pass and multi-GPU test still TODO. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Q always represents the last chunk in the visible KV range, so padding_causal_bottom_right is the correct mask for all steps (not just step 1). Also allow padding mask types for THD format and add multi-GPU forward-only test that passes with exact zero diff on all ranks. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Mirror forward THD logic in backward: - AllGather + reorder K/V using reorder_seq_chunks_after_a2a_before_attn_thd - Per-step cu_seqlens reconstruction (same formula as forward) - fused_attn_bwd with per-step mask types and padded cu_seqlens - dQ assembly: copy valid token ranges per sequence per step - dK/dV: accumulate full tensors from both steps (kernel zeros non-valid) - Reverse-reorder via reorder_seq_chunks_before_a2a_after_attn_thd then reduce-scatter Test results (2 GPUs, bf16, seqlens=[256,512,1024]): out: exact 0 diff, dq: ~1e-9, dk: ~1e-8, dv: ~5e-4 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: