-
Notifications
You must be signed in to change notification settings - Fork 85
Description
I have a simple flow for my use case where I am using orbax with pure JAX to save and load models. Here is the project.
The Flow
- Build a model (a PyTree) in JAX using a config file (e.g. size of layers, number of layers, etc.)
- Train the model and save the checkpoint using orbax
- Restore the checkpoint for inference or for continued training.
Problem
Say, I built a model that has 12 layers, trained it and saved the checkpoint. After a few days, I train a bigger model, say a model with 16 layers. If I forget to provide the correct checkpoint path for inference, orbax instead of complaining about the mismatch of sizes in the current model instance and the checkpoint, silently loads partial weights and never errors out. This is an undesirable behavior and can be a serious footgun for anyone who trains a lots of models and have lots of checkpoints for a use case.
Here is the code I am using to restore the checkpoint. Can anything be done about it? I want it to error out if there is a mismatch instead of silently restoring weights partially
def get_sharding_for_checkpoint(x, mesh):
"""Obtain shardings for the leaves of the pytree."""
if hasattr(x, "ndim") and x.ndim == 0:
return NamedSharding(mesh, P())
if isinstance(x, jax.Array) and hasattr(x, "sharding"):
from jax.sharding import SingleDeviceSharding
# Ensure small optimizer leaves (e.g., Muon scalars/vectors) are replicated,
# not left on a single device, to match model param shardings during train_step.
if isinstance(x.sharding, SingleDeviceSharding):
return NamedSharding(mesh, P())
return x.sharding
else:
return NamedSharding(mesh, P())
def load_weights_from_checkpoint(path, params, mesh):
print(f"Restoring params from: {path}")
# 1. Create an abstract 'target' structure for validation
abstract_params = jax.tree.map(ocp.utils.to_shape_dtype_struct, params)
restore_args = jax.tree.map(
lambda leaf: ocp.ArrayRestoreArgs(
sharding=get_sharding_for_checkpoint(leaf, mesh), strict=True),
params,
)
with ocp.PyTreeCheckpointer() as ckptr:
restored = ckptr.restore(
path,
args=ocp.args.PyTreeRestore(
# Use the abstract structure here to trigger validation
item=abstract_params,
restore_args=restore_args,
),
)
return restored