System Info
- CPU architecture:
x86_64
- GPU name:
NVIDIA RTX PRO 4500 Blackwell
- GPU memory size:
32623 MiB
- NVIDIA driver version:
580.82.07
- NVIDIA-SMI version:
580.82.07
- CUDA version reported by
nvidia-smi: 13.0
- Python:
3.12.11
- OS / container base:
nvidia/cuda:13.0.0-cudnn-runtime-ubuntu24.04
- PyTorch:
2.9.1
- TensorRT-LLM versions tested:
1.3.0rc11: engine build + inference both reproduced the issue
1.2.0: issue still reproduced after downgrading to stable
- Model:
google/gemma-3-1b-it
- Runtime path: Python
LLM(...) API with async generation
- Engine path: TensorRT engine (
bfloat16_trt_llm_1_2_0)
- Batching type:
INFLIGHT
- Relevant runtime config:
enable_block_reuse=True
max_seq_len=5500
max_attention_window=[512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512]
- GPU memory usage observed during failure:
30495 MiB / 32623 MiB
Who can help?
@laikhtewari
Information
Tasks
Reproduction
We are serving Gemma 3 1B through the TensorRT engine path using the Python LLM(...) API with async generation.
The crash is reproducible when all of the following are true:
batching_type="INFLIGHT"
KvCacheConfig(enable_block_reuse=True, max_attention_window=<VSWA list>)
- Gemma 3 engine built with
max_seq_len=5500
The issue disappears when enable_block_reuse=False.
The VSWA list we use at runtime is a single-line list:
[512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512]
This is a 26-layer VSWA configuration where most layers use local window 512, layers 6, 12, 18, 24 use full window 5500, and the final two layers also use 512.
Simplified runtime setup:
from transformers import AutoTokenizer
from tensorrt_llm import SamplingParams
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.llmapi import KvCacheConfig
tokenizer = AutoTokenizer.from_pretrained("/path/to/gemma-3-1b-it")
attention_windows = [512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512]
llm = LLM(
"/path/to/engine_dir",
tokenizer=tokenizer,
kv_cache_config=KvCacheConfig(
enable_block_reuse=True,
max_attention_window=attention_windows,
),
batching_type="INFLIGHT",
max_batch_size=<engine build_config.max_batch_size>,
max_seq_len=5500,
)
sampling_params = SamplingParams(max_tokens=64)
output = await llm.generate_async(
"test prompt",
sampling_params=sampling_params,
)
Expected behavior
Either enable_block_reuse=True should work for this Gemma 3 VSWA configuration, or
TensorRT-LLM should reject this configuration cleanly if it is unsupported
It should not hit an internal assertion in forwardAsync.
actual behavior
When enable_block_reuse=True, TensorRT-LLM crashes in inflight batching with:
Assertion failed: decoderFinishedEvent must be nullopt.
When enable_block_reuse=False, the exceptions disappear and the system is stable.
Observed failure:
Traceback (most recent call last):
File "/gemma3_1b_service/.venv/lib/python3.12/site-packages/tensorrt_llm/llmapi/utils.py", line 40, in wrapper
return func(*args, **kwargs)
File "/root/.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 81, in inner
return func(*args, **kwds)
File "/gemma3_1b_service/.venv/lib/python3.12/site-packages/tensorrt_llm/executor/result.py", line 474, in _handle_response
handler(response.error_msg)
File "/gemma3_1b_service/.venv/lib/python3.12/site-packages/tensorrt_llm/executor/executor.py", line 280, in _handle_background_error
raise error
tensorrt_llm.executor.utils.RequestError: Encountered an error in forwardAsync function: [TensorRT-LLM][ERROR] Assertion failed: decoderFinishedEvent must be nullopt. (../tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp:1132)
1 0x7fef27d8197f tensorrt_llm::_v1::common::throwRuntimeError(char const*, int, char const*) + 76
2 0x7feee876a4a8 /gemma3_1b_service/.venv/lib/python3.12/site-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b6a4a8) [0x7feee876a4a8]
3 0x7feee9ec51ac tensorrt_llm::executor::Executor::Impl::forwardAsync(...) + 444
4 0x7feee9ecca7e tensorrt_llm::executor::Executor::Impl::executionLoop() + 1550
After the first failure, similar errors continue to be emitted for subsequent requests.
additional notes
We reproduced the same issue on both tensorrt-llm==1.3.0rc11 and tensorrt-llm==1.2.0, so this does not appear to be specific to only the release candidate.
The strongest signal so far is that the crash is specifically tied to KV cache block reuse. Disabling block reuse makes the service stable without other changes.
GPU memory usage was high at the time of failure (30495 MiB / 32623 MiB), but the observed error is not an OOM; it is an internal assertion in the executor path.
This may be related to existing VSWA / KV-cache-reuse issues on the TRT engine path, but our case is a hard runtime failure rather than ineffective caching:
#12563 Prefix caching (block reuse) ineffective for VSWA models like Gemma 3 on TRT engine path
#2983 disable kv cache reuse when minimum window size is reached, instead of maximum window size
Before submitting a new issue...
System Info
x86_64NVIDIA RTX PRO 4500 Blackwell32623 MiB580.82.07580.82.07nvidia-smi:13.03.12.11nvidia/cuda:13.0.0-cudnn-runtime-ubuntu24.042.9.11.3.0rc11: engine build + inference both reproduced the issue1.2.0: issue still reproduced after downgrading to stablegoogle/gemma-3-1b-itLLM(...)API with async generationbfloat16_trt_llm_1_2_0)INFLIGHTenable_block_reuse=Truemax_seq_len=5500max_attention_window=[512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512]30495 MiB / 32623 MiBWho can help?
@laikhtewari
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
We are serving Gemma 3 1B through the TensorRT engine path using the Python
LLM(...)API with async generation.The crash is reproducible when all of the following are true:
batching_type="INFLIGHT"KvCacheConfig(enable_block_reuse=True, max_attention_window=<VSWA list>)max_seq_len=5500The issue disappears when
enable_block_reuse=False.The VSWA list we use at runtime is a single-line list:
[512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512, 512, 512, 512, 5500, 512, 512]This is a 26-layer VSWA configuration where most layers use local window
512, layers6, 12, 18, 24use full window5500, and the final two layers also use512.Simplified runtime setup:
Expected behavior
Either
enable_block_reuse=Trueshould work for this Gemma 3 VSWA configuration, orTensorRT-LLM should reject this configuration cleanly if it is unsupported
It should not hit an internal assertion in forwardAsync.
actual behavior
When
enable_block_reuse=True, TensorRT-LLM crashes in inflight batching with:Assertion failed: decoderFinishedEvent must be nullopt.When
enable_block_reuse=False, the exceptions disappear and the system is stable.Observed failure:
Traceback (most recent call last):
After the first failure, similar errors continue to be emitted for subsequent requests.
additional notes
We reproduced the same issue on both
tensorrt-llm==1.3.0rc11andtensorrt-llm==1.2.0, so this does not appear to be specific to only the release candidate.The strongest signal so far is that the crash is specifically tied to KV cache block reuse. Disabling block reuse makes the service stable without other changes.
GPU memory usage was high at the time of failure (
30495 MiB / 32623 MiB), but the observed error is not an OOM; it is an internal assertion in the executor path.This may be related to existing VSWA / KV-cache-reuse issues on the TRT engine path, but our case is a hard runtime failure rather than ineffective caching:
#12563Prefix caching (block reuse) ineffective for VSWA models like Gemma 3 on TRT engine path#2983disable kv cache reuse when minimum window size is reached, instead of maximum window sizeBefore submitting a new issue...