Skip to content

Cp thd swa with ag#2829

Draft
sudhakarsingh27 wants to merge 5 commits intoNVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag
Draft

Cp thd swa with ag#2829
sudhakarsingh27 wants to merge 5 commits intoNVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 5 commits April 3, 2026 00:02
Forward pass implementation for THD format in AttnFuncWithCPAndKVAllGather:
- Remove THD assert, add format branching (skip Q reshape, K/V movedim)
- AllGather + reorder K/V using reorder_seq_chunks_after_a2a_before_attn_thd
- Virtual cu_seqlens approach: per-step cu_seqlens_q_padded and cu_seqlens_kv
  to select valid tokens without physical tensor slicing
- Step 0: padding_causal (first chunk), Step 1: padding_causal_bottom_right
  (second chunk, front-padded)
- Output assembly copies valid token ranges per sequence per step

Also includes test_fused_attn_split_q verifying front-padding mechanism and
debug logging in fused_attn_fwd.

Backward pass and multi-GPU test still TODO.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Q always represents the last chunk in the visible KV range, so
padding_causal_bottom_right is the correct mask for all steps (not just
step 1). Also allow padding mask types for THD format and add multi-GPU
forward-only test that passes with exact zero diff on all ranks.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Mirror forward THD logic in backward:
- AllGather + reorder K/V using reorder_seq_chunks_after_a2a_before_attn_thd
- Per-step cu_seqlens reconstruction (same formula as forward)
- fused_attn_bwd with per-step mask types and padded cu_seqlens
- dQ assembly: copy valid token ranges per sequence per step
- dK/dV: accumulate full tensors from both steps (kernel zeros non-valid)
- Reverse-reorder via reorder_seq_chunks_before_a2a_after_attn_thd
  then reduce-scatter

Test results (2 GPUs, bf16, seqlens=[256,512,1024]):
  out: exact 0 diff, dq: ~1e-9, dk: ~1e-8, dv: ~5e-4

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant