Skip to content

Strict checkpoint restore not working as expected #2972

@AakashKumarNain

Description

@AakashKumarNain

I have a simple flow for my use case where I am using orbax with pure JAX to save and load models. Here is the project.

The Flow

  1. Build a model (a PyTree) in JAX using a config file (e.g. size of layers, number of layers, etc.)
  2. Train the model and save the checkpoint using orbax
  3. Restore the checkpoint for inference or for continued training.

Problem

Say, I built a model that has 12 layers, trained it and saved the checkpoint. After a few days, I train a bigger model, say a model with 16 layers. If I forget to provide the correct checkpoint path for inference, orbax instead of complaining about the mismatch of sizes in the current model instance and the checkpoint, silently loads partial weights and never errors out. This is an undesirable behavior and can be a serious footgun for anyone who trains a lots of models and have lots of checkpoints for a use case.

Here is the code I am using to restore the checkpoint. Can anything be done about it? I want it to error out if there is a mismatch instead of silently restoring weights partially

def get_sharding_for_checkpoint(x, mesh):
    """Obtain shardings for the leaves of the pytree."""
    if hasattr(x, "ndim") and x.ndim == 0:
        return NamedSharding(mesh, P())
    if isinstance(x, jax.Array) and hasattr(x, "sharding"):
        from jax.sharding import SingleDeviceSharding
        # Ensure small optimizer leaves (e.g., Muon scalars/vectors) are replicated,
        # not left on a single device, to match model param shardings during train_step.
        if isinstance(x.sharding, SingleDeviceSharding):
            return NamedSharding(mesh, P())
        return x.sharding
    else:
        return NamedSharding(mesh, P())

def load_weights_from_checkpoint(path, params, mesh):
    print(f"Restoring params from: {path}")

    # 1. Create an abstract 'target' structure for validation
    abstract_params = jax.tree.map(ocp.utils.to_shape_dtype_struct, params)

    restore_args = jax.tree.map(
        lambda leaf: ocp.ArrayRestoreArgs(
        sharding=get_sharding_for_checkpoint(leaf, mesh), strict=True),
        params,
    )

    with ocp.PyTreeCheckpointer() as ckptr:
        restored = ckptr.restore(
            path,
            args=ocp.args.PyTreeRestore(
                # Use the abstract structure here to trigger validation
                item=abstract_params, 
                restore_args=restore_args,
            ),
        )
    return restored

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions