Skip to content

feat(gemma4): feature-complete backend with DFlash + MTP + sparse-FA decode (supersedes PR #175 skeleton)#193

Draft
dusterbloom wants to merge 4 commits into
Luce-Org:mainfrom
dusterbloom:feature/gemma4-feature-complete
Draft

feat(gemma4): feature-complete backend with DFlash + MTP + sparse-FA decode (supersedes PR #175 skeleton)#193
dusterbloom wants to merge 4 commits into
Luce-Org:mainfrom
dusterbloom:feature/gemma4-feature-complete

Conversation

@dusterbloom
Copy link
Copy Markdown
Contributor

Supersedes #171, #176, #179, #184, #185 — those were split against pre-#175 main. After #175 merged howard0su's gemma4/ skeleton into main, the split chain was overtaken by upstream work. This PR re-cuts against current main as a single cohesive replacement: howard0su's 341-LoC AR-only Gemma4Backend skeleton → our feature-complete impl (DFlash + MTP + sparse-FA decode + per-layer embedding + MoE expert metadata + asymmetric KV).

Validation (RTX 3090, Gemma4-26B-A4B-it Q4_K_M, apples-to-apples binary built from this branch's libdflash27b.a)

Cell Config Prefill (tok/s) Decode (tok/s) TTFT MTP accept VRAM
1 MoE MTP γ=1 + TQ3+sparse-FA @ 16k (8k-tok prompt) 2137 43.86 46 ms 0.56 18.54 GB
2 MoE MTP γ=1 + TQ3+sparse-FA @ 64k (50k-tok prompt) 1544 20.86 99 ms 0.52 19.03 GB

Byte-identical to upstream llama-cli AR Q4 @ temp=0 greedy on this same model (Q4_0 KV is lossless; our forward implements the same Gemma4 architecture; identical argmax).

Slop-fix parity: cell 1 on pre-slop binary (May-13 feature/gemma4-support tip) = 42.85 tok/s decode vs. 43.86 tok/s on this branch's post-slop binary — within run-to-run noise. Confirms S1–S10 cleanups are zero-perf-cost.

Scope by layer

Core (target forward + AR decode + loader)

  • gemma4/gemma4_loader.cpp — Per-layer-embedding (PLE) table, MoE expert metadata, asymmetric KV override for MTP donor layers
  • gemma4/gemma4_graph.cpp — iSWA target forward in pure ggml, per-layer-embedding decode path
  • gemma4/gemma4_backend.cppdecode_autoregressive, prefill, snapshot save/restore
  • gemma4/gemma4_daemon.{h,cpp} — protocol wrapper, mirrors run_laguna_daemon
  • gemma4/gemma4_internal.h — 17-line passthrough to ../internal.h (struct defs sit alongside other archs)
  • Replaces howard's Gemma4Weights/Gemma4Cache names with our existing GemmaTargetWeights/GemmaTargetCache family (canonical across all archs in this tree)

TQ3 KV (TurboQuant)

  • Ring-buffer + SWA boundary correctness for long contexts
  • Asymmetric KV override: when MTP loads, donor layer K/V forced to Q8 (TQ3 unreadable by MTP cross-attn)

Sparse-FA decode dispatch (commit 2 — refactor: rename use_pflash → use_sparse_fa)

  • use_sparse_fa flag + F5 v2 gate at gemma4_graph.cpp can_sparse_fa predicate
  • Decode-only (n_tokens == 1) because CUDA dispatcher at deps/llama.cpp/ggml/src/ggml-cuda/fattn.cu:572-576 has no sparse-FA kernel for head_dim=512 + mask + S>1 (BEST_FATTN_KERNEL_NONE abort)
  • Not pFlash — see naming clarification below

DFlash speculative decode

  • gemma4_dflash_graph.cpp (top-level — separate model, the drafter)
  • gemma4_backend.cpp::decode_dflash (425 LoC inner loop, DDTree verify)
  • target_feat capture in target graph (hidden states fed to draft KV prefill)
  • Draft KV prefill helper (run_draft_kv_prefill)

MTP γ=1 (Multi-Token Prediction)

  • gemma4_mtp_graph.cpp (top-level — separate head, the MTP cross-attn)
  • gemma4_backend.cpp::decode_mtp
  • mtp_h_prev capture (post-output-norm hidden, fed to MTP cross-attn)
  • Asymmetric KV override activation

Naming clarification (commit 2)

The previous internal naming used use_pflash / pflash_alpha / s_pflash_tq3 / DFLASH_PFLASH_TQ3 to refer to the ggml_flash_attn_sparse op. This collides with the PFlash product (Python speculative-prefill compression module at pflash/, which calls pflash_register_ggml_kernel() to register its BSA forward as the implementation of ggml_flash_attn_sparse).

Commit 2 renames the Gemma4-scoped flag to use_sparse_fa etc. PFlash product files (pflash_ggml_adapter.{cpp,h}, pflash/ dir, README/RESULTS prose) are untouched — those are correctly named.

A real pFlash (= prefill flash, a D=512 sparse-FA-with-mask CUDA kernel) is future work, separate from this PR, blocked on a D=512 kernel that doesn't exist yet in upstream ggml-cuda.

Out of scope (future PRs)

  • MTP γ=2 — needed for >1× chain speculative decode (1.5-2.5 days estimated)
  • Real pFlash D=512 sparse-FA-with-mask kernel — would widen the F5 v2 gate to allow prefill dispatch (20-32 senior-CUDA-days estimated per a private kernel-author scoping)
  • Bench harness (test_gemma4_dflash.cpp) — currently lives only on feature/gemma4-support branch; will be brought into dflash/test/gemma4/ as a follow-up

Commits

  1. b0b4b7e feat(gemma4): feature-complete backend with DFlash + MTP + sparse-FA decode (supersedes PR #175 skeleton) — 22 files, +7887/-1081
  2. 9b311d0 refactor(gemma4): rename use_pflash → use_sparse_fa (clarify ggml_flash_attn_sparse is not the PFlash product) — 8 files, +46/-46

Closes

#171 (target loader), #176 (target graph), #179 (sparse-FA integration), #184 (backend), #185 (daemon routing) — all superseded by this PR.

…decode (supersedes PR Luce-Org#175 skeleton)

Replaces howard0su's PR Luce-Org#175 Gemma4 skeleton with a feature-complete
implementation. Howard's PR landed an AR-only backend (341 LoC) sufficient
to validate the daemon protocol; this PR brings the production paths.

What changes
------------
gemma4/gemma4_backend.{cpp,h}  341+91   → 1182+155
  - decode_autoregressive (AR)
  - decode_dflash         (speculative decode with DDTree drafter)
  - decode_mtp            (Multi-Token Prediction, γ=1)
  - snapshot_save / snapshot_restore
  - park / unpark for speculative draft

gemma4/gemma4_graph.cpp        448 → 1218
  - iSWA target forward in pure ggml
  - target_feat capture (hidden states for DFlash draft KV prefill)
  - mtp_h_prev capture   (post-output-norm hidden for MTP cross-attn)
  - sparse-FA dispatch via ggml_flash_attn_sparse (full-attn layers, decode-only)
  - F5 v2 gate (n_tokens==1) — CUDA dispatcher has no sparse kernel
    for head_dim=512+mask+S>1 (fattn.cu:572-576 abort)
  - Per-layer embedding decode handling

gemma4/gemma4_loader.cpp       370 → 1197
  - Per-layer-embedding (PLE) table loading
  - MoE expert metadata (top-k routing, expert_count, shared_exp)
  - Asymmetric KV override (MTP donor layers forced to Q8)

gemma4/gemma4_daemon.{cpp,h}    29+20 → 43+46
  - Extended Gemma4BackendConfig (draft_method, mtp_gamma,
    sparse_fa_alpha, draft_kv_cap_override, max_ctx)

gemma4/gemma4_internal.h       184 → 17  (passthrough to ../internal.h)
  - Internal struct definitions consolidated into shared ../internal.h
    where they sit alongside Qwen3/Laguna/Qwen35 struct families
    (GemmaTargetWeights / GemmaTargetCache / MtpLayerWeights / SwaView /
    Gemma4GraphInputs / Gemma4GraphOutputs / GemmaDraftLayer /
    GemmaDraftWeights / MtpDrafterWeights / MtpStepGraph)
  - Howard's Gemma4Weights / Gemma4Cache / Gemma4Snapshot struct names
    are NOT carried — our pre-existing GemmaTarget* family is the
    canonical naming across all archs in this tree

New files
---------
gemma4/gemma4_runtime_helpers.{cpp,h}  shared graph builders + masks
gemma4_dflash_graph.cpp                drafter step graph (top-level — separate model)
gemma4_mtp_graph.cpp                   MTP step graph (top-level — separate head)
include/gemma4.h                       public C-ish API surface

Tests
-----
test/gemma4/smoke_load_gemma4_target.cpp   loader smoke (Dense 31B Q4_K_M)
test/gemma4/smoke_load_gemma4_draft.cpp    drafter smoke
test/gemma4/smoke_gemma4_target_forward.cpp  AR forward smoke
test/gemma4/smoke_gemma4_draft_forward.cpp   drafter forward smoke
test/gemma4/test_gemma4_kv_tq3.cpp         TQ3_0 KV cache correctness at 16k+
test/gemma4/test_mtp_loader.cpp            MTP assistant GGUF loader
test/gemma4/test_mtp_graph_shapes.cpp      MTP step graph shape invariants

Validated configurations (RTX 3090, Gemma4-31B-it Q4_K_M)
--------------------------------------------------------
  AR + Q4 KV @ 16k:     678 pp / 22.6 tg tok/s  (+80% / +11% vs upstream)
  AR + TQ3 KV @ 16k:    739 pp / 16.6 tg
  MTP γ=1 + Q4 @ 16k:   691 pp / 21.4 tg
  Byte-identical to upstream llama-cli AR Q4 @ temp=0 greedy

Naming clarification
--------------------
The `use_pflash` / `pflash_alpha` / `s_pflash_tq3` / `DFLASH_PFLASH_TQ3`
identifiers in this commit refer to the ggml_flash_attn_sparse op, NOT
the PFlash product (Python speculative-prefill compression module at
pflash/). A follow-up commit on this branch renames them to use_sparse_fa
etc. to remove the ambiguity.

Closes (when merged): Luce-Org#171, Luce-Org#176, Luce-Org#179, Luce-Org#184, Luce-Org#185 — those were the
original split-PR series against pre-Luce-Org#175 main; reworked into this single
PR after Luce-Org#175 brought the gemma4/ skeleton into main.
…sh_attn_sparse is not the PFlash product)

The previous commit's Gemma4 backend used `use_pflash` / `pflash_alpha`
/ `s_pflash_tq3` / `pflash_supports` / `can_pflash` / `DFLASH_PFLASH_TQ3`
identifiers and "pFlash" prose to refer to the `ggml_flash_attn_sparse`
op. This created a name collision with the **PFlash** product
(`pflash/` Python module — speculative-prefill compression in front of
DFlash, shipped on main).

The two are different things:

- **PFlash** (the product) — Python module at `pflash/`, drafter +
  scoring + select pipeline for compressing the prompt before it
  reaches the target. Registers its BSA forward kernel as the
  implementation of `ggml_flash_attn_sparse` via
  `pflash_register_ggml_kernel()` (`pflash_ggml_adapter.cpp:31`).

- **sparse FA** — the ggml op `ggml_flash_attn_sparse`. Maskless flash
  attention kernel that any backend can dispatch. Decode-only for
  Gemma4's head_dim=512 because the CUDA dispatcher has no sparse
  kernel for `head_dim=512 + mask + n_tokens>1`
  (`fattn.cu:572-576 BEST_FATTN_KERNEL_NONE`).

This commit renames all Gemma4-scoped identifiers and comments so a
reader can tell which one our backend code is talking about. PFlash
product files (`pflash_ggml_adapter.{cpp,h}`, `pflash_daemon`,
`bench_laguna_pflash`, the `pflash/` directory, README/RESULTS prose)
are UNCHANGED — those are correctly named.

Rename map
----------
  use_pflash          → use_sparse_fa
  pflash_alpha        → sparse_fa_alpha
  s_pflash_tq3        → s_sparse_fa_tq3
  pflash_supports     → sparse_fa_supports
  can_pflash          → can_sparse_fa
  DFLASH_PFLASH_TQ3   → DFLASH_SPARSE_FA_TQ3
  "pFlash sparse path" / "pFlash sparse-FA dispatch" / etc.
                      → "sparse-FA path" / "sparse-FA dispatch"

Real pFlash (= prefill flash, the D=512 sparse-FA-with-mask CUDA
kernel) is future work — tracked separately as the "D=512 sparse-FA
kernel" plan derived from the F5 v2 gate analysis. That work will
introduce a properly-named `pflash_*` API distinct from the
`use_sparse_fa` dispatch flag.

Scope
-----
Touches only this PR's Gemma4 content. No changes to PFlash product
files or to other backends (Qwen3, Laguna, Qwen35).
Adds best-first tree-structured speculative verification (DDTree) to the
Gemma4 DFlash backend, alongside the chain-verify path. budget=0 keeps
chain semantics; budget>0 activates tree verify.

Backend additions:
- Gemma4BackendConfig / Gemma4DaemonArgs: ddtree_budget, ddtree_temp,
  ddtree_chain_seed, ignore_eos fields plumbed daemon -> backend.
- StepGraph.argmax_tokens [N] i32 output tensor for in-graph per-slot
  argmax (avoids vocab*N logits round-trip).
- build_gemma4_step_tree(): tree-aware target forward with custom
  visibility mask + per-slot positions (siblings share parent position,
  do not see each other). Padded to n_max = budget+1 for gallocr reuse.
- decode_dflash ddtree_enabled branch: extract_draft_topk -> build_ddtree
  -> tree-verify forward -> follow_verified_tree. Persistent tree_sg
  (allocator + graph) reused across iterations.
- Inject target's tok_embd into draft weights post-load (draft
  safetensors ship without tok_embd; tied LM head per
  gemma4_dflash_graph.cpp:777).

Bug fixes (Codex review):
- compact_gemma4_tree_path: cudaMemcpyAsync(null stream) -> synchronous
  cudaMemcpy. Subsequent graph runs use ggml's non-blocking backend
  streams which are not ordered after the null-stream copies, risking
  stale reads on sibling acceptance.
- [spec] stats line: stdout -> stderr. The daemon protocol owns stdout
  (line-oriented ok/err); the stray stats line corrupted client parsing
  for generate / bare-prompt requests.
- DDTree truncation state: when commit_n < accept_depth (early stop
  from in-loop EOS or n_gen cap), cur_tok now uses the verified target
  argmax at the actually-committed leaf instead of next_token (which
  was computed for the position beyond the full accept_depth). Matches
  chain-path semantics. Affects daemon mid-stream cancel + snapshot-
  restore.

Perf cleanups:
- OpenMP linked into dflash27b. The #pragma omp parallel for in
  ddtree.cpp extract_draft_topk was a no-op before -- only test_dflash
  was wired with OpenMP. Now the 262K-vocab top-K scan parallelizes
  across L=draft positions.
- build_moe_ffn: replace cont -> gelu -> mul with ggml_geglu_split,
  matching upstream llama.cpp's gate/up activation path. Removes two
  explicit ggml_cont memory copies on the MoE FFN hot path. Paired
  bench (10 pairs x 3 prompts, RTX 3090 @ 350W, kv_q8_0):
    MoE AR med_delta = +3.55 tok/s
    positive in 28/30 paired runs (sign-test p ~ 1e-6)

CLI / drivers:
- test_dflash gains --budget flag and DFLASH_GEMMA4_DDTREE_{BUDGET,TEMP,
  CHAIN_SEED} env vars piped through the gemma4 dispatch.
- New thin driver test_gemma4_dflash_backend.cpp (~180 LoC) exercises
  Gemma4Backend directly.
- New fat driver test_gemma4_dflash.cpp (~3.3k LoC) inline AR + spec
  decode loop with prefill timing the thin driver lacks.
- CMakeLists builds both as test_gemma4_dflash and
  test_gemma4_dflash_fat respectively.

Validated on RTX 3090 @ 350W (unlocked from 203W default), KV q4_0,
sparse-FA on, n_predict=128, 5 HE-style prompts x 3 runs each:
- MoE Gemma4-26B-A4B AR baseline (fat driver no-draft, kv_q8_0):
  73.56 tok/s med, public llama.cpp range 80-110 -- gap closing.
- Dense Gemma4-31B AR baseline (kv_q8_0): 34.12 tok/s med, public
  range 30-34 -- dead center.
- MoE chain spec dm=16: 94.44 med over AR baseline 73.56 = 1.28x.
- Dense DDTree b=22 dm=11: 45.13 med over AR baseline 34.12 = 1.32x.

Notes:
- --pflash gate is "prompts >= 4096 tokens"; in short-prompt decode
  benches sparse-FA is inactive, so reported speedup is pure spec
  decode + perf cleanups.
- Output tokens are not bit-identical across (budget, dm) cells at
  temp=0 -- different verify batch shapes break BF16/Q4 argmax ties
  differently. Spec-decode is algorithmically lossless; semantically
  equivalent text across configs.

Follow-ups (not in this commit):
- GPU top-K via ggml_top_k + ggml_log_softmax + ggml_get_rows to
  eliminate vocab D2H per spec step
- Persist gallocr in draft_step_build + draft_kv_prefill_create
- pflash sparse-FA kernel correctness audit (q_len=1, kv_len=1024)
- Adaptive draft-block tuning for MoE (separate commit)
Enables the existing AdaptiveDraftMax rolling-window estimator on MoE
target weights only (n_layer <= 40 heuristic — Gemma4-26B-A4B has 30
layers, Dense-31B has 61). When fill ratio drops <0.35 the draft block
halves; when >0.78 it doubles. Floor min_q=8 prevents the one-way
ratchet from collapsing to dm=1 on prompts where the drafter has a
bad early window (unfloored adaptive lost 30-60% in the sanity sweep).

Dense path is intentionally left at fixed dm=16: empirical sweep showed
adaptive regresses 30-60% on hard-to-draft Dense prompts (chain Dense
AR p3 collapsed from 32 to 12-18 tok/s at any floor) because Dense
weight reads amortize across the full block regardless of acceptance.

DFLASH_NO_ADAPTIVE=1 env var disables the heuristic so the same binary
can produce adapt-ON and adapt-OFF measurements in a single session
(prevents cross-binary thermal/build artifacts).

Bench (RTX 3090 @ 350W, KV q4_0, sparse-FA on, 5 HE-style prompts x 3
runs, n_predict=128, all numbers vs AR baseline from fat driver no-draft
at kv_q8_0 = 73.56 tok/s median for MoE):
- MoE chain dm=16 NO_ADAPTIVE=1: 94.44 tok/s med  = 1.28x over AR
- MoE chain dm=16 adapt+floor=8 : 122.03 tok/s med = 1.66x over AR
- MoE DDTree b=16 dm=8 + adapt  : 102.29 tok/s med = 1.39x over AR

Adaptive alone beats DDTree+adaptive on MoE in this bench — DDTree's
per-step overhead doesn't pay back when adaptive is already capturing
the prompt-level acceptance variance. DDTree remains useful on Dense
where adaptive is off (Dense DDTree b=22 dm=11: 45.13 = 1.32x over AR).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant