Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
513659f
Add initial inference server
yachty66 Apr 9, 2026
9691cd9
Add custom inference engine (GatedDeltaNet + GQA + MoE)
yachty66 Apr 9, 2026
db39fd7
Fix q_norm/k_norm: use RMSNorm instead of LayerNorm (no bias)
yachty66 Apr 9, 2026
54503ef
Fix linear attn einsum: keep val-sub-head dim in outer product
yachty66 Apr 9, 2026
0e9c54c
Fix dtype mismatch: cast softmax output back to bfloat16
yachty66 Apr 9, 2026
417dd36
Fix generate: return [1, num_generated] via stack not cat
yachty66 Apr 9, 2026
c057806
Fix causal conv1d (manual left-pad) and add Q normalization
yachty66 Apr 9, 2026
babeb0b
Fix conv_buf: save from padded input not raw qkv_t
yachty66 Apr 9, 2026
1e8eddc
Add debug comparison script (HF vs our model)
yachty66 Apr 9, 2026
e84abfd
Add layer-by-layer debug comparison script
yachty66 Apr 9, 2026
7cb8c42
Add source inspection script
yachty66 Apr 9, 2026
0687f71
Add source2 inspection script
yachty66 Apr 9, 2026
244e34b
Fix all bugs found from source: RMSNorm formula, conv silu, Q/K expan…
yachty66 Apr 9, 2026
2f08440
Add layer 0 sublayer debug script
yachty66 Apr 9, 2026
f3d656d
Add GatedDeltaNet internals debug script
yachty66 Apr 9, 2026
61754ea
Fix syntax error in debug_deltanet.py
yachty66 Apr 9, 2026
da80400
Fix delta rule: decay state first, then compute delta from decayed state
yachty66 Apr 9, 2026
5bcc748
Fix debug script delta rule order to match fix
yachty66 Apr 9, 2026
4b42aed
Dump HF chunk_gated_delta_rule torch fallback source
yachty66 Apr 9, 2026
9cd1d28
Fix delta rule: scale query by 1/sqrt(d_k) before recurrence
yachty66 Apr 9, 2026
5aee84e
Stop generation on <|im_end|> as well as <|endoftext|>
yachty66 Apr 9, 2026
6722f73
Use FLA chunk kernel for prefill + data-parallel across all GPUs
yachty66 Apr 9, 2026
cae35f9
Cast FLA kernel outputs back to input dtype
yachty66 Apr 9, 2026
647ccbb
Load model replicas across GPUs in parallel
yachty66 Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions debug_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Compare our custom model output against HF reference layer by layer.
Run with: python debug_compare.py
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "/mnt/data/GPUINFERENCE/Qwen3.5-35B-A3B"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
text = tokenizer.apply_chat_template(
[{"role": "user", "content": "What is 2+2?"}],
tokenize=False, add_generation_prompt=True, enable_thinking=False,
)
input_ids = tokenizer(text, return_tensors="pt").input_ids.cuda()
print(f"Input ids shape: {input_ids.shape}, tokens: {input_ids[0].tolist()}")

# ── HF reference ──────────────────────────────────────────────────────────────
print("\nLoading HF reference model...")
hf = AutoModelForCausalLM.from_pretrained(MODEL_PATH, dtype=torch.bfloat16, device_map="cuda")
hf.eval()

with torch.no_grad():
hf_out = hf(input_ids)
hf_logits = hf_out.logits[:, -1, :] # [1, V]

hf_top = hf_logits.topk(5)
print(f"HF top-5 token ids : {hf_top.indices[0].tolist()}")
print(f"HF top-5 tokens : {[tokenizer.decode([i]) for i in hf_top.indices[0].tolist()]}")
print(f"HF has NaN: {hf_logits.isnan().any().item()}")

del hf
torch.cuda.empty_cache()

# ── Our model ─────────────────────────────────────────────────────────────────
print("\nLoading our model...")
from server.model.loader import load_model
from server.model.rope import build_rope_freqs

our = load_model(MODEL_PATH, device="cuda")

T = input_ids.shape[1]
inv_freq = build_rope_freqs(T + 100, input_ids.device)
position_ids = torch.arange(T, device=input_ids.device).unsqueeze(0)

with torch.no_grad():
our_logits, _, _, _ = our.forward_step(
input_ids, position_ids, inv_freq,
[None] * 40, [None] * 40, [None] * 40,
)
our_logits = our_logits[:, -1, :]

our_top = our_logits.topk(5)
print(f"Our top-5 token ids: {our_top.indices[0].tolist()}")
print(f"Our top-5 tokens : {[tokenizer.decode([i]) for i in our_top.indices[0].tolist()]}")
print(f"Our has NaN: {our_logits.isnan().any().item()}")
print(f"Our has Inf: {our_logits.isinf().any().item()}")

# ── Summary ───────────────────────────────────────────────────────────────────
max_diff = (hf_logits.float() - our_logits.float()).abs().max().item()
print(f"\nMax logit diff (HF vs ours): {max_diff:.4f}")
print(f"HF argmax: {hf_logits.argmax().item()} -> '{tokenizer.decode([hf_logits.argmax().item()])}'")
print(f"Our argmax: {our_logits.argmax().item()} -> '{tokenizer.decode([our_logits.argmax().item()])}'")
173 changes: 173 additions & 0 deletions debug_deltanet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""
Deep dive into GatedDeltaNet layer 0 internals.
Hooks on HF model's linear_attn to compare intermediate tensors.
"""
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "/mnt/data/GPUINFERENCE/Qwen3.5-35B-A3B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
text = tokenizer.apply_chat_template(
[{"role": "user", "content": "What is 2+2?"}],
tokenize=False, add_generation_prompt=True, enable_thinking=False,
)
input_ids = tokenizer(text, return_tensors="pt").input_ids.cuda()
B, T = input_ids.shape

print(f"Input: {B}x{T} tokens")

# ── Monkey-patch HF GatedDeltaNet forward to capture intermediates ────────────
print("Loading HF model...")
hf = AutoModelForCausalLM.from_pretrained(MODEL_PATH, dtype=torch.bfloat16, device_map="cuda")
hf.eval()

hf_vals = {}
la = hf.model.layers[0].linear_attn # the GatedDeltaNet module

orig_forward = la.forward.__func__

def patched_forward(self, hidden_states, cache_params=None, attention_mask=None):
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import apply_mask_to_padding_states
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape

mixed_qkv = self.in_proj_qkv(hidden_states).transpose(1, 2)
z = self.in_proj_z(hidden_states)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)

# conv + silu
mixed_qkv_conv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
hf_vals['qkv_after_conv'] = mixed_qkv_conv.transpose(1,2).detach().clone()

mixed_qkv_conv = mixed_qkv_conv.transpose(1, 2)
query, key, value = torch.split(mixed_qkv_conv, [self.key_dim, self.key_dim, self.value_dim], dim=-1)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)

beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)

hf_vals['beta'] = beta.detach().clone()
hf_vals['g'] = g.detach().clone()
hf_vals['z'] = z.detach().clone()

if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

hf_vals['q_expanded'] = query.detach().clone()
hf_vals['k_expanded'] = key.detach().clone()
hf_vals['v'] = value.detach().clone()

core_out, _ = self.chunk_gated_delta_rule(
query, key, value, g=g, beta=beta,
initial_state=None, output_final_state=False,
use_qk_l2norm_in_kernel=True,
)
hf_vals['core_out'] = core_out.detach().clone()

core_out = core_out.reshape(-1, self.head_v_dim)
z2 = z.reshape(batch_size, seq_len, -1, self.head_v_dim).reshape(-1, self.head_v_dim)
core_out = self.norm(core_out, z2)
hf_vals['normed_out'] = core_out.reshape(batch_size, seq_len, -1).detach().clone()

return self.out_proj(core_out.reshape(batch_size, seq_len, -1))

import types
la.forward = types.MethodType(patched_forward, la)

with torch.no_grad():
x_embed = hf.model.embed_tokens(input_ids)
ln_out = hf.model.layers[0].input_layernorm(x_embed)
la(ln_out)

del hf
torch.cuda.empty_cache()

# ── Our model ─────────────────────────────────────────────────────────────────
print("Loading our model...")
from server.model.loader import load_model
from server.model.rope import build_rope_freqs
import torch.nn.functional as F

our = load_model(MODEL_PATH, device="cuda")
our.eval()

inv_freq = build_rope_freqs(T + 100, input_ids.device)
position_ids = torch.arange(T, device=input_ids.device).unsqueeze(0)

la_our = our.layers[0].attn
from server.model.linear_attn import _CONV_DIM, _K_DIM, _V_DIM, _EXPAND, _EXP_HEADS, _HEAD_DIM, LINEAR_CONV_KERNEL

def diff(name, hf_t, our_t):
hf_t = hf_t.float(); our_t = our_t.float()
d = (hf_t - our_t).abs().max().item()
rel = d / (hf_t.abs().mean().item() + 1e-8)
print(f" {name:<25} HF_norm={hf_t.norm():.3f} Our_norm={our_t.norm():.3f} MaxDiff={d:.4f} {'OK' if d < 0.1 else 'FAIL'}")

with torch.no_grad():
x = our.embed_tokens(input_ids)
h = our.layers[0].input_layernorm(x)
k = LINEAR_CONV_KERNEL

qkv = la_our.in_proj_qkv(h)
z = la_our.in_proj_z(h)
b = la_our.in_proj_b(h)
a = la_our.in_proj_a(h)

# Conv
qkv_t = qkv.transpose(1,2)
pad = qkv_t.new_zeros(B, _CONV_DIM, k-1)
qkv_padded = torch.cat([pad, qkv_t], dim=2)
qkv_conv = F.conv1d(qkv_padded, la_our.conv1d_weight, groups=_CONV_DIM)
qkv_conv = F.silu(qkv_conv)
our_qkv_conv = qkv_conv.transpose(1,2)

print("\n--- GatedDeltaNet internals ---")
diff("qkv_after_conv", hf_vals['qkv_after_conv'], our_qkv_conv)

# Split
q = our_qkv_conv[..., :_K_DIM].view(B, T, 16, _HEAD_DIM)
kk = our_qkv_conv[..., _K_DIM:_K_DIM*2].view(B, T, 16, _HEAD_DIM)
v = our_qkv_conv[..., _K_DIM*2:].view(B, T, 32, _HEAD_DIM)

q = q.repeat_interleave(_EXPAND, dim=2)
kk = kk.repeat_interleave(_EXPAND, dim=2)

diff("q_expanded", hf_vals['q_expanded'], q)
diff("k_expanded", hf_vals['k_expanded'], kk)
diff("v", hf_vals['v'], v)

# beta and decay
beta = torch.sigmoid(b.float()).to(h.dtype)
g_our = (-la_our.A_log.float().exp() * F.softplus(a.float() + la_our.dt_bias.float()))
decay = g_our.exp().to(h.dtype)

diff("beta", hf_vals['beta'], beta)
diff("g (log decay)", hf_vals['g'].to(h.dtype), g_our.to(h.dtype))
diff("z", hf_vals['z'], z)

# Sequential delta rule (decay first, then delta from decayed state)
q_n = F.normalize(q.float(), dim=-1).to(h.dtype)
kk_n = F.normalize(kk.float(), dim=-1).to(h.dtype)
q_n = q_n * (_HEAD_DIM ** -0.5) # HF scales query by 1/sqrt(d_k) inside the kernel
state = h.new_zeros(B, _EXP_HEADS, _HEAD_DIM, _HEAD_DIM)
outputs = []
for t in range(T):
q_t = q_n[:, t]; k_t = kk_n[:, t]; v_t = v[:, t]
b_t = beta[:, t]; d_t = decay[:, t]
state = state * d_t.unsqueeze(-1).unsqueeze(-1) # decay first
Sk = torch.einsum("bhi,bhij->bhj", k_t, state) # delta from decayed state
delta = v_t - Sk
update = torch.einsum("bhi,bhj->bhij", k_t, b_t.unsqueeze(-1) * delta)
state = state + update
outputs.append(torch.einsum("bhi,bhij->bhj", q_t, state))

core_out = torch.stack(outputs, dim=1) # [B, T, 32, 128]
diff("core_out", hf_vals['core_out'], core_out)

print("\nDone.")
72 changes: 72 additions & 0 deletions debug_fla_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Dump the torch fallback of chunk_gated_delta_rule that HF is actually calling.

We can do this WITHOUT loading the 35B model — just import the class and
introspect the method (and follow it through any wrappers / module-level
fallbacks).
"""
import inspect
import sys

from transformers.models.qwen3_5_moe import modeling_qwen3_5_moe as M

print("=" * 70)
print(f"modeling file: {M.__file__}")
print("=" * 70)

# 1. Find the GatedDeltaNet class
gdn_cls = None
for name in dir(M):
obj = getattr(M, name)
if inspect.isclass(obj) and "GatedDeltaNet" in name:
gdn_cls = obj
print(f"Found class: {name}")
break

if gdn_cls is None:
print("ERROR: no GatedDeltaNet class found in modeling file")
sys.exit(1)

# 2. Inspect chunk_gated_delta_rule on the class
fn = getattr(gdn_cls, "chunk_gated_delta_rule", None)
print(f"\nchunk_gated_delta_rule attr: {fn!r}")
print(f"type: {type(fn)}")

if fn is not None:
try:
print(f"defined in: {inspect.getsourcefile(fn)}")
except Exception as e:
print(f" (no sourcefile: {e})")
try:
print("\n--- SOURCE ---")
print(inspect.getsource(fn))
except Exception as e:
print(f" (no source: {e})")

# 3. Also dump every module-level callable whose name mentions delta/recurrent/chunk
print("\n" + "=" * 70)
print("Module-level helpers (delta / recurrent / chunk):")
print("=" * 70)
for name in dir(M):
obj = getattr(M, name)
if callable(obj) and any(s in name.lower() for s in ("delta", "recurrent", "chunk")):
print(f"\n--- {name} ({type(obj).__name__}) ---")
try:
print(f"defined in: {inspect.getsourcefile(obj)}")
except Exception as e:
print(f" (no sourcefile: {e})")
try:
print(inspect.getsource(obj))
except Exception as e:
print(f" (no source: {e})")

# 4. Show all imports of the modeling file so we can see what it pulls from fla
print("\n" + "=" * 70)
print("Imports in modeling file:")
print("=" * 70)
import ast
import pathlib
src = pathlib.Path(M.__file__).read_text()
tree = ast.parse(src)
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
print(ast.unparse(node))
Loading