Skip to content

Fix save_model to include fine-tuning head weights#371

Merged
bputzeys merged 3 commits into
helicalAI:mainfrom
LiudengZhang:fix/save-load-fine-tuning-head
Apr 28, 2026
Merged

Fix save_model to include fine-tuning head weights#371
bputzeys merged 3 commits into
helicalAI:mainfrom
LiudengZhang:fix/save-load-fine-tuning-head

Conversation

@LiudengZhang
Copy link
Copy Markdown
Contributor

Summary

  • save_model() currently calls torch.save(self.model.state_dict(), path), which only persists backbone weights and silently drops the fine_tuning_head
  • Users who fine-tune, save, and reload get random head weights instead of trained ones
  • Affects all 7 model types (Geneformer, scGPT, UCE, HyenaDNA, Caduceus, HelixmRNA, Mamba2mRNA)
  • Changed to self.state_dict() which includes both backbone and head
  • Updated load_model to auto-detect three checkpoint formats for backward compatibility (full, backbone-only, legacy pickle)

Test plan

  • Fine-tune a model, save, reload — verify head weights are preserved
  • Load a v2.0.0 backbone-only checkpoint — verify warning is logged and backbone loads correctly
  • Load a pre-v2.0.0 legacy pickle checkpoint — verify backward compatibility

🤖 Generated with Claude Code

@dmiv-helical
Copy link
Copy Markdown
Contributor

Hi @LiudengZhang and welcome to the community!

Please open this PR against the main branch.
Also, this item in your Test plan:

Fine-tune a model, save, reload — verify head weights are preserved

is better as a unit test.

save_model only persisted self.model.state_dict(), silently
discarding the trained fine-tuning head (ClassificationHead /
RegressionHead) weights.  Switch to self.state_dict() so both
the backbone and the head are saved.

load_model now auto-detects the checkpoint format:
- full checkpoint (model + head keys)  -> self.load_state_dict()
- backbone-only (v2.0.0 checkpoint)    -> strict=False, warn
- legacy pickle (pre-v2.0.0)           -> extract & load backbone
@LiudengZhang LiudengZhang force-pushed the fix/save-load-fine-tuning-head branch from 28e5ba6 to 1235729 Compare April 28, 2026 01:10
@LiudengZhang LiudengZhang changed the base branch from release to main April 28, 2026 01:10
@LiudengZhang
Copy link
Copy Markdown
Contributor Author

Thanks for the welcome and the feedback! I've retargeted the PR to main and added a unit test that verifies fine-tuning head weights survive save/reload (sets all head params to a sentinel value, saves, loads into a fresh model, and asserts equality). Let me know if anything else needs adjusting.

Comment thread ci/tests/test_geneformer/test_fine_tuning.py Outdated
@bputzeys
Copy link
Copy Markdown
Collaborator

Thanks for this PR! Good one :)

@bputzeys bputzeys merged commit f133d5a into helicalAI:main Apr 28, 2026
4 checks passed
bputzeys added a commit that referenced this pull request Apr 28, 2026
…g head (#372)

* Merge pull request #371 from LiudengZhang/fix/save-load-fine-tuning-head

Fix save_model to include fine-tuning head weights

* Bump version from 2.0.1 to 2.0.2

---------

Co-authored-by: LiudengZhang <99156394+LiudengZhang@users.noreply.github.com>
@LiudengZhang
Copy link
Copy Markdown
Contributor Author

Thanks @bputzeys and @dmiv-helical for the reviews! Good catch on the comment — noted for next time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants