Skip to content

Add Sapiens2 + RTMDet (top-down pose pipeline)#1081

Open
Blaizzy wants to merge 4 commits into
mainfrom
pc/add-sapiens2
Open

Add Sapiens2 + RTMDet (top-down pose pipeline)#1081
Blaizzy wants to merge 4 commits into
mainfrom
pc/add-sapiens2

Conversation

@Blaizzy
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy commented Apr 26, 2026

Summary

Ports two related dense-prediction models from Meta to mlx-vlm:

  • Sapiens2 — high-resolution ViT family pretrained on 1B human images. Backbone is shared across four task heads (pose / seg / normal / pointmap) × four sizes (0.4 B / 0.8 B / 1 B / 5 B). Inference goes through Sapiens2Predictor (it is not a generative VLM, so it does not flow through mlx_vlm.generate).
  • RTMDetfacebook/sapiens-pose-bbox-detector, the single-class person detector that is the first stage of Sapiens2's top-down pose pipeline.

Together they enable the full top-down pose flow: RTMDet → per-person crops → Sapiens2 pose → stitched keypoints in original-image coords.

What's in the PR

mlx_vlm/models/sapiens2/

  • vision.py — Sapiens2Backbone (PatchEmbed → cls + 8 storage tokens + 64×48 patches → 24/32/40/56 × TransformerBlock → RMSNorm → drop cls/storage). Partial GQA (full MHA in the outer 8 layers + GQA/2 inside), SwiGLU FFN.
  • 2-D RoPE in bf16 (matches PT's pos_embed_rope_dtype="bf16" default — running rope in fp32 inflates backbone max-diff from ~2.5e-3 to ~5e-2).
  • heads.py — pose / seg / normal / pointmap heads. Manual pixel_shuffle() (no MLX builtin), and a channels-first transpose before the pointmap scale branch's Flatten + Linear to match PT's (B, C, H, W) ordering (skipping it gives a wildly wrong scale).
  • sapiens2.py — top-level Model with shape-heuristic, idempotent sanitize() so mlx_vlm.utils.load() does not double-transpose conv weights when re-loading a sanitized checkpoint.
  • processing_sapiens2.py — 1024×768 ImageNet-norm processor with save_pretrained for round-tripping through mlx_vlm.convert.
  • generate.pySapiens2Predictor with three pose modes: whole-image, detector= (RTMDet or RF-DETR), and BYO person_boxes=. Crops to 3:4 with 25% margin, runs pose, stitches keypoints back to original-image coordinates.
  • convert.py — HF → MLX (task + size inferred from repo name).
  • readme.py — per-quant README generator.

mlx_vlm/models/rtmdet/

  • backbone.py — CSPNeXt (stem + 4 stages, stage4 includes SPPBottleneck), CSPLayer with optional ChannelAttention (hardsigmoid as clip((x+3)/6, 0, 1)).
  • neck.py — CSPNeXtPAFPN (top-down + bottom-up + out_convs at strides 8/16/32).
  • head.pyRTMDetSepBNHead with per-level cls/reg towers + 1×1 projections. Critically, applies exp() * stride to the reg branch (exp_on_reg=True for RTMDet-m/l/x) — without this, predicted boxes collapse to a few pixels.
  • processing_rtmdet.py — letterbox + ImageNet norm on [0, 255] (mean/std × 255), matching mmdet's DetDataPreprocessor.
  • generate.py — multi-level point-based decode + class-agnostic NMS + letterbox-undo.
  • convert.py.pth → MLX.

mlx_vlm/utils.py

  • Two lines added to MODEL_REMAPPING (sapiens2, rtmdet).

Closeness vs PyTorch

Synthetic input (fixed-seed np.random.standard_normal((1, 3, 1024, 768)), fp32) on Sapiens2-0.4 B:

task max |Δ| mean |Δ| agreement
seg 1.96e-01 4.77e-04 100.00% argmax match
pose 5.05e-03 1.91e-05 96.10% kpt-argmax
normal ~1.0 3.92e-03 dense regression
pointmap 3.79e-01 4.00e-04 scale Δ ≈ 1e-4

Drift is dominated by the backbone (max ≈ 2.5e-3 on featmap output) and is amplified by the heads' spatial upsamplers; argmax-based outputs are unaffected. Real-image agreement: seg 99.99% pixel match, pose 94.5% identical kpt with median Δ = 0.0 px.

Quantization

Round-trips cleanly through python -m mlx_vlm.convert for all eight quant modes (affine 4/5/6/8-bit, mxfp4, mxfp8, nvfp4, plus bf16). Quality vs fp32 on real-image seg argmax: 8-bit 99.99 %, 4-bit 99.79 %, mxfp4 99.69 %, all > 99.6 %.

Numerics notes (also in sapiens2/README.md)

  • RoPE in bf16 — matches PT's default; fp32 rope inflates max-diff ~20×.
  • Conv weight layouts — Conv2d (out, kH, kW, in) as usual; ConvTranspose2d in deconv heads needs (out, kH, kW, in) from PT's (in, out, kH, kW), handled in Model.sanitize via a deconv_layers. key hint.
  • Channels-first Flatten — PT's Sequential(Flatten, Linear) on (B, C, H, W) walks channels outermost; in MLX channels-last we transpose(0, 3, 1, 2) before flattening.
  • PixelShuffle — manual reshape (C, r_h, r_w) + transpose since MLX has no builtin.
  • Idempotent sanitize — shape heuristic (s[1] == s[2] and s[1] <= 16 ⇒ already MLX-layout) lets mlx_vlm.utils.load() be called repeatedly on already-sanitized checkpoints without double-transposing.

Test plan

  • Verify Sapiens2 forward parity vs PT for all 4 tasks (synthetic + real image)
  • Verify pose top-down pipeline with RTMDet detector (single image, multi-person crowd)
  • Verify pose top-down pipeline with RF-DETR detector
  • Verify pose with BYO person_boxes=
  • Round-trip all 8 quant modes through mlx_vlm.convert
  • Verify quantized seg argmax > 99 % match vs fp32
  • Verify RTMDet exp_on_reg produces tight bboxes (without it, boxes collapse)
  • Verify processor save_pretrained round-trips through mlx_vlm.convert

🤖 Generated with Claude Code

Blaizzy and others added 4 commits April 24, 2026 14:01
…ration. Refactor weight transpose logic in Model class for clarity and correctness.
… READMEs

Adds a runnable matplotlib snippet for each task plus the rendered output
PNG, so the README shows what the predictor actually produces:

  rtmdet/   detection.png      — single-class person bbox overlay
  sapiens2/ seg.png            — 29-class body-part mask overlay
            pose.png           — 308-keypoint top-down pose (RTMDet detector)
            pose_multi.png     — multi-person pose with RF-DETR detector
            normal.png         — surface normals (XYZ→RGB)
            pointmap.png       — pointmap z-channel depth (turbo colormap)

Sample images (assets/sample.jpg, assets/crowd.jpg) are Unsplash photos
under the Unsplash license; visualizations were produced by running the
ported MLX models locally.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant