Skip to content

feat(gdn): separate input and output pool indices#2905

Open
feldsherov wants to merge 1 commit intoflashinfer-ai:mainfrom
feldsherov:gdn-decode-separate-input-and-output-indices
Open

feat(gdn): separate input and output pool indices#2905
feldsherov wants to merge 1 commit intoflashinfer-ai:mainfrom
feldsherov:gdn-decode-separate-input-and-output-indices

Conversation

@feldsherov
Copy link
Copy Markdown

@feldsherov feldsherov commented Mar 28, 2026

📌 Description

Introduce separate output indices parameter for gated_delta_rule_decode_pretranspose.

This addresses decoded part of feature request in #2873

🔍 Related Issues

#2873

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

I've checked only tests/gdn/test_decode_delta_rule.py on H200. I need help with running whole testsuite.

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added support for independent state output locations in gated delta network decoding operations, enabling users to control where updated states are written separately from where they are read.
    • Added comprehensive test coverage validating the new state output indexing functionality and verifying correct behavior across various usage scenarios and edge cases.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 28, 2026

📝 Walkthrough

Walkthrough

This PR extends the GDN decode APIs and kernels to support separate read and write indices for state pooling. When output_state_indices is provided, updated states are written to specified pool slots instead of the slots used for reading via initial_state_indices. The feature is validated, plumbed through both bf16 and pretranspose kernel paths, and includes comprehensive test coverage.

Changes

Cohort / File(s) Summary
Public API Extensions
flashinfer/gdn_decode.py
Added output_state_indices optional parameter to gated_delta_rule_decode_pretranspose() with validation ensuring it's only used in pool mode, has shape [B], and dtype int32/int64. Updated docstring to document the parameter and its default behavior.
BF16 Kernel Implementation
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Added output_state_indices parameter to the Python API wrapper gated_delta_rule() and gH_out_slot_indices write tensor to CUDA kernels. Implemented resolution logic to default h_out_slot_indices to h_slot_indices when unspecified, plumbed through kernel compilation and execution arguments, and updated cache key to include pool size. Modified write_pool_batch_idx derivation to use write indices with negative index clamping.
Pretranspose Kernel Implementation
flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Added h0_out_indices tensor parameter to CUDA kernel entry points and launch wrappers. Updated pool-indexing logic to separate read (h0_indices) from write (h0_out_indices) paths. Extended run_pretranspose_decode() public API with output_state_indices parameter and implemented resolution to convert it to int32 when use_pool_indexing is enabled.
Test Coverage
tests/gdn/test_decode_delta_rule.py
Added two parametrized tests (test_output_state_indices, test_output_state_indices_same_as_input) validating the separation of read and write pool indices, verifying correct state updates at write slots while preserving read slots and untouched pool entries.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant API as gated_delta_rule_decode<br/>_pretranspose()
    participant Wrapper as run_pretranspose_decode()
    participant Kernel as CUDA Kernels<br/>(pretranspose/bf16)
    participant Pool as State Pool
    
    User->>API: initial_state (pool)<br/>initial_state_indices (read)<br/>output_state_indices (write)
    API->>API: Validate output_state_indices<br/>(pool mode, shape [B], int32/64)
    API->>Wrapper: Forward all indices
    Wrapper->>Wrapper: Resolve h0_out_indices<br/>(default to read indices)
    Wrapper->>Kernel: Launch with h0_indices (read)<br/>h0_out_indices (write)
    Kernel->>Pool: Read states from<br/>Pool[h0_indices[b]]
    Kernel->>Kernel: Compute updates via<br/>gated delta rule
    Kernel->>Pool: Write updated states to<br/>Pool[h0_out_indices[b]]
    Pool-->>User: Updated state output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • bkryu
  • kahyunnam

Poem

🐰 Hop, hop, hop! The indices now take separate paths,
Read from here, write to there—no more taking baths!
State pools rejoice, their slots redirect with grace,
A rabbit's gift to kernels: different read-write space! 🌟

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title accurately and specifically describes the main change: introducing separate output pool indices for the gated_delta_rule_decode_pretranspose function.
Description check ✅ Passed The pull request description includes the required template sections with key information: a clear description of changes, reference to related issue #2873, completed pre-commit checklist, and added tests, though full test suite status is noted as incomplete.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the output_state_indices parameter to the Gated Delta Rule decode kernels, enabling the updated state to be written to a different pool slot than the one read from. The changes include updates to the high-level Python API, the underlying CUDA kernels in gdn_decode_bf16_state.py and gdn_decode_pretranspose.py, and the addition of verification tests. Review feedback recommends grouping index reads within the kernels to improve consistency and instruction-level parallelism, as well as simplifying redundant logic in the pretranspose runner.

Comment on lines 754 to +759
pool_batch_idx = gH_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for handling negative indices is duplicated in gated_delta_rule_decode_kernel_seqlen234_unified and gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk. To improve consistency and potentially instruction-level parallelism, consider grouping the reads together before the checks, as done in gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk.

Additionally, to reduce code duplication across kernels, you could introduce a cute.jit helper function at the module level to handle this pattern.

Suggested change
pool_batch_idx = gH_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)
pool_batch_idx = gH_slot_indices[batch_idx]
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)

Comment on lines +976 to +979
if use_pool_indexing and output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use_pool_indexing check here is redundant. The public API gated_delta_rule_decode_pretranspose already asserts that output_state_indices can only be provided when use_pool_indexing is true.

You can simplify this logic for better readability.

Suggested change
if use_pool_indexing and output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices
if output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 195-206: The output_state_indices path currently allows in-place
remaps that can alias other batch items' source slots, making final state
CTA-order dependent; in the block that checks output_state_indices (and uses
use_pool and initial_state / initial_state_indices), validate that
output_state_indices contains no duplicate targets and that none of its target
indices overlap any indices in initial_state_indices (or raise a clear error);
alternatively implement a staged fallback: allocate a temporary buffer, gather
sources into temp using initial_state_indices, perform compute, then scatter
results from temp to initial_state using output_state_indices to avoid
read/write races. Ensure checks/reference to output_state_indices,
initial_state, initial_state_indices and use_pool are used so the change locates
the remap logic.
- Around line 195-206: The code currently only checks shape/dtype of
output_state_indices; add validation that output_state_indices is on the same
device as the pool (reject CPU/non-local tensors) and that all values are within
[0, pool_size-1] to prevent out-of-bounds or aliasing when writing into the pool
(when use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1756-1767: The tests currently use torch.testing.assert_close to
check that pool_under_test[read_indices] and pool_under_test[~used_mask] match
pool_orig with nonzero atol/rtol; change these to exact-equality checks (e.g.,
use torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 748e9c86-0f49-4322-be15-919a3f8c3f91

📥 Commits

Reviewing files that changed from the base of the PR and between 904fa8c and 93687a1.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py

Comment on lines +195 to +206
if output_state_indices is not None:
assert use_pool, (
"output_state_indices can only be used with initial_state (pool mode)"
)
assert output_state_indices.shape == (B,), (
f"Expected output_state_indices shape [{B}], "
f"got {output_state_indices.shape}"
)
assert output_state_indices.dtype in (torch.int32, torch.int64), (
f"output_state_indices must be int32 or int64, "
f"got {output_state_indices.dtype}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject in-place remaps that alias another batch item's source slot.

output_state_indices still writes back into the same initial_state buffer during the same kernel launch. If two batch items target the same write slot, or one item writes a slot another item is still reading via initial_state_indices, the final state becomes CTA-order dependent and no longer matches gather→compute→scatter semantics. Please either validate a safe mapping here or route overlapping remaps through a staged fallback.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 195 - 206, The output_state_indices
path currently allows in-place remaps that can alias other batch items' source
slots, making final state CTA-order dependent; in the block that checks
output_state_indices (and uses use_pool and initial_state /
initial_state_indices), validate that output_state_indices contains no duplicate
targets and that none of its target indices overlap any indices in
initial_state_indices (or raise a clear error); alternatively implement a staged
fallback: allocate a temporary buffer, gather sources into temp using
initial_state_indices, perform compute, then scatter results from temp to
initial_state using output_state_indices to avoid read/write races. Ensure
checks/reference to output_state_indices, initial_state, initial_state_indices
and use_pool are used so the change locates the remap logic.

⚠️ Potential issue | 🔴 Critical

Validate output_state_indices against the pool before dispatch.

The new arg is only shape/dtype-checked. A CPU tensor here will fail late, and a negative or >= pool_size write index can either become an out-of-bounds store on the float32 pretranspose path or silently alias slot 0 on the bf16 path. Please reject non-local or out-of-range write indices here unless you want explicit write-side padding semantics.

💡 Suggested guard
     if output_state_indices is not None:
         assert use_pool, (
             "output_state_indices can only be used with initial_state (pool mode)"
         )
         assert output_state_indices.shape == (B,), (
             f"Expected output_state_indices shape [{B}], "
             f"got {output_state_indices.shape}"
         )
         assert output_state_indices.dtype in (torch.int32, torch.int64), (
             f"output_state_indices must be int32 or int64, "
             f"got {output_state_indices.dtype}"
         )
+        assert output_state_indices.device == initial_state.device, (
+            "output_state_indices must be on the same device as initial_state"
+        )
+        pool_size = int(initial_state.shape[0])
+        in_range = (output_state_indices >= 0) & (output_state_indices < pool_size)
+        assert in_range.all().item(), (
+            f"output_state_indices must be in [0, {pool_size})"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 195 - 206, The code currently only
checks shape/dtype of output_state_indices; add validation that
output_state_indices is on the same device as the pool (reject CPU/non-local
tensors) and that all values are within [0, pool_size-1] to prevent
out-of-bounds or aliasing when writing into the pool (when
use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.

Comment on lines +1756 to +1767
# Read slots must be unchanged (we wrote to different slots)
torch.testing.assert_close(
pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
)

# Other slots must be unchanged
used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
used_mask[read_indices] = True
used_mask[write_indices] = True
torch.testing.assert_close(
pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use exact equality for slots that must stay untouched.

These assertions are checking for no mutation, not numerical closeness. Keeping atol/rtol=1e-3 can hide a small stray write, so the read slots and the untouched remainder should be compared with zero tolerance.

💡 Tighten the unchanged-slot assertions
     # Read slots must be unchanged (we wrote to different slots)
     torch.testing.assert_close(
-        pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+        pool_under_test[read_indices],
+        pool_orig[read_indices],
+        atol=0.0,
+        rtol=0.0,
     )

     # Other slots must be unchanged
     used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
     used_mask[read_indices] = True
     used_mask[write_indices] = True
     torch.testing.assert_close(
-        pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+        pool_under_test[~used_mask],
+        pool_orig[~used_mask],
+        atol=0.0,
+        rtol=0.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 1756 - 1767, The tests
currently use torch.testing.assert_close to check that
pool_under_test[read_indices] and pool_under_test[~used_mask] match pool_orig
with nonzero atol/rtol; change these to exact-equality checks (e.g., use
torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant