Add Sapiens2 + RTMDet (top-down pose pipeline)#1081
Open
Blaizzy wants to merge 4 commits into
Open
Conversation
…ration. Refactor weight transpose logic in Model class for clarity and correctness.
… READMEs
Adds a runnable matplotlib snippet for each task plus the rendered output
PNG, so the README shows what the predictor actually produces:
rtmdet/ detection.png — single-class person bbox overlay
sapiens2/ seg.png — 29-class body-part mask overlay
pose.png — 308-keypoint top-down pose (RTMDet detector)
pose_multi.png — multi-person pose with RF-DETR detector
normal.png — surface normals (XYZ→RGB)
pointmap.png — pointmap z-channel depth (turbo colormap)
Sample images (assets/sample.jpg, assets/crowd.jpg) are Unsplash photos
under the Unsplash license; visualizations were produced by running the
ported MLX models locally.
Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Ports two related dense-prediction models from Meta to mlx-vlm:
Sapiens2Predictor(it is not a generative VLM, so it does not flow throughmlx_vlm.generate).facebook/sapiens-pose-bbox-detector, the single-class person detector that is the first stage of Sapiens2's top-down pose pipeline.Together they enable the full top-down pose flow: RTMDet → per-person crops → Sapiens2 pose → stitched keypoints in original-image coords.
What's in the PR
mlx_vlm/models/sapiens2/vision.py— Sapiens2Backbone (PatchEmbed → cls + 8 storage tokens + 64×48 patches → 24/32/40/56 × TransformerBlock → RMSNorm → drop cls/storage). Partial GQA (full MHA in the outer 8 layers + GQA/2 inside), SwiGLU FFN.pos_embed_rope_dtype="bf16"default — running rope in fp32 inflates backbone max-diff from ~2.5e-3 to ~5e-2).heads.py— pose / seg / normal / pointmap heads. Manualpixel_shuffle()(no MLX builtin), and a channels-first transpose before the pointmap scale branch'sFlatten + Linearto match PT's(B, C, H, W)ordering (skipping it gives a wildly wrong scale).sapiens2.py— top-levelModelwith shape-heuristic, idempotentsanitize()somlx_vlm.utils.load()does not double-transpose conv weights when re-loading a sanitized checkpoint.processing_sapiens2.py— 1024×768 ImageNet-norm processor withsave_pretrainedfor round-tripping throughmlx_vlm.convert.generate.py—Sapiens2Predictorwith three pose modes: whole-image,detector=(RTMDet or RF-DETR), and BYOperson_boxes=. Crops to 3:4 with 25% margin, runs pose, stitches keypoints back to original-image coordinates.convert.py— HF → MLX (task + size inferred from repo name).readme.py— per-quant README generator.mlx_vlm/models/rtmdet/backbone.py— CSPNeXt (stem + 4 stages, stage4 includes SPPBottleneck), CSPLayer with optional ChannelAttention (hardsigmoid asclip((x+3)/6, 0, 1)).neck.py— CSPNeXtPAFPN (top-down + bottom-up + out_convs at strides 8/16/32).head.py—RTMDetSepBNHeadwith per-level cls/reg towers + 1×1 projections. Critically, appliesexp() * strideto the reg branch (exp_on_reg=Truefor RTMDet-m/l/x) — without this, predicted boxes collapse to a few pixels.processing_rtmdet.py— letterbox + ImageNet norm on[0, 255](mean/std × 255), matching mmdet'sDetDataPreprocessor.generate.py— multi-level point-based decode + class-agnostic NMS + letterbox-undo.convert.py—.pth→ MLX.mlx_vlm/utils.pyMODEL_REMAPPING(sapiens2,rtmdet).Closeness vs PyTorch
Synthetic input (fixed-seed
np.random.standard_normal((1, 3, 1024, 768)), fp32) on Sapiens2-0.4 B:Drift is dominated by the backbone (max ≈ 2.5e-3 on featmap output) and is amplified by the heads' spatial upsamplers; argmax-based outputs are unaffected. Real-image agreement: seg 99.99% pixel match, pose 94.5% identical kpt with median Δ = 0.0 px.
Quantization
Round-trips cleanly through
python -m mlx_vlm.convertfor all eight quant modes (affine 4/5/6/8-bit, mxfp4, mxfp8, nvfp4, plus bf16). Quality vs fp32 on real-image seg argmax: 8-bit 99.99 %, 4-bit 99.79 %, mxfp4 99.69 %, all > 99.6 %.Numerics notes (also in
sapiens2/README.md)(out, kH, kW, in)as usual;ConvTranspose2din deconv heads needs(out, kH, kW, in)from PT's(in, out, kH, kW), handled inModel.sanitizevia adeconv_layers.key hint.Sequential(Flatten, Linear)on(B, C, H, W)walks channels outermost; in MLX channels-last wetranspose(0, 3, 1, 2)before flattening.(C, r_h, r_w)+ transpose since MLX has no builtin.s[1] == s[2] and s[1] <= 16⇒ already MLX-layout) letsmlx_vlm.utils.load()be called repeatedly on already-sanitized checkpoints without double-transposing.Test plan
person_boxes=mlx_vlm.convertexp_on_regproduces tight bboxes (without it, boxes collapse)save_pretrainedround-trips throughmlx_vlm.convert🤖 Generated with Claude Code