Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
105f0e8
Enable AutoEP ZeRO-3 placement groups
tohtana May 20, 2026
4232ea3
Fix ZeRO-3 multi-gather wait forwarding
tohtana May 24, 2026
807085c
Gather AutoEP source ZeRO params
tohtana May 24, 2026
28e725a
Document constrained AutoEP ZeRO-3 support
tohtana Jun 11, 2026
5770e0f
Fix AutoEP zero.Init unit test import
tohtana Jun 11, 2026
573b525
Add 8 GPU AutoEP zero.Init validation
tohtana Jun 11, 2026
8c05ea7
Fix AutoEP ZeRO-3 expert gradient averaging
tohtana Jun 11, 2026
d7fac4a
Add partition-native AutoEP ZeRO-3 checkpoints
tohtana Jun 11, 2026
8f528ce
Address review findings for partition-native AutoEP checkpoints
tohtana Jun 12, 2026
b76fb57
Fail fast on pre-partitioned AutoEP expert params with mismatched groups
tohtana Jun 12, 2026
c177dbc
Update AutoEP checkpoint docs for ZeRO-3 partition-native support
tohtana Jun 12, 2026
fd4a861
Validate AutoEP universal topology loads
tohtana Jun 13, 2026
bac569d
Fix AutoEP topology fixture registration
tohtana Jun 13, 2026
6e7400d
Register AutoEP topology baseline as an explicit pytest fixture
tohtana Jun 13, 2026
fbb1bf9
Fix AutoEP topology universal router check
tohtana Jun 13, 2026
18569fc
Update config-json AutoEP ZeRO-3 universal topology-change note
tohtana Jun 13, 2026
50f6516
Support AutoEP universal module-only loads
tohtana Jun 13, 2026
f2d422d
Preserve fp32 masters for AutoEP module-only loads
tohtana Jun 13, 2026
a90deec
Address AutoEP ZeRO-3 review comments
tohtana Jun 23, 2026
dbcd098
Merge upstream master into AutoEP ZeRO-3 PR
tohtana Jun 23, 2026
dca6e1a
Merge branch 'master' into tohtana/autoep-zero3-zero-init
delock Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions deepspeed/checkpoint/autoep_zero3_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""Shared validation for AutoEP ZeRO-3 checkpoint metadata."""

from deepspeed.checkpoint.constants import (
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION,
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY,
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY,
AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT,
)

AUTOEP_METADATA_REQUIRED_FIELDS = frozenset({
'moe_layer_id',
'module_path',
'num_experts',
'num_local_experts',
'ep_size',
'expert_key_prefix',
})

AUTOEP_ZERO3_PARTITIONED_METADATA_FIELDS = frozenset({
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY,
'ep_group_name',
'ep_rank',
'expert_data_parallel_rank',
'expert_data_parallel_world_size',
'global_expert_start',
'global_expert_end',
})


def is_autoep_zero3_partitioned_entry(entry):
return (isinstance(entry, dict)
and entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT)


def validate_autoep_zero3_partitioned_metadata(autoep_metadata,
require_partitioned=True,
expected_expert_prefixes=None,
version_context="This DeepSpeed build"):
if not isinstance(autoep_metadata, list):
raise RuntimeError(f"ds_autoep_layers metadata is malformed: expected list, got "
f"{type(autoep_metadata).__name__}")

seen_layer_ids = set()
seen_prefixes = set()
partitioned_count = 0

for entry in autoep_metadata:
if not isinstance(entry, dict):
raise RuntimeError(f"ds_autoep_layers entry is malformed: expected dict, got "
f"{type(entry).__name__}")
missing = AUTOEP_METADATA_REQUIRED_FIELDS - entry.keys()
if missing:
raise RuntimeError(f"ds_autoep_layers entry is invalid: missing fields {sorted(missing)}")

layer_id = entry['moe_layer_id']
if layer_id in seen_layer_ids:
raise RuntimeError(f"ds_autoep_layers metadata has duplicate moe_layer_id: {layer_id}")
seen_layer_ids.add(layer_id)

prefix = entry['expert_key_prefix']
if prefix in seen_prefixes:
raise RuntimeError(f"ds_autoep_layers metadata has duplicate expert_key_prefix: {prefix}")
seen_prefixes.add(prefix)

if not is_autoep_zero3_partitioned_entry(entry):
continue

missing = AUTOEP_ZERO3_PARTITIONED_METADATA_FIELDS - entry.keys()
if missing:
raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata is invalid: missing fields {sorted(missing)}")
version = entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY]
if version != AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION:
raise RuntimeError("Unsupported AutoEP ZeRO-3 checkpoint format version: "
f"{version}. {version_context} supports version "
f"{AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION}.")

num_experts = entry['num_experts']
num_local_experts = entry['num_local_experts']
ep_size = entry['ep_size']
if num_local_experts * ep_size != num_experts:
raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata is inconsistent: "
f"num_local_experts={num_local_experts}, ep_size={ep_size}, "
f"num_experts={num_experts}")

expected_start = entry['ep_rank'] * num_local_experts
expected_end = expected_start + num_local_experts
if entry['global_expert_start'] != expected_start or entry['global_expert_end'] != expected_end:
raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has inconsistent global expert range: "
f"got [{entry['global_expert_start']}, {entry['global_expert_end']}), "
f"expected [{expected_start}, {expected_end})")

if expected_expert_prefixes is not None:
module_path = entry['module_path']
if module_path not in expected_expert_prefixes:
raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata references missing module: {module_path}")
expected_prefix = expected_expert_prefixes[module_path]
if prefix != expected_prefix:
raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has unexpected expert key prefix: "
f"got {prefix}, expected {expected_prefix}")

partitioned_count += 1

if require_partitioned and partitioned_count == 0:
raise RuntimeError("AutoEP ZeRO-3 partition-native checkpoint metadata was expected but no "
"partitioned AutoEP layer entries were found")
4 changes: 4 additions & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@
#########################################
AUTOEP_LAYERS_KEY = 'ds_autoep_layers'
AUTOEP_LAYERS_KEY_LEGACY = 'autoep_layers'
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY = 'checkpoint_format'
AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT = 'zero3_partitioned'
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY = 'checkpoint_format_version'
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION = 1

#########################################
# Universal Checkpoint EP keys
Expand Down
Loading
Loading