From 105f0e8c45ecea7d608ee43dfcb498c197a8dbb4 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 20 May 2026 16:40:51 -0700 Subject: [PATCH 01/19] Enable AutoEP ZeRO-3 placement groups Signed-off-by: Masahiro Tanaka --- deepspeed/checkpoint/ds_to_universal.py | 52 ++-- deepspeed/module_inject/auto_ep_config.py | 5 + deepspeed/module_inject/auto_ep_layer.py | 5 + .../module_inject/auto_ep_presets/base.py | 1 + deepspeed/runtime/engine.py | 176 +++++++++--- deepspeed/runtime/pipe/engine.py | 12 +- .../runtime/zero/partition_parameters.py | 152 +++++++---- deepspeed/runtime/zero/stage3.py | 250 +++++++++++++----- docs/_pages/config-json.md | 11 +- docs/_tutorials/universal-checkpointing.md | 21 +- tests/unit/v1/moe/test_autoep_integration.py | 134 +++++++++- tests/unit/v1/moe/test_autoep_unit.py | 220 +++++++++++++++ 12 files changed, 842 insertions(+), 197 deletions(-) diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 5c392ca52ec2..7d8c9d728bad 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -43,6 +43,9 @@ PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, PARAMETER_WITH_SUB_PARAMS, + AUTOEP_LAYERS_KEY, + AUTOEP_LAYERS_KEY_LEGACY, + EXPERT_PARAMETER_PATTERNS, SubparamShape, ) @@ -436,6 +439,32 @@ def _get_model_state_files(checkpoint_dir): return _get_checkpoint_files(checkpoint_dir, "*_model_states.pt") +def _is_expert_model_state_file(checkpoint_file): + basename = os.path.basename(checkpoint_file) + return basename.startswith('layer_') and '_expert_' in basename + + +def _get_zero3_model_state_files(checkpoint_dir): + model_files = [f for f in _get_model_state_files(checkpoint_dir) if not _is_expert_model_state_file(f)] + + if len(model_files) == 0: + raise FileNotFoundError(f"can't find ZeRO Stage 3 model state files in directory '{checkpoint_dir}'") + + return model_files + + +def _raise_if_stage3_autoep_universal_conversion(model_files): + for model_file in model_files: + model_state = torch.load(model_file, map_location=torch.device('cpu'), weights_only=False) + autoep_metadata = model_state.get(AUTOEP_LAYERS_KEY) + if autoep_metadata is None: + autoep_metadata = model_state.get(AUTOEP_LAYERS_KEY_LEGACY) + + if autoep_metadata is not None: + raise NotImplementedError("Stage 3 universal checkpoint conversion with AutoEP is not supported. " + "Use regular same-topology ZeRO-3 checkpoint load for AutoEP checkpoints.") + + def _get_checkpoint_files(checkpoint_dir, glob_pattern): ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) @@ -510,11 +539,6 @@ def main(args): _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) print('*** 2.5. Consolidating AutoEP expert files') - from deepspeed.checkpoint.constants import ( - AUTOEP_LAYERS_KEY, - AUTOEP_LAYERS_KEY_LEGACY, - EXPERT_PARAMETER_PATTERNS, - ) from deepspeed.checkpoint.autoep_universal import ( consolidate_autoep_expert_files, consolidate_autoep_optimizer_states, @@ -563,19 +587,8 @@ def main(args): else: # Stage 3 path - # Check for AutoEP metadata - Stage 3 + AutoEP is not supported - stage3_expert_files = glob.glob(os.path.join(args.input_folder, 'layer_*_expert_*_model_states.pt')) - stage3_model_files_for_meta = glob.glob(os.path.join(args.input_folder, 'mp_rank_*_model_states.pt')) - if stage3_model_files_for_meta: - _stage3_sd = torch.load(stage3_model_files_for_meta[0], - map_location=torch.device('cpu'), - weights_only=False) - _stage3_autoep = _stage3_sd.get('ds_autoep_layers') or _stage3_sd.get('autoep_layers') - if _stage3_autoep is not None: - raise NotImplementedError("Stage 3 universal checkpoint conversion with AutoEP is not supported. " - "AutoEP currently requires ZeRO Stage 1 or 2.") - - model_files = _get_model_state_files(args.input_folder) + model_files = _get_zero3_model_state_files(args.input_folder) + _raise_if_stage3_autoep_universal_conversion(model_files) param_shapes = _parse_model_states_stage3(model_files) dp_degree = len(model_files) @@ -596,8 +609,7 @@ def main(args): # Copy *model_states files into output folder, filtering out expert files for f in glob.glob(os.path.join(args.input_folder, '*model_states.pt')): - basename = os.path.basename(f) - if basename.startswith('layer_') and '_expert_' in basename: + if _is_expert_model_state_file(f): continue # Skip expert files (handled separately if AutoEP were supported) shutil.copy2(f, args.output_folder) diff --git a/deepspeed/module_inject/auto_ep_config.py b/deepspeed/module_inject/auto_ep_config.py index a0e6e3a9b36b..ca07156c3ae4 100644 --- a/deepspeed/module_inject/auto_ep_config.py +++ b/deepspeed/module_inject/auto_ep_config.py @@ -45,6 +45,7 @@ def parse_autoep_config(param_dict: dict) -> AutoEPConfig: config = AutoEPConfig() config.enabled = param_dict.get("enabled", False) config.autoep_size = param_dict.get("autoep_size", 1) + config.expert_tensor_parallel_size = param_dict.get("expert_tensor_parallel_size", 1) config.preset_model = param_dict.get("preset_model", None) config.moe_layer_pattern = param_dict.get("moe_layer_pattern", None) config.expert_pattern = param_dict.get("expert_pattern", None) @@ -108,6 +109,10 @@ def validate_autoep_config( f"(tensor_parallel.autotp_size={tp_size}). Disable AutoTP for this run; " "AutoEP+AutoTP support is planned as follow-up work.") + if config.expert_tensor_parallel_size != 1: + raise ValueError("AutoEP only supports expert_parallel.expert_tensor_parallel_size=1 in this release; " + "expert tensor parallelism is planned as follow-up work.") + # ep_size must divide the stage size (world_size / pp_size) stage_size = world_size // pp_size if stage_size % config.autoep_size != 0: diff --git a/deepspeed/module_inject/auto_ep_layer.py b/deepspeed/module_inject/auto_ep_layer.py index f4bba73d3d9e..8834eaadb4d4 100644 --- a/deepspeed/module_inject/auto_ep_layer.py +++ b/deepspeed/module_inject/auto_ep_layer.py @@ -456,16 +456,21 @@ def __init__( for param in self.experts.parameters(): param.allreduce = False param.group_name = self.ep_group_name + param.ds_zero_placement_family = "autoep_expert" + param.ds_zero_partition_group_name = self.ep_group_name # Mark shared expert and router params for global DP reduction for param in self.router.parameters(): param.allreduce = True + param.ds_zero_placement_family = "replicated" if self.shared_experts is not None: for param in self.shared_experts.parameters(): param.allreduce = True + param.ds_zero_placement_family = "replicated" if self.shared_experts_gate is not None: for param in self.shared_experts_gate.parameters(): param.allreduce = True + param.ds_zero_placement_family = "replicated" # Load balancing buffers self.load_balance_coeff = resolved_config.load_balance_coeff diff --git a/deepspeed/module_inject/auto_ep_presets/base.py b/deepspeed/module_inject/auto_ep_presets/base.py index 342e6ff1abb5..c023498109c9 100644 --- a/deepspeed/module_inject/auto_ep_presets/base.py +++ b/deepspeed/module_inject/auto_ep_presets/base.py @@ -98,6 +98,7 @@ class AutoEPConfig: enabled: bool = False autoep_size: int = 1 + expert_tensor_parallel_size: int = 1 preset_model: str | None = None moe_layer_pattern: str | None = None expert_pattern: str | None = None diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7708999fcdf7..aa063879ac1e 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1599,6 +1599,72 @@ def _configure_distributed_model(self, model): if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() + def _validate_zero3_moe_compatibility(self): + if not self.has_moe_layers: + return + + try: + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer + except ImportError: + _AutoEPMoELayer = None + + autoep_layers = [] + native_moe_layers = [] + for name, module in self.module.named_modules(): + if isinstance(module, MoE): + native_moe_layers.append(name) + elif _AutoEPMoELayer is not None and isinstance(module, _AutoEPMoELayer): + autoep_layers.append(name) + + if native_moe_layers: + raise AssertionError("Native DeepSpeed MoE is not supported with ZeRO Stage 3. " + "Use AutoEP or choose ZeRO stage 1/2.") + if not autoep_layers: + raise AssertionError("MoE not supported with Stage 3") + autotp_size = self.autotp_size() + if autotp_size not in (0, 1): + raise AssertionError("AutoEP with ZeRO Stage 3 does not support AutoTP yet " + f"(tensor_parallel.autotp_size={autotp_size}).") + if self.sequence_parallel_size != 1: + raise AssertionError("AutoEP with ZeRO Stage 3 does not support sequence parallelism yet " + f"(sequence_parallel_size={self.sequence_parallel_size}).") + if self.zero_quantized_gradients(): + raise AssertionError("AutoEP with ZeRO Stage 3 does not support zero_quantized_gradients or LoCo " + "quantized gradients yet.") + mics_shard_size = getattr(self._config, "mics_shard_size", 0) + if mics_shard_size > 0: + raise AssertionError("AutoEP with ZeRO Stage 3 does not support MiCS yet " + f"(mics_shard_size={mics_shard_size}).") + hpz_partition_size = getattr(getattr(self._config, "zero_config", None), "zero_hpz_partition_size", 1) + if hpz_partition_size > 1: + raise AssertionError("AutoEP with ZeRO Stage 3 does not support hpZeRO secondary tensor groups yet " + f"(zero_optimization.zero_hpz_partition_size={hpz_partition_size}).") + + expert_tp_size = getattr(self._config.expert_parallel_config, "expert_tensor_parallel_size", 1) + if expert_tp_size != 1: + raise AssertionError("AutoEP with ZeRO Stage 3 only supports " + "expert_parallel.expert_tensor_parallel_size=1.") + + def _resolve_zero3_param_placement(self): + for name, param in self.module.named_parameters(): + family = getattr(param, "ds_zero_placement_family", "replicated") + if family == "autoep_expert": + group_name = getattr(param, "ds_zero_partition_group_name", getattr(param, "group_name", None)) + if group_name is None: + raise AssertionError(f"AutoEP expert parameter '{name}' is missing a ZeRO partition group name.") + partition_group = groups._get_expert_data_parallel_group(group_name) + elif family == "replicated": + group_name = None + partition_group = self.seq_data_parallel_group + else: + raise AssertionError(f"Parameter '{name}' has unsupported ZeRO placement family '{family}'.") + + param.ds_zero_placement_family = family + param.ds_zero_partition_group_name = group_name + param.ds_zero_partition_process_group = partition_group + param.ds_zero_partition_rank = dist.get_rank(group=partition_group) + param.ds_zero_partition_world_size = dist.get_world_size(group=partition_group) + # check if parameters are duplicated in optimizer param_groups def _check_for_duplicates(self, optimizer): for name, param in self.module.named_parameters(): @@ -2062,7 +2128,9 @@ def _configure_zero_optimizer(self, optimizer): check_grad_overflow=check_grad_overflow) elif zero_stage == ZeroStageEnum.weights: - assert not self.has_moe_layers, "MoE not supported with Stage 3" + self._validate_zero3_moe_compatibility() + if self.has_moe_layers: + self._resolve_zero3_param_placement() if isinstance(optimizer, DummyOptim): log_dist("Creating ZeRO Offload", ranks=[0]) zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() @@ -3567,8 +3635,17 @@ def load_moe_state_dict(checkpoint_path, moe_layer_id += 1 - def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False): - if fetch_z3_params: + def load_module_state_dict(self, + checkpoint, + strict=True, + custom_load_fn=None, + fetch_z3_params=False, + z3_params_to_fetch=None): + if z3_params_to_fetch is not None: + params_to_fetch = [ + p for p in z3_params_to_fetch if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE + ] + elif fetch_z3_params: params_to_fetch = [ p for p in self.module.parameters() if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE @@ -3789,6 +3866,12 @@ def _load_checkpoint(self, return None, None fetch_z3_params = False + z3_params_to_fetch = None + if self.has_moe_layers and self.zero_optimization_partition_weights() and (load_module_only + or not load_optimizer_states): + raise NotImplementedError("AutoEP with ZeRO Stage 3 does not support module-only checkpoint load or " + "load_optimizer_states=False yet. Use same-topology checkpoint load with " + "optimizer states for this release.") if self.zero_optimization_partition_weights() and not load_optimizer_states: checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir) fetch_z3_params = True @@ -3817,11 +3900,22 @@ def _load_checkpoint(self, num_experts=self.num_experts, checkpoint_engine=self.checkpoint_engine, autoep_layers=autoep_layers) + if self.zero_optimization_partition_weights(): + z3_params_to_fetch = [] + try: + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer + except ImportError: + _AutoEPMoELayer = None + if _AutoEPMoELayer is not None: + for _, module in self.module.named_modules(): + if isinstance(module, _AutoEPMoELayer): + z3_params_to_fetch.extend(module.experts.parameters()) if not self.load_universal_checkpoint(): self.load_module_state_dict(checkpoint=checkpoint, strict=load_module_strict, custom_load_fn=custom_load_fn, - fetch_z3_params=fetch_z3_params) + fetch_z3_params=fetch_z3_params, + z3_params_to_fetch=z3_params_to_fetch) self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] @@ -4262,26 +4356,25 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa f"multiple groups: {sorted(autoep_group_names)}. " f"All AutoEPMoELayer instances must use the same ep_size.") - # Gate file writes behind writer guard - if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): - moe_layer_id += 1 - continue - - # Slice fused 3D tensors into per-expert state dicts - for local_expert_id in range(num_local_experts): - global_expert_id = expp_rank * num_local_experts + local_expert_id - expert_state_dict = {} - for wname in ('w1', 'w2', 'w3'): - fused_key = f"{module_prefix}experts.{wname}" - param = getattr(module.experts, wname) - expert_state_dict[f"{fused_key}.{global_expert_id}"] = ( - param[local_expert_id].clone().detach()) - - moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu) - saveable = expert_state_dict - if self.checkpoint_engine.preserves_storage_sharing(): - saveable = clone_tensors_for_torch_save(expert_state_dict) - self.checkpoint_engine.save(saveable, moe_save_path) + expert_params = [getattr(module.experts, wname) for wname in ('w1', 'w2', 'w3')] + with deepspeed.zero.GatheredParameters(expert_params): + if self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): + # Slice fused 3D tensors into per-expert state dicts. + for local_expert_id in range(num_local_experts): + global_expert_id = expp_rank * num_local_experts + local_expert_id + expert_state_dict = {} + for wname in ('w1', 'w2', 'w3'): + fused_key = f"{module_prefix}experts.{wname}" + param = getattr(module.experts, wname) + expert_state_dict[f"{fused_key}.{global_expert_id}"] = ( + param[local_expert_id].clone().detach()) + + moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, + self.mpu) + saveable = expert_state_dict + if self.checkpoint_engine.preserves_storage_sharing(): + saveable = clone_tensors_for_torch_save(expert_state_dict) + self.checkpoint_engine.save(saveable, moe_save_path) moe_layer_id += 1 @@ -4291,25 +4384,28 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa expp_rank = groups._get_expert_parallel_rank(largest_group_name) exp_dp_rank = groups._get_expert_data_parallel_rank(largest_group_name) - # In the case of E + D parallelism, only the - # first expert parallel group should save the expert weights - # since each expert parallel group is a copy of the model's experts - if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): + # In the case of E + D parallelism, only the first expert data-parallel + # rank writes expert/EP optimizer files because the expert weights are + # replicated across expert-data-parallel ranks. ZeRO-3 model-state + # files are different: each ZeRO partition rank must write its own + # zero_pp_rank_*_model_states.pt file so load can discover a checkpoint + # on every rank. + is_expert_dp_writer = self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank) + if is_expert_dp_writer: + optimizer_state = { + 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None + } + # TODO: why use BufferedWriter not the path + file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) + saveable_state_dict = optimizer_state + if self.checkpoint_engine.preserves_storage_sharing(): + saveable_state_dict = clone_tensors_for_torch_save(optimizer_state) + self.checkpoint_engine.save(saveable_state_dict, file_path) + elif not self.zero_optimization_partition_weights(): return - # Save optimizer states. They are different across each exp parallel rank. - optimizer_state = { - 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None - } - # TODO: why use BufferedWriter not the path - file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) - saveable_state_dict = optimizer_state - if self.checkpoint_engine.preserves_storage_sharing(): - saveable_state_dict = clone_tensors_for_torch_save(optimizer_state) - self.checkpoint_engine.save(saveable_state_dict, file_path) - # Load flow uses below saved file for model parameters, RNG and more - if groups._get_data_parallel_rank() == 0: + if self.zero_optimization_partition_weights() or groups._get_data_parallel_rank() == 0: # Get non-moe parameters # Classes DeepSpeedEngine and PipelineEngine have different behavior for method module_state_dict. # DeepSpeedEngine returns the state dict, where PipelineEngine saves the state dict and returns None. diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 46fbb4ebaf05..90a152091c1b 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1325,7 +1325,12 @@ def module_state_dict(self, exclude_frozen_parameters=False): exclude_frozen_params=exclude_frozen_parameters) return None - def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False): + def load_module_state_dict(self, + checkpoint, + strict=True, + custom_load_fn=None, + fetch_z3_params=False, + z3_params_to_fetch=None): """Override hack to instead use a directory path. This is important because pipeline models checkpoint by layer instead of rank. @@ -1339,7 +1344,10 @@ def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, f assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism" state_dict = checkpoint if self.has_moe_layers else checkpoint['module'] if (state_dict is not None) and (not isinstance(state_dict, str)): - super().load_module_state_dict(state_dict, strict) + super().load_module_state_dict(state_dict, + strict, + fetch_z3_params=fetch_z3_params, + z3_params_to_fetch=z3_params_to_fetch) return self.module.load_state_dir(load_dir=self._curr_ckpt_path, diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 2b23c0b340ee..dd0e2745efea 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1119,11 +1119,11 @@ def _update_persist_config(self, ds_config): def _zero_init_param(self, param): self._convert_to_deepspeed_param(param) - if dist.get_world_group() == self.get_dp_process_group(): - dist.broadcast(param.data, 0, self.get_dp_process_group()) + partition_group = self.get_partition_dp_group(param) + if dist.get_world_group() == partition_group: + dist.broadcast(param.data, 0, partition_group) else: - dist.broadcast(param.data, dist.get_global_rank(self.get_dp_process_group(), 0), - self.get_dp_process_group()) + dist.broadcast(param.data, dist.get_global_rank(partition_group, 0), partition_group) param.partition() def _convert_to_zero_parameters(self, param_list): @@ -1205,7 +1205,9 @@ def _convert_to_deepspeed_param(self, param): param.is_external_param = False # The group that the parameter is scattered across. - param.ds_process_group = self.ds_process_group + param.ds_process_group = getattr(param, "ds_zero_partition_process_group", self.ds_process_group) + param.ds_zero_partition_rank = dist.get_rank(group=param.ds_process_group) + param.ds_zero_partition_world_size = dist.get_world_size(group=param.ds_process_group) param.ds_enable_sanity_checks = self.enable_sanity_checks # Stores the secondary partitioned copy of the tensor @@ -1357,7 +1359,7 @@ def _all_gather_coalesced(params, world_size, rank_in_group, use_secondary_tenso start_param = 0 for param in params: param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape) - start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank() + start = start_param + param.ds_tensor.ds_numel * rank_in_group flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor) start_param += param.ds_numel @@ -1442,11 +1444,23 @@ def _all_gather_coalesced(params, world_size, rank_in_group, use_secondary_tenso def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False, quantize: bool = False) -> AllGatherCoalescedHandle: + params = list(params) + if len(params) == 0: + return _no_gather_coalesced(params) # fetches from nvme if the partition is not available and in nvme self._ensure_availability_of_partitioned_params(params) - if self.num_partitions == 1: + params_by_group = {} + for param in params: + params_by_group.setdefault(id(self.get_partition_dp_group(param)), []).append(param) + if len(params_by_group) > 1: + return MultipleAllGatherHandles([ + all_gather_coalesced(group_params, safe_mode=safe_mode, quantize=quantize) + for group_params in params_by_group.values() + ]) + + if all(self._partition_world_size(param) == 1 for param in params): return _no_gather_coalesced(params) for param in params: @@ -1455,9 +1469,9 @@ def all_gather_coalesced(params: Iterable[Parameter], param.ds_status = ZeroParamStatus.INFLIGHT #use appropriate all gather process group - ds_process_group = self.ds_process_group - rank_in_group = self.rank - world_size = self.dp_world_size + ds_process_group = self.get_partition_dp_group(params[0]) + rank_in_group = self._partition_rank(params[0]) + world_size = self._partition_world_size(params[0]) use_secondary_tensor = params[0].ds_secondary_tensor is not None if self.zero_param_process_group and use_secondary_tensor: ds_process_group = self.zero_param_process_group #intragroup @@ -1581,12 +1595,19 @@ def _aligned_size(self, param): return param.ds_numel + self._padding_size(param) def _padding_size(self, param): - remainder = param.ds_numel % self.num_partitions - return (self.num_partitions - remainder) if remainder else 0 + partition_world_size = self._partition_world_size(param) + remainder = param.ds_numel % partition_world_size + return (partition_world_size - remainder) if remainder else 0 def _partition_numel(self, param): return param.ds_tensor.ds_numel + def _partition_world_size(self, param): + return dist.get_world_size(group=self.get_partition_dp_group(param)) + + def _partition_rank(self, param): + return dist.get_rank(group=self.get_partition_dp_group(param)) + def _ensure_availability_of_partitioned_params(self, params): swap_in_list = [] swap_in_flight = [] @@ -1620,21 +1641,25 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): all_gather_list.append(param) # note: param_list may contain params that are already in flight / aviailable. So we need to use all_gather_list if not async_op: - if self.allgather_sequential or len(all_gather_list) == 1: - ret_value = self._allgather_params_sequential(all_gather_list, hierarchy=hierarchy) - else: - all_gather_quantize_list = [] - all_gather_nonquantize_list = [] - for param in all_gather_list: - if hasattr(param.ds_tensor, - "ds_quant_scale") or (hasattr(param, "ds_secondary_tensor") - and hasattr(param.ds_secondary_tensor, "ds_quant_scale")): - all_gather_quantize_list.append(param) - else: - all_gather_nonquantize_list.append(param) - # _allgather_params_coalesced always return None - self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False) - self._allgather_params_coalesced(all_gather_quantize_list, hierarchy, quantize=True) + all_gather_by_group = {} + for param in all_gather_list: + all_gather_by_group.setdefault(id(self.get_partition_dp_group(param)), []).append(param) + for group_params in all_gather_by_group.values(): + if self.allgather_sequential or len(group_params) == 1: + ret_value = self._allgather_params_sequential(group_params, hierarchy=hierarchy) + else: + all_gather_quantize_list = [] + all_gather_nonquantize_list = [] + for param in group_params: + if hasattr(param.ds_tensor, + "ds_quant_scale") or (hasattr(param, "ds_secondary_tensor") + and hasattr(param.ds_secondary_tensor, "ds_quant_scale")): + all_gather_quantize_list.append(param) + else: + all_gather_nonquantize_list.append(param) + # _allgather_params_coalesced always return None + self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False) + self._allgather_params_coalesced(all_gather_quantize_list, hierarchy, quantize=True) for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE return None @@ -1693,7 +1718,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False, free_data return tensor_size = self._aligned_size(param) - partition_size = tensor_size // self.num_partitions + partition_size = tensor_size // self._partition_world_size(param) if param.ds_tensor is None: final_location = None if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor( @@ -1728,7 +1753,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False, free_data param.ds_tensor.final_location = final_location param.ds_numel_aligned = tensor_size - start = partition_size * self.get_partition_rank() + start = partition_size * self._partition_rank(param) end = start + partition_size one_dim_param = param.contiguous().view(-1) @@ -1844,7 +1869,9 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): partition_size = param.ds_tensor.ds_numel - tensor_size = partition_size * self.num_partitions + partition_world_size = self._partition_world_size(param) + partition_rank = self._partition_rank(param) + tensor_size = partition_size * partition_world_size aligned_param_size = self._aligned_size(param) assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}' @@ -1879,14 +1906,14 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): async_op=async_op) else: partitions = [] - for i in range(self.num_partitions): + for i in range(partition_world_size): partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) - if i == dist.get_rank(group=self.get_partition_dp_group(param)): + if i == partition_rank: partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) handle = dist.all_gather(partitions, - partitions[self.get_partition_rank()], + partitions[partition_rank], group=self.get_partition_dp_group(param), async_op=async_op) @@ -1901,7 +1928,7 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): if len(param_list) == 0: return - if self.num_partitions == 1: + if all(self._partition_world_size(param) == 1 for param in param_list): handle = _no_gather_coalesced(param_list) handle.wait() return None @@ -1922,15 +1949,16 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): allgather_params = [] if quantize: allgather_quantize_scale = [] - for psize in partition_sizes: - tensor_size = psize * self.num_partitions + for param, psize in zip(param_list, partition_sizes): + tensor_size = psize * self._partition_world_size(param) flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device).view(-1) flat_tensor.requires_grad = False allgather_params.append(flat_tensor) if quantize: for psize in quantize_scale_sizes: - tensor_size = psize * self.num_partitions + param = param_list[len(allgather_quantize_scale)] + tensor_size = psize * self._partition_world_size(param) flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.ds_quant_scale.dtype, device=self.local_device).view(-1) @@ -1957,7 +1985,8 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): launch_quantize_handles.append(quantize_handle) else: output_list = [] - for i in range(self.num_partitions): + partition_world_size = self._partition_world_size(param) + for i in range(partition_world_size): psize = partition_sizes[param_idx] partition = allgather_params[param_idx].narrow(0, i * psize, psize) output_list.append(partition) @@ -1969,7 +1998,7 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): h = dist.all_gather(output_list, input_tensor, group=self.get_partition_dp_group(param), async_op=True) if quantize: output_scale_list = [] - for i in range(self.num_partitions): + for i in range(partition_world_size): psize = quantize_scale_sizes[param_idx] partition = allgather_quantize_scale[param_idx].narrow(0, i * psize, psize) output_scale_list.append(partition) @@ -2005,7 +2034,9 @@ def _allgather_params_sequential(self, param_list, hierarchy=0): for param in param_list: partition_size = param.ds_tensor.ds_numel - tensor_size = partition_size * self.num_partitions + partition_world_size = self._partition_world_size(param) + partition_rank = self._partition_rank(param) + tensor_size = partition_size * partition_world_size flat_tensor = torch.empty(tensor_size, dtype=param.ds_tensor.dtype, device=self.local_device) flat_tensor.requires_grad = False @@ -2016,18 +2047,18 @@ def _allgather_params_sequential(self, param_list, hierarchy=0): async_op=False) else: partitions = [] - for i in range(self.num_partitions): + for i in range(partition_world_size): partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) - if i == self.get_partition_rank(): + if i == partition_rank: partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) dist.all_gather(partitions, - partitions[self.get_partition_rank()], + partitions[partition_rank], group=self.get_partition_dp_group(param), async_op=False) if hasattr(param.ds_tensor, 'ds_quant_scale'): scale_size = param.ds_tensor.ds_quant_scale.numel() - scale_tensor_size = scale_size * self.num_partitions + scale_tensor_size = scale_size * partition_world_size flat_scale_tensor = torch.empty(scale_tensor_size, dtype=param.ds_tensor.ds_quant_scale.dtype, device=self.local_device) @@ -2039,12 +2070,12 @@ def _allgather_params_sequential(self, param_list, hierarchy=0): async_op=False) else: scale_partitions = [] - for i in range(self.num_partitions): + for i in range(partition_world_size): scale_partitions.append(flat_scale_tensor.narrow(0, scale_size * i, scale_size)) - if i == self.get_partition_rank(): + if i == partition_rank: scale_partitions[i].data.copy_(param.ds_tensor.ds_quant_scale.data, non_blocking=True) dist.all_gather(scale_partitions, - scale_partitions[self.get_partition_rank()], + scale_partitions[partition_rank], group=self.get_partition_dp_group(param), async_op=False) flat_tensor = self.quantizer_module.dequantize(flat_tensor, flat_scale_tensor) @@ -2076,7 +2107,7 @@ def _reduce_scatter_gradients(self, param_list): # For these ranks the output of reduce scatter is a separate buffer and needs # to be copied in partition_size = param.ds_tensor.ds_numel - start = self.get_partition_rank() * partition_size + start = self._partition_rank(param) * partition_size end = start + partition_size #print_rank_0("REduce scatter was executed for param {param.ds_id}") if start < param.ds_numel < end: @@ -2088,10 +2119,10 @@ def _reduce_scatter_gradient(self, param): partition_size = param.ds_tensor.ds_numel #output = torch.empty(partition_size, dtype=param.dtype, device=param.device) - total_size = partition_size * self.num_partitions + partition_world_size = self._partition_world_size(param) input_list = [] - for i in range(self.num_partitions): + for i in range(partition_world_size): start = i * partition_size end = start + partition_size @@ -2313,12 +2344,16 @@ def load(module: nn.Module, prefix=""): ) # remove the duplicates to prevent racing condition, we must also make sure the order is the same on all ranks otherwise we'll get deadlocks self.enable_sanity_checks = getattr(self.params[0], "ds_enable_sanity_checks", False) self.src_rank = None + self.src_rank_by_group = {} if modifier_rank is not None: - if self.params[0].ds_process_group == dist.get_world_group(): - self.src_rank = modifier_rank - else: - # A group was specified; convert DP rank to global rank - self.src_rank = dist.get_global_rank(self.params[0].ds_process_group, modifier_rank) + for p in self.params: + if p.ds_process_group == dist.get_world_group(): + src_rank = modifier_rank + else: + # A group was specified; convert DP rank to global rank + src_rank = dist.get_global_rank(p.ds_process_group, modifier_rank) + self.src_rank_by_group[id(p.ds_process_group)] = src_rank + self.src_rank = self.src_rank_by_group[id(self.params[0].ds_process_group)] self.fwd_module = fwd_module if self.fwd_module is not None: # is a no-op if already registered @@ -2370,7 +2405,12 @@ def __exit__(self, *exc): f"When using GatheredParameters with modifier_rank set, parameters must remain on " f"the accelerator device. If you don't need to broadcast updates, use modifier_rank=None.") - handles = [dist.broadcast(p.data, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] + handles = [ + dist.broadcast(p.data, + self.src_rank_by_group[id(p.ds_process_group)], + group=p.ds_process_group, + async_op=True) for p in self.params + ] for h in handles: h.wait() self.params[0].partition(param_list=self.params, has_been_updated=True) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index fc81d8bc414f..9f408d0e4f70 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -117,17 +117,20 @@ class IPGBucketZ3: buffer_meta: Optional[torch.Tensor] = None params: List[torch.Tensor] = field(default_factory=list) elements: int = 0 + process_group: Optional[object] = None def clear(self): self.buffer = None self.buffer_meta = None self.params.clear() self.elements = 0 + self.process_group = None def clear_params(self): """Clear params and elements but keep buffer for reuse.""" self.params.clear() self.elements = 0 + self.process_group = None INITIAL_MICRO_STEP_ID = -1 @@ -411,6 +414,8 @@ def _enforce_optimizer_offload(): self.sub_group_size = sub_group_size self.sub_group_to_group_id = {} + self.sub_group_process_groups = [] + self.sub_group_partition_counts = [] # Trainable parameters self.trainable_param_groups = self._get_trainable_parameter_groups() @@ -577,21 +582,75 @@ def initialize_ds_offload( zero_module_granularity_threshold=zero_module_granularity_threshold, log_trace_cache_warnings=log_trace_cache_warnings) + def _get_param_partition_group(self, param): + return getattr(param, "ds_process_group", self.dp_process_group) + + def _get_param_partition_rank(self, param): + return dist.get_rank(group=self._get_param_partition_group(param)) + + def _get_param_partition_world_size(self, param): + return dist.get_world_size(group=self._get_param_partition_group(param)) + + def _get_sub_group_process_group(self, sub_group_id): + return self.sub_group_process_groups[sub_group_id] + + def _get_sub_group_partition_count(self, sub_group_id): + return self.sub_group_partition_counts[sub_group_id] + + def _get_sub_group_partition_rank(self, sub_group_id): + return dist.get_rank(group=self._get_sub_group_process_group(sub_group_id)) + + def _get_sub_group_partition_ranks(self, sub_group_id): + process_group = self._get_sub_group_process_group(sub_group_id) + world_size = dist.get_world_size(group=process_group) + if process_group == dist.get_world_group(): + return list(range(world_size)) + return [dist.get_global_rank(process_group, rank) for rank in range(world_size)] + + def _assert_same_partition_group(self, params): + if not params: + return + process_group = self._get_param_partition_group(params[0]) + if any(self._get_param_partition_group(param) != process_group for param in params): + raise RuntimeError("ZeRO Stage 3 parameter bucket mixes partition process groups.") + + def _autoep_expert_parallel_group(self, params): + families = set(getattr(param, "ds_zero_placement_family", "replicated") for param in params) + if families != {"autoep_expert"}: + return None + + group_names = set(getattr(param, "ds_zero_partition_group_name", None) for param in params) + if len(group_names) != 1: + raise RuntimeError("AutoEP expert ZeRO Stage 3 subgroup mixes expert-parallel group names.") + + group_name = next(iter(group_names)) + if group_name is None: + raise RuntimeError("AutoEP expert ZeRO Stage 3 subgroup is missing an expert-parallel group name.") + return groups._get_expert_parallel_group(group_name) + def _get_trainable_parameter_groups(self): param_groups = [] PARAMS_KEY = "params" - for param_group in self.optimizer.param_groups: - trainable_params = [p for p in param_group[PARAMS_KEY] if p.requires_grad] - if len(trainable_params) == 0: + for source_param_group_id, param_group in enumerate(self.optimizer.param_groups): + trainable_params_by_group = collections.OrderedDict() + for param in param_group[PARAMS_KEY]: + if not param.requires_grad: + continue + process_group = getattr(param, "ds_process_group", self.dp_process_group) + trainable_params_by_group.setdefault(id(process_group), (process_group, []))[1].append(param) + + if len(trainable_params_by_group) == 0: continue - trainable_param_group = {} - for key in param_group.keys(): - if key == PARAMS_KEY: - trainable_param_group[PARAMS_KEY] = trainable_params - else: - trainable_param_group[key] = param_group[key] - param_groups.append(trainable_param_group) + for _, trainable_params in trainable_params_by_group.values(): + trainable_param_group = {} + for key in param_group.keys(): + if key == PARAMS_KEY: + trainable_param_group[PARAMS_KEY] = trainable_params + else: + trainable_param_group[key] = param_group[key] + trainable_param_group["_ds_source_param_group_id"] = source_param_group_id + param_groups.append(trainable_param_group) return param_groups @@ -779,17 +838,23 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): for param_group_idx, param_group in enumerate(param_groups): for sub_group in param_group: sub_group_idx = len(self.fp16_groups) + self._assert_same_partition_group(sub_group) + sub_group_process_group = self._get_param_partition_group(sub_group[0]) + sub_group_partition_count = dist.get_world_size(group=sub_group_process_group) # record sub group and partitions self.fp16_groups.append(sub_group) self.fp16_partitioned_groups.append([param.ds_tensor for param in sub_group]) + self.sub_group_process_groups.append(sub_group_process_group) + self.sub_group_partition_counts.append(sub_group_partition_count) if self.zenflow: for param in sub_group: param.group_id = param_group_idx # record sub group -> group mapping - self.sub_group_to_group_id[sub_group_idx] = param_group_idx + self.sub_group_to_group_id[sub_group_idx] = fp16_param_groups[param_group_idx].get( + "_ds_source_param_group_id", param_group_idx) # record total elements of parameter partitions in sub group self.fp16_partitioned_groups_flat_numel.append(sum(p.partition_numel() for p in sub_group)) @@ -798,8 +863,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): self.fp16_partitioned_groups_flat_id.append([p.ds_id for p in sub_group]) # record padding required to align group to world size (only applies to last rank) - rank_requires_padding = dist.get_rank( - self.dp_process_group) == dist.get_world_size(self.dp_process_group) - 1 + rank_requires_padding = dist.get_rank(sub_group_process_group) == sub_group_partition_count - 1 self.groups_padding.append([p.padding_size() if rank_requires_padding else 0 for p in sub_group]) # move parameters to flattened buffer @@ -1220,10 +1284,8 @@ def get_first_param_index(self, group_id, param_group, partition_id): return None def initialize_gradient_partitioning_data_structures(self): - - total_partitions = dist.get_world_size(group=self.dp_process_group) - for i, param_group in enumerate(self.fp16_groups): + total_partitions = self._get_sub_group_partition_count(i) self.param_to_partition_ids[i] = {} self.is_partition_reduced[i] = {} @@ -1368,6 +1430,10 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param): # empty, while reduction_list will have that garbage data. comm_dtype = self.get_param_comm_dtype(param) bucket = self.ipg_buckets[comm_dtype] + param_process_group = self._get_param_partition_group(param) + if bucket.params and bucket.process_group != param_process_group: + self.report_ipg_memory_usage("In ipg_remove_grads before cross-group reduce_ipg_grads", param.ds_numel) + self.__reduce_and_partition_ipg_grads(comm_dtype) if bucket.elements + param.ds_numel > self.reduce_bucket_size and bucket.elements > 0: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel) self.__reduce_and_partition_ipg_grads(comm_dtype) @@ -1400,6 +1466,7 @@ def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: param.grad.data = new_grad_tensor bucket.params.append(param) + bucket.process_group = self._get_param_partition_group(param) bucket.elements += param.grad.numel() @instrument_w_nvtx @@ -1411,6 +1478,8 @@ def __reduce_and_partition_ipg_grads(self, communication_data_type: torch.dtype) if not params_in_bucket: return + self._assert_same_partition_group(params_in_bucket) + for param in params_in_bucket: if param.grad.numel() != param.ds_numel: raise RuntimeError(f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " @@ -1444,6 +1513,7 @@ def __reduce_and_partition_ipg_grads(self, communication_data_type: torch.dtype) params_in_bucket.clear() bucket.elements = 0 + bucket.process_group = None if not get_accelerator().handles_memory_backpressure(): event = get_accelerator().Event() @@ -1508,8 +1578,9 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b gathered_params_momentums = self._partitioned_buffers_all_gather(params, momentum_buffer, communication_data_type) - world_sz = dist.get_world_size(self.dp_process_group) - rank = dist.get_rank(self.dp_process_group) + process_group = self._get_sub_group_process_group(i) + world_sz = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) grads_pad = [param.grad for param in params] + [torch.empty_like(params[-1].grad)] * ( (world_sz - len(params) % world_sz) % world_sz) gathered_momentums_pad = gathered_params_momentums + [torch.empty_like(gathered_params_momentums[-1])] * ( @@ -1525,10 +1596,12 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b g.data.copy_(update, non_blocking=False) grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz], grads_pad[base_i + rank], + group=process_group, async_op=True) grad_handles.append(grad_handle) momentum_handle = dist.all_gather(gathered_momentums_pad[base_i:base_i + world_sz], gathered_momentums_pad[base_i + rank], + group=process_group, async_op=True) momentum_handles.append(momentum_handle) @@ -1571,17 +1644,21 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b @instrument_w_nvtx def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, communication_data_type: torch.dtype) -> List[Tensor]: + bucket = self.ipg_buckets[communication_data_type] + params_in_bucket = bucket.params + self._assert_same_partition_group(params_in_bucket) + process_group = bucket.process_group dtype = buffer_to_reduce.dtype if communication_data_type != dtype: buffer_to_reduce = buffer_to_reduce.to(communication_data_type) if self.postscale_gradients and self.gradient_predivide_factor != 1.0: buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor) - world_sz = dist.get_world_size(self.dp_process_group) - rank = dist.get_rank(self.dp_process_group) + world_sz = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size)) - dist.all_reduce(buffer_to_reduce, group=self.dp_process_group) + dist.all_reduce(buffer_to_reduce, group=process_group) if self.postscale_gradients and self.gradient_predivide_factor != world_sz: buffer_to_reduce = buffer_to_reduce.mul(self.gradient_predivide_factor) @@ -1592,7 +1669,7 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, grad_partitions = [] grad_offset_in_buffer = 0 self._apply_distributed_muon_update(communication_data_type, buffer_to_reduce) - for param in self.ipg_buckets[communication_data_type].params: + for param in params_in_bucket: grad = param.grad chunk_sz = math.ceil(grad.numel() / world_sz) @@ -1617,6 +1694,8 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, def __avg_scatter_grads(self, params_to_reduce: List[Parameter], communication_data_type: torch.dtype) -> List[Tensor]: """average gradients and scatter partitions across ranks""" + self._assert_same_partition_group(params_to_reduce) + process_group = self._get_param_partition_group(params_to_reduce[0]) full_grads_for_rank = [p.grad for p in params_to_reduce] if communication_data_type != self.dtype: @@ -1634,10 +1713,10 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter], if self.zeropp_loco_param is not None else all_to_all_quant_reduce( full_grads_for_rank, self.all2all_process_group)) else: - grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group) + grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, process_group) if self.postscale_gradients and self.gradient_predivide_factor != 1.0 and self.gradient_predivide_factor != dist.get_world_size( - self.dp_process_group): + process_group): grad_partitions_for_rank = [g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank] if communication_data_type != self.gradient_accumulation_dtype: @@ -1681,6 +1760,8 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tenso param.grad = None def complete_grad_norm_calculation_for_cpu_offload(self, params): + self._assert_same_partition_group(params) + process_group = self._get_param_partition_group(params[0]) total_norm = 0.0 norm_type = 2.0 for p in params: @@ -1693,9 +1774,12 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=process_group) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + autoep_ep_group = self._autoep_expert_parallel_group(params) + if autoep_ep_group is not None: + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=autoep_ep_group) total_norm = total_norm_cuda[0]**(1. / norm_type) @@ -1710,7 +1794,7 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L buffers = [] for param, grad_partition in zip(params_to_release, grad_partitions): - contains_real_data = param.partition_numel() * dist.get_rank(self.dp_process_group) < param.ds_numel + contains_real_data = param.partition_numel() * self._get_param_partition_rank(param) < param.ds_numel if not contains_real_data: # this grad partition is empty - don't need to do anything param.grad = None @@ -1785,25 +1869,28 @@ def _partitioned_buffers_all_gather(self, params: List[Parameter], buffers_to_al for param, buffer in zip(params, buffers_to_allgather)), \ "params and buffers_to_allgather must have the same numel" + self._assert_same_partition_group(params) + process_group = self._get_param_partition_group(params[0]) + partition_count = dist.get_world_size(group=process_group) coalesced_buffer = instrument_w_nvtx(torch.cat)(buffers_to_allgather) buffer_numel = coalesced_buffer.numel() - reduce_buffer = torch.empty(self.partition_count * buffer_numel, + reduce_buffer = torch.empty(partition_count * buffer_numel, dtype=communication_data_type, device=params[0].device) - rearrange_buffer = torch.empty(self.partition_count * buffer_numel, + rearrange_buffer = torch.empty(partition_count * buffer_numel, dtype=communication_data_type, device=params[0].device) - my_rank = dist.get_rank(group=self.dp_process_group) + my_rank = dist.get_rank(group=process_group) partition = reduce_buffer.narrow(0, buffer_numel * my_rank, buffer_numel) partition.data.copy_(coalesced_buffer.data, non_blocking=False) - dist.all_gather_into_tensor(reduce_buffer, partition, group=self.dp_process_group) + dist.all_gather_into_tensor(reduce_buffer, partition, group=process_group) param_partition_offsets = [0] rearranged_offset = 0 for idx, param in enumerate(params): param_partition_offsets.append(param_partition_offsets[idx] + param.partition_numel()) for idx, param in enumerate(params): num_elements = param.partition_numel() - for partition_idx in range(self.partition_count): + for partition_idx in range(partition_count): sliced = reduce_buffer.narrow(0, buffer_numel * partition_idx + param_partition_offsets[idx], num_elements) rearrange_buffer.narrow(0, rearranged_offset, num_elements).copy_(sliced.data, non_blocking=False) @@ -1811,8 +1898,7 @@ def _partitioned_buffers_all_gather(self, params: List[Parameter], buffers_to_al param_full_offsets = [0] for idx, param in enumerate(params): # the offset is the sum of the numel of all the partitions of the parameter including padding - param_full_offsets.append(param_full_offsets[idx] + - buffers_to_allgather[idx].numel() * self.partition_count) + param_full_offsets.append(param_full_offsets[idx] + buffers_to_allgather[idx].numel() * partition_count) output = [] for idx, param in enumerate(params): output.append(rearrange_buffer.narrow(0, param_full_offsets[idx], param.ds_numel).view(param.ds_shape)) @@ -1921,8 +2007,10 @@ def set_none_gradients_to_zero(self, i, partition_id): ######################Reduction Related Methods############################## - def allreduce_bucket(self, bucket, rank=None, log=None): + def allreduce_bucket(self, bucket, rank=None, log=None, group=None): rank = None + if group is None: + group = self.dp_process_group tensor = self.flatten(bucket) tensor_to_allreduce = tensor @@ -1935,10 +2023,10 @@ def allreduce_bucket(self, bucket, rank=None, log=None): if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) + tensor_to_allreduce.div_(dist.get_world_size(group=group) / float(self.sequence_parallel_size)) # "All Reducing" - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + dist.all_reduce(tensor_to_allreduce, group=group) if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: tensor.copy_(tensor_to_allreduce) @@ -1946,24 +2034,26 @@ def allreduce_bucket(self, bucket, rank=None, log=None): return tensor # if rank is specified do a reduction instead of an allreduce - def allreduce_and_copy(self, small_bucket, rank=None, log=None): + def allreduce_and_copy(self, small_bucket, rank=None, log=None, group=None): + if group is None: + group = self.dp_process_group with get_accelerator().stream(self.reduction_stream): - allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) - if rank is None or rank == dist.get_rank(group=self.dp_process_group): + allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log, group=group) + if rank is None or rank == dist.get_rank(group=group): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) - def allreduce_no_retain(self, bucket, numel_per_bucket=500000000, rank=None, log=None): + def allreduce_no_retain(self, bucket, numel_per_bucket=500000000, rank=None, log=None, group=None): small_bucket = [] numel = 0 for tensor in bucket: small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, rank=rank, log=None) + self.allreduce_and_copy(small_bucket, rank=rank, log=None, group=group) small_bucket = [] if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, rank=rank, log=log) + self.allreduce_and_copy(small_bucket, rank=rank, log=log, group=group) ############################################################################# ############################################################################# @@ -1971,11 +2061,12 @@ def allreduce_no_retain(self, bucket, numel_per_bucket=500000000, rank=None, log # views the tensor as multiple partitions and returns # those partitions - def get_data_parallel_partitions(self, tensor): + def get_data_parallel_partitions(self, tensor, group=None): + if group is None: + group = self.dp_process_group partitions = [] - dp = dist.get_world_size(group=self.dp_process_group) - dp_id = dist.get_rank(group=self.dp_process_group) + dp = dist.get_world_size(group=group) total_num_elements = tensor.numel() @@ -2091,14 +2182,19 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): Returns: Total norm of the parameters (viewed as a single vector). """ + self._assert_same_partition_group(params) + process_group = self._get_param_partition_group(params[0]) norm_type = float(norm_type) if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=process_group) # Take max across all GPUs. self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) + autoep_ep_group = self._autoep_expert_parallel_group(params) + if autoep_ep_group is not None: + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=autoep_ep_group) total_norm = total_norm_cuda[0] else: # if dist.get_rank() == 0: @@ -2116,9 +2212,12 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): else: total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=process_group) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + autoep_ep_group = self._autoep_expert_parallel_group(params) + if autoep_ep_group is not None: + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=autoep_ep_group) total_norm = total_norm_cuda**(1. / norm_type) @@ -2216,7 +2315,7 @@ def _get_norm_groups(self): @instrument_w_nvtx def _prepare_fp32_grad_for_sub_group(self, sub_group_id): - partition_id = dist.get_rank(group=self.dp_process_group) + partition_id = dist.get_rank(group=self._get_sub_group_process_group(sub_group_id)) single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( self.fp32_partitioned_groups_flat[sub_group_id].dtype) @@ -2638,13 +2737,15 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: return grad_dict def _fp32_state_allgather(self, param, fp32_state_partition): - reduce_buffer = torch.empty(self.partition_count * fp32_state_partition.numel(), + process_group = self._get_param_partition_group(param) + partition_count = dist.get_world_size(group=process_group) + reduce_buffer = torch.empty(partition_count * fp32_state_partition.numel(), dtype=self.master_weights_and_grads_dtype, device=param.device) - my_rank = dist.get_rank(group=self.dp_process_group) + my_rank = dist.get_rank(group=process_group) partition = reduce_buffer.narrow(0, fp32_state_partition.numel() * my_rank, fp32_state_partition.numel()) partition.data.copy_(fp32_state_partition.data, non_blocking=False) - dist.all_gather_into_tensor(reduce_buffer, partition, group=self.dp_process_group) + dist.all_gather_into_tensor(reduce_buffer, partition, group=process_group) return reduce_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape) def _get_fp32_grad_state_partition(self, param, release_swap_buffers): @@ -2688,7 +2789,7 @@ def set_fp32_grad_for_param(self, value, param): fp32_grad, group_idx = self._get_fp32_grad_state_partition(param=param, release_swap_buffers=False) # import pdb; pdb.set_trace() - my_rank = dist.get_rank(group=self.dp_process_group) + my_rank = self._get_param_partition_rank(param) value_partition = value.flatten().narrow(0, fp32_grad.numel() * my_rank, fp32_grad.numel()) fp32_grad.data.copy_(value_partition.data) @@ -2740,7 +2841,7 @@ def set_full_hp_param(self, value, param, optim_state_key=None): optim_state_key=optim_state_key) # print(f'{dist.get_rank()=} {fp32_opt_state_partition.shape=} -------- {value.shape=}') # import pdb; pdb.set_trace() - my_rank = dist.get_rank(group=self.dp_process_group) + my_rank = self._get_param_partition_rank(param) value_partition = value.flatten().narrow(0, fp32_opt_state_partition.numel() * my_rank, fp32_opt_state_partition.numel()) @@ -2843,7 +2944,7 @@ def update_fp32_grad_for_param_vectorized(self, update_func, param_list): fp32_grad_full = self._fp32_state_allgather(param, fp32_grad_part) new_fp32_grad_full = update_func(fp32_grad_full, param) - my_rank = dist.get_rank(group=self.dp_process_group) + my_rank = self._get_param_partition_rank(param) value_partition = new_fp32_grad_full.flatten().narrow(0, fp32_grad_part.numel() * my_rank, fp32_grad_part.numel()) @@ -2958,6 +3059,24 @@ def _clear_fp32_optimizer_param_groups(self): for param_group in self.optimizer.param_groups: param_group['params'] = [] + def _zero3_partition_group_metadata(self): + metadata = [] + has_autoep_expert_group = False + for i, group in enumerate(self.fp16_groups): + families = sorted(set(getattr(param, "ds_zero_placement_family", "replicated") for param in group)) + group_names = sorted( + set(getattr(param, "ds_zero_partition_group_name", None) or "replicated" for param in group)) + has_autoep_expert_group = has_autoep_expert_group or "autoep_expert" in families + metadata.append({ + "sub_group": i, + "partition_count": self._get_sub_group_partition_count(i), + "partition_rank": self._get_sub_group_partition_rank(i), + "partition_ranks": self._get_sub_group_partition_ranks(i), + "families": families, + "group_names": group_names, + }) + return metadata if has_autoep_expert_group else None + def _rigid_state_dict(self): state_dict = {} state_dict[ZERO_STAGE] = ZeroStageEnum.weights @@ -2965,6 +3084,9 @@ def _rigid_state_dict(self): state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['overflow'] = self.overflow state_dict[PARTITION_COUNT] = self.partition_count + zero3_partition_group_metadata = self._zero3_partition_group_metadata() + if zero3_partition_group_metadata is not None: + state_dict['ds_zero_partition_groups'] = zero3_partition_group_metadata self._set_fp32_optimizer_param_groups() state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict() @@ -3000,7 +3122,8 @@ def _restore_from_fp32_weights(self, all_state_dict): flat_local_partition = [] for i in range(len(self.fp32_partitioned_groups_flat)): merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] - flat_local_partition.append(self._get_flattened_partition(merged_partitions)) + flat_local_partition.append( + self._get_flattened_partition(merged_partitions, group=self._get_sub_group_process_group(i))) for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): current.data.copy_(saved.data) @@ -3016,9 +3139,11 @@ def refresh_fp32_params(self): self._restore_from_bit16_weights() # Extract flattened partition for current rank from all partitions - def _get_flattened_partition(self, all_partition_states): - partition_id = dist.get_rank(group=self.dp_process_group) - alignment = dist.get_world_size(group=self.dp_process_group) + def _get_flattened_partition(self, all_partition_states, group=None): + if group is None: + group = self.dp_process_group + partition_id = dist.get_rank(group=group) + alignment = dist.get_world_size(group=group) param_partitions = [[] for _ in range(len(all_partition_states[0]))] for i, partition in enumerate(all_partition_states): @@ -3028,7 +3153,7 @@ def _get_flattened_partition(self, all_partition_states): local_state_partitions = [] for param_index, param_slices in enumerate(param_partitions): flattened_merged_tensor = self.flatten_dense_tensors_aligned(param_slices, alignment) - new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) + new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor, group=group) local_state_partitions.append(new_partitions[partition_id]) if torch.is_tensor(local_state_partitions[0]): @@ -3048,7 +3173,8 @@ def _restore_base_optimizer_state(self, all_state_dict): all_partition_group_states = [sd['base_optimizer_state'][i] for sd in all_state_dict] for key in all_partition_group_states[0].keys(): all_partition_states = [all_states[key] for all_states in all_partition_group_states] - partition_states[key] = self._get_flattened_partition(all_partition_states) + partition_states[key] = self._get_flattened_partition(all_partition_states, + group=self._get_sub_group_process_group(i)) base_optimizer_group_states.append(partition_states) for i, group in enumerate(self.optimizer.param_groups): @@ -3064,6 +3190,10 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): self.loss_scaler = state_dict[LOSS_SCALER] self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.overflow = state_dict['overflow'] + saved_partition_groups = state_dict.get('ds_zero_partition_groups') + if saved_partition_groups is not None and saved_partition_groups != self._zero3_partition_group_metadata(): + raise RuntimeError("AutoEP ZeRO-3 checkpoints can only be loaded with the same parameter placement " + "topology in this release.") if load_optimizer_states: self._set_fp32_optimizer_param_groups() diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 7f96f718af1f..28780f2ba512 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -861,7 +861,7 @@ When a HuggingFace model provides a built-in `tp_plan` (via `model.config.base_m | Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to `True` by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. | `True` | ### Expert Parallel (AutoEP) -Configure AutoEP expert parallelism for MoE models. AutoEP automatically detects MoE layers in HuggingFace models and replaces them with EP-enabled versions using TorchTitan's grouped GEMM kernels. Requires zero model code changes. Supports ZeRO stages 0, 1, and 2 (stage 3 is not supported). +Configure AutoEP expert parallelism for MoE models. AutoEP automatically detects MoE layers in HuggingFace models and replaces them with EP-enabled versions using TorchTitan's grouped GEMM kernels. Requires zero model code changes. Supports ZeRO stages 0, 1, 2, and constrained ZeRO Stage 3. ```json "expert_parallel": { "enabled": true, @@ -887,6 +887,12 @@ Configure AutoEP expert parallelism for MoE models. AutoEP automatically detects | -------------------------------------------------------------------------------------------------- | ------- | | Expert-parallel degree (number of ranks sharing expert computation). Must divide `world_size / pp_size`. `1` = all experts local (no AllToAll), useful for testing. | `1` | +***expert_tensor_parallel_size***: [integer] + +| Description | Default | +| -------------------------------------------------------------------------------------------------- | ------- | +| Reserved for expert tensor parallelism. AutoEP currently accepts only `1`; non-1 values are rejected. | `1` | + ***preset_model***: [string] | Description | Default | @@ -1070,7 +1076,8 @@ Use a built-in preset but override specific naming/weight fields for a fine-tune **Constraints:** - `autoep_size` must divide `num_experts` for all detected MoE layers - AutoEP currently cannot be combined with AutoTP (`tensor_parallel.autotp_size > 1`); support is planned as follow-up work -- ZeRO Stage 3 is not supported with AutoEP (assertion will fire) +- AutoEP with ZeRO Stage 3 is supported only without AutoTP, sequence parallelism, MiCS, hpZeRO secondary tensor groups, non-1 `expert_tensor_parallel_size`, or quantized gradients +- ZeRO Stage 3 supports same-topology AutoEP checkpoint save/load when optimizer state is loaded; module-only loads, optimizer-state-skipping loads, universal checkpoint conversion, and topology-changing loads are not supported ### Logging diff --git a/docs/_tutorials/universal-checkpointing.md b/docs/_tutorials/universal-checkpointing.md index 013a14431c45..e744fc655085 100644 --- a/docs/_tutorials/universal-checkpointing.md +++ b/docs/_tutorials/universal-checkpointing.md @@ -87,16 +87,18 @@ weights have an additional layout. With AutoEP enabled, DeepSpeed writes the routed expert weights (`w1`, `w2`, and `w3`) into per-expert files named like `layer__expert__mp_rank__model_states.pt`. The regular model checkpoint records AutoEP metadata in `ds_autoep_layers`; older -checkpoints may use the legacy `autoep_layers` key. Router, gate, shared-expert, -and other non-routed-expert parameters stay in the regular -`mp_rank_*_model_states.pt` files and use the standard Universal Checkpointing -path. +checkpoints may use the legacy `autoep_layers` key. For ZeRO Stage 1 and ZeRO +Stage 2 checkpoints, router, gate, shared-expert, and other +non-routed-expert parameters stay in the regular `mp_rank_*_model_states.pt` +files and use the standard Universal Checkpointing path. Use ZeRO Stage 1 or ZeRO Stage 2 for the current AutoEP Universal Checkpoint conversion path. ZeRO Stage 3 AutoEP Universal Checkpoint conversion is not -supported; when AutoEP metadata is present, the converter raises -`NotImplementedError` with the message that AutoEP currently requires ZeRO Stage -1 or 2. +supported; when AutoEP metadata is present in `zero_pp_rank_*_model_states.pt`, +the converter raises `NotImplementedError` and instructs users to use regular +same-topology ZeRO-3 checkpoint load for AutoEP checkpoints. AutoEP ZeRO Stage 3 +checkpoints do not support Universal conversion, topology-changing loads, +module-only loads, or optimizer-state-free loads. During conversion, `ds_to_universal.py` reads `ds_autoep_layers` or the legacy `autoep_layers` key, consolidates each AutoEP layer's routed expert files, and @@ -108,8 +110,9 @@ such as `exp_avg.pt` and `exp_avg_sq.pt` next to the converted parameter. Regular AutoEP checkpoint load requires the target run to use the same `autoep_size` as the save run. To change `autoep_size` for the same -AutoEP-detected model topology, convert the saved checkpoint to Universal format -and load the Universal checkpoint. +AutoEP-detected model topology, convert a ZeRO Stage 1 or ZeRO Stage 2 checkpoint +to Universal format and load the Universal checkpoint. For ZeRO Stage 3 AutoEP +checkpoints, use regular same-topology checkpoint load instead. In the Universal Checkpoint load path, AutoEP routed experts are restored from the `zero/` parameter layout rather than from the regular diff --git a/tests/unit/v1/moe/test_autoep_integration.py b/tests/unit/v1/moe/test_autoep_integration.py index f8be9304ff7b..5775f5e218da 100644 --- a/tests/unit/v1/moe/test_autoep_integration.py +++ b/tests/unit/v1/moe/test_autoep_integration.py @@ -6,7 +6,10 @@ import pytest import torch +import torch.nn as nn import deepspeed +from deepspeed import comm as dist +from deepspeed.moe.layer import MoE from unit.v1.moe.autoep_test_utils import ( MockMoETransformer, make_autoep_integration_config as _make_autoep_config, @@ -15,6 +18,16 @@ ) from unit.common import DistributedTest + +def _assert_global_grad_norm_consistent(engine): + norm_groups = engine.optimizer._get_norm_groups() + local_norm = torch.linalg.vector_norm(torch.stack(norm_groups)).detach().reshape(1) + gathered = [torch.zeros_like(local_norm) for _ in range(dist.get_world_size())] + dist.all_gather(gathered, local_norm) + for norm in gathered[1:]: + assert torch.allclose(norm, gathered[0], rtol=1e-4, atol=1e-4), [float(item.item()) for item in gathered] + + # --------------------------------------------------------------------------- # Test class: AutoEP integration (world_size=2) # --------------------------------------------------------------------------- @@ -57,16 +70,121 @@ def test_zero2_ep_2gpu(self): params_changed += 1 assert params_changed > 0, "No parameters changed after 3 training steps with ZeRO-2" - def test_zero3_ep_rejected_2gpu(self): - """EP with ZeRO-3 should trigger an assertion error. - - ZeRO Stage 3 is incompatible with MoE. The engine should raise - an AssertionError with the message 'MoE not supported with Stage 3'. - """ + def test_zero3_ep_train_step_and_placement_2gpu(self): + """EP with ZeRO-3 trains when AutoEP owns the MoE layers.""" _seed_everything(1234) model = MockMoETransformer() config = _make_autoep_config(zero_stage=3, ep_size=2) + engine, _, _, _ = deepspeed.initialize(model=model, config=config) + + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer + autoep_layers = [m for _, m in engine.module.named_modules() if isinstance(m, AutoEPMoELayer)] + assert len(autoep_layers) == 2 + + for layer in autoep_layers: + for param in layer.experts.parameters(): + assert param.ds_zero_placement_family == "autoep_expert" + assert param.ds_zero_partition_group_name == layer.ep_group_name + assert param.ds_zero_partition_world_size == 1 + for param in layer.router.parameters(): + assert param.ds_zero_placement_family == "replicated" + assert param.ds_zero_partition_world_size == 2 + + losses, _ = _run_training_steps(engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) + + def test_zero3_native_moe_rejected_2gpu(self): + + class NativeMoEModel(nn.Module): + + def __init__(self): + super().__init__() + self.moe = MoE(hidden_size=64, expert=nn.Linear(64, 64), num_experts=2, ep_size=2) + + def forward(self, x): + output, _, _ = self.moe(x) + return output + + config = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + }, + }, + "zero_optimization": { + "stage": 3, + }, + } + + with pytest.raises(AssertionError, match="Native DeepSpeed MoE"): + deepspeed.initialize(model=NativeMoEModel(), config=config) + + def test_zero3_ep_save_load_same_topology_2gpu(self, tmpdir): + _seed_everything(5678) + + model = MockMoETransformer() + config = _make_autoep_config(zero_stage=3, ep_size=2) + engine, _, _, _ = deepspeed.initialize(model=model, config=config) + _run_training_steps(engine, num_steps=1) + + save_dir = str(tmpdir) + engine.save_checkpoint(save_dir, tag="autoep-zero3") + + reloaded = MockMoETransformer() + reloaded_engine, _, _, _ = deepspeed.initialize(model=reloaded, config=config) + _, client_state = reloaded_engine.load_checkpoint(save_dir, tag="autoep-zero3") + assert client_state is not None + + module_only = MockMoETransformer() + module_only_engine, _, _, _ = deepspeed.initialize(model=module_only, config=config) + with pytest.raises(NotImplementedError, match="load_optimizer_states=False"): + module_only_engine.load_checkpoint(save_dir, tag="autoep-zero3", load_optimizer_states=False) + + losses, _ = _run_training_steps(reloaded_engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) + + +class TestAutoEPZero3ReplicaGroups(DistributedTest): + world_size = 4 + + def test_zero3_ep_expert_replica_group_train_save_load_4gpu(self, tmpdir): + _seed_everything(9012) + + model = MockMoETransformer() + config = _make_autoep_config(zero_stage=3, ep_size=2) + config["gradient_clipping"] = 1.0 + engine, _, _, _ = deepspeed.initialize(model=model, config=config) - with pytest.raises(AssertionError, match="MoE not supported with Stage 3"): - deepspeed.initialize(model=model, config=config) + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer + autoep_layers = [m for _, m in engine.module.named_modules() if isinstance(m, AutoEPMoELayer)] + assert len(autoep_layers) == 2 + + for layer in autoep_layers: + for param in layer.experts.parameters(): + assert param.ds_zero_placement_family == "autoep_expert" + assert param.ds_zero_partition_group_name == layer.ep_group_name + assert param.ds_zero_partition_world_size == 2 + for param in layer.router.parameters(): + assert param.ds_zero_placement_family == "replicated" + assert param.ds_zero_partition_world_size == 4 + + x = torch.randn(1, 8, 64, device=engine.device) + loss = engine(x).mean() + engine.backward(loss) + _assert_global_grad_norm_consistent(engine) + engine.step() + assert torch.isfinite(engine.optimizer._global_grad_norm) + + save_dir = str(tmpdir) + engine.save_checkpoint(save_dir, tag="autoep-zero3") + + reloaded = MockMoETransformer() + reloaded_engine, _, _, _ = deepspeed.initialize(model=reloaded, config=config) + _, client_state = reloaded_engine.load_checkpoint(save_dir, tag="autoep-zero3") + assert client_state is not None + + losses, _ = _run_training_steps(reloaded_engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index 3104f8f8a9ed..ade8f55c4d09 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -4,6 +4,7 @@ # DeepSpeed Team """Compact critical-path tests for AutoEP.""" +import inspect from types import SimpleNamespace import pytest @@ -11,6 +12,7 @@ import torch.nn as nn import deepspeed.runtime.engine as ds_engine +import deepspeed.runtime.zero.stage3 as zero_stage3 from deepspeed.module_inject.auto_ep import AutoEP, _resolve_route_scale from deepspeed.module_inject.auto_ep_config import ( AutoEPConfig, @@ -32,10 +34,12 @@ preset_name_for_hf_model_type, unsupported_preset_for_hf_model_type, ) +from deepspeed.moe.layer import MoE from deepspeed.moe.ep_experts import GroupedExperts from deepspeed.moe.ep_kernels import TokenReorderer from deepspeed.moe.ep_router import TokenChoiceTopKRouter from deepspeed.runtime.engine import DeepSpeedEngine +from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 from deepspeed.utils import groups from unit.v1.moe.autoep_test_utils import ( MockMoEBlock, @@ -202,6 +206,18 @@ def test_ep_size_validation_rejects_invalid_topology(self): with pytest.raises(ValueError, match="exceeds num_experts"): validate_autoep_post_detection(AutoEPConfig(enabled=True, autoep_size=16), [_make_spec(num_experts=8)]) + def test_expert_tensor_parallel_size_is_parsed_but_limited_to_one(self): + config = parse_autoep_config({ + "enabled": True, + "autoep_size": 2, + "expert_tensor_parallel_size": 1, + }) + assert config.expert_tensor_parallel_size == 1 + + config.expert_tensor_parallel_size = 2 + with pytest.raises(ValueError, match="expert_tensor_parallel_size=1"): + validate_autoep_config(config, world_size=4, pp_size=1, tp_size=1, sp_size=1) + def test_configure_expert_parallel_uses_engine_mpu_sequence_parallel_size(self, monkeypatch): class SequenceParallelMPU: @@ -285,6 +301,210 @@ def test_autoep_sequence_parallel_size_falls_back_to_groups_helper(self, monkeyp assert engine._autoep_sequence_parallel_world_size() == 3 + def test_zero3_compatibility_gate_rejects_native_moe(self): + engine = object.__new__(DeepSpeedEngine) + engine.__dict__["module"] = nn.Sequential(MoE(hidden_size=4, expert=nn.Linear(4, 4), num_experts=1)) + engine.has_moe_layers = True + engine.sequence_parallel_size = 1 + engine._config = SimpleNamespace( + tensor_parallel_config=SimpleNamespace(autotp_size=1), + expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=1), + ) + + with pytest.raises(AssertionError, match="Native DeepSpeed MoE"): + engine._validate_zero3_moe_compatibility() + + def test_zero3_compatibility_gate_allows_constrained_autoep(self): + model = MockMoETransformer(num_layers=1) + replace_autoep_layers(model, "mixtral") + engine = object.__new__(DeepSpeedEngine) + engine.__dict__["module"] = model + engine.has_moe_layers = True + engine.sequence_parallel_size = 1 + engine.zero_quantized_gradients = lambda: False + engine._config = SimpleNamespace( + tensor_parallel_config=SimpleNamespace(autotp_size=1), + expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=1), + ) + + engine._validate_zero3_moe_compatibility() + + def test_zero3_compatibility_gate_rejects_sequence_parallel(self): + model = MockMoETransformer(num_layers=1) + replace_autoep_layers(model, "mixtral") + engine = object.__new__(DeepSpeedEngine) + engine.__dict__["module"] = model + engine.has_moe_layers = True + engine.sequence_parallel_size = 2 + engine.zero_quantized_gradients = lambda: False + engine._config = SimpleNamespace( + tensor_parallel_config=SimpleNamespace(autotp_size=1), + expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=1), + ) + + with pytest.raises(AssertionError, match="sequence parallelism"): + engine._validate_zero3_moe_compatibility() + + def test_zero3_compatibility_gate_rejects_active_autotp(self): + model = MockMoETransformer(num_layers=1) + replace_autoep_layers(model, "mixtral") + engine = object.__new__(DeepSpeedEngine) + engine.__dict__["module"] = model + engine.has_moe_layers = True + engine.sequence_parallel_size = 1 + engine.zero_quantized_gradients = lambda: False + engine._config = SimpleNamespace( + tensor_parallel_config=SimpleNamespace(autotp_size=2), + expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=1), + ) + + with pytest.raises(AssertionError, match="AutoTP"): + engine._validate_zero3_moe_compatibility() + + def test_zero3_compatibility_gate_rejects_quantized_gradients(self): + model = MockMoETransformer(num_layers=1) + replace_autoep_layers(model, "mixtral") + engine = object.__new__(DeepSpeedEngine) + engine.__dict__["module"] = model + engine.has_moe_layers = True + engine.sequence_parallel_size = 1 + engine.zero_quantized_gradients = lambda: True + engine._config = SimpleNamespace( + tensor_parallel_config=SimpleNamespace(autotp_size=1), + expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=1), + ) + + with pytest.raises(AssertionError, match="zero_quantized_gradients"): + engine._validate_zero3_moe_compatibility() + + def test_zero3_compatibility_gate_rejects_mics(self): + model = MockMoETransformer(num_layers=1) + replace_autoep_layers(model, "mixtral") + engine = object.__new__(DeepSpeedEngine) + engine.__dict__["module"] = model + engine.has_moe_layers = True + engine.sequence_parallel_size = 1 + engine.zero_quantized_gradients = lambda: False + engine._config = SimpleNamespace( + mics_shard_size=2, + zero_config=SimpleNamespace(zero_hpz_partition_size=1), + tensor_parallel_config=SimpleNamespace(autotp_size=1), + expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=1), + ) + + with pytest.raises(AssertionError, match="MiCS"): + engine._validate_zero3_moe_compatibility() + + def test_zero3_compatibility_gate_rejects_hpzero(self): + model = MockMoETransformer(num_layers=1) + replace_autoep_layers(model, "mixtral") + engine = object.__new__(DeepSpeedEngine) + engine.__dict__["module"] = model + engine.has_moe_layers = True + engine.sequence_parallel_size = 1 + engine.zero_quantized_gradients = lambda: False + engine._config = SimpleNamespace( + mics_shard_size=0, + zero_config=SimpleNamespace(zero_hpz_partition_size=2), + tensor_parallel_config=SimpleNamespace(autotp_size=1), + expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=1), + ) + + with pytest.raises(AssertionError, match="hpZeRO"): + engine._validate_zero3_moe_compatibility() + + def test_autoep_layer_marks_zero3_param_placement_families(self): + model = MockMoETransformer(num_layers=1) + replace_autoep_layers(model, "mixtral") + autoep_layer = next(module for module in model.modules() if isinstance(module, AutoEPMoELayer)) + + for param in autoep_layer.experts.parameters(): + assert param.ds_zero_placement_family == "autoep_expert" + assert param.ds_zero_partition_group_name == autoep_layer.ep_group_name + + for param in autoep_layer.router.parameters(): + assert param.ds_zero_placement_family == "replicated" + + def test_zero3_checkpoint_metadata_includes_partition_group_ranks(self): + optimizer = object.__new__(DeepSpeedZeroOptimizer_Stage3) + param = nn.Parameter(torch.empty(1)) + param.ds_zero_placement_family = "autoep_expert" + param.ds_zero_partition_group_name = "ep_size_2" + optimizer.fp16_groups = [[param]] + optimizer._get_sub_group_partition_count = lambda _: 2 + optimizer._get_sub_group_partition_rank = lambda _: 1 + optimizer._get_sub_group_partition_ranks = lambda _: [1, 3] + + metadata = optimizer._zero3_partition_group_metadata() + + assert metadata == [{ + "sub_group": 0, + "partition_count": 2, + "partition_rank": 1, + "partition_ranks": [1, 3], + "families": ["autoep_expert"], + "group_names": ["ep_size_2"], + }] + + param.ds_zero_placement_family = "replicated" + param.ds_zero_partition_group_name = None + assert optimizer._zero3_partition_group_metadata() is None + + def test_zero3_cpu_offload_grad_norm_reduces_autoep_expert_parallel_group(self, monkeypatch): + optimizer = object.__new__(DeepSpeedZeroOptimizer_Stage3) + param = nn.Parameter(torch.empty(1)) + param.ds_zero_placement_family = "autoep_expert" + param.ds_zero_partition_group_name = "ep_size_2" + optimizer.model_parallel_rank = 0 + optimizer.norm_for_param_grads = {7: 3.0} + optimizer.get_param_id = lambda _: 7 + optimizer._assert_same_partition_group = lambda _: None + optimizer._get_param_partition_group = lambda _: "expert_data_parallel" + optimizer._model_parallel_all_reduce = lambda tensor, op: None + optimizer._autoep_expert_parallel_group = lambda _: "expert_parallel" + calls = [] + + def fake_all_reduce(tensor, op=None, group=None): + calls.append(group) + + class FakeAccelerator: + + def FloatTensor(self, values): + return torch.FloatTensor(values) + + monkeypatch.setattr(zero_stage3, "get_accelerator", lambda: FakeAccelerator()) + monkeypatch.setattr(zero_stage3.dist, "all_reduce", fake_all_reduce) + + norm = optimizer.complete_grad_norm_calculation_for_cpu_offload([param]) + + assert calls == ["expert_data_parallel", "expert_parallel"] + assert torch.isfinite(norm) + + def test_pipeline_load_module_state_dict_accepts_autoep_zero3_fetch_kwarg(self): + from deepspeed.runtime.pipe.engine import PipelineEngine + + signature = inspect.signature(PipelineEngine.load_module_state_dict) + + assert "z3_params_to_fetch" in signature.parameters + + def test_universal_converter_rejects_zero3_autoep_model_state(self, tmp_path): + from deepspeed.checkpoint.constants import AUTOEP_LAYERS_KEY + from deepspeed.checkpoint.ds_to_universal import ( + _get_zero3_model_state_files, + _raise_if_stage3_autoep_universal_conversion, + ) + + zero3_model_file = tmp_path / "zero_pp_rank_0_mp_rank_00_model_states.pt" + expert_file = tmp_path / "layer_0_expert_0_mp_rank_00_model_states.pt" + torch.save({AUTOEP_LAYERS_KEY: [{"moe_layer_id": 0}]}, zero3_model_file) + torch.save({"expert": torch.empty(1)}, expert_file) + + model_files = _get_zero3_model_state_files(str(tmp_path)) + + assert model_files == [str(zero3_model_file)] + with pytest.raises(NotImplementedError, match="same-topology ZeRO-3 checkpoint load"): + _raise_if_stage3_autoep_universal_conversion(model_files) + def test_preset_registry_core_contracts(self): assert set(PRESET_MODELS) == {"mixtral", "qwen3_moe", "qwen3_5_moe", "deepseek_v2", "deepseek_v3"} assert preset_name_for_hf_model_type("mixtral") == "mixtral" From 4232ea36981a5b0734d3aaaa1c97c5ad3d525fc1 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 24 May 2026 13:30:06 -0700 Subject: [PATCH 02/19] Fix ZeRO-3 multi-gather wait forwarding Signed-off-by: Masahiro Tanaka --- .../runtime/zero/partition_parameters.py | 2 +- tests/unit/runtime/zero/test_zero_context.py | 38 ++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index dd0e2745efea..59aa3f211ea5 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -789,7 +789,7 @@ def __init__(self, handles: List[Union[AllGatherHandle, AllGatherCoalescedHandle def wait(self, handle_dependency=True) -> None: for handle in self.handles: - handle.wait(handle_dependency) + handle.wait(handle_dependency=handle_dependency) class AllReduceCoalescedHandle: diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index 189502445bf3..02d7dd23fce1 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -8,9 +8,10 @@ import torch import pytest import deepspeed -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.partition_parameters import (MultipleAllGatherHandles, ZeroParamStatus, + partitioned_param_data_shape) from unit.common import DistributedTest, preferred_dtype, reduce_boolean_flags from unit.simple_model import SimpleModel @@ -61,6 +62,41 @@ def forward(self, x): config["fp16"] = {"enabled": True, "loss_scale": 138.} +def test_multiple_all_gather_handles_wait_passes_dependency_by_keyword(): + + class PositionalWaitHandle: + + def __init__(self): + self.handle_dependency = None + + def wait(self, handle_dependency=True): + self.handle_dependency = handle_dependency + + class KeywordOnlyWaitHandle: + + def __init__(self): + self.handle_dependency = None + + def wait(self, *, handle_dependency=True): + self.handle_dependency = handle_dependency + + class KwargsWaitHandle: + + def __init__(self): + self.kwargs = None + + def wait(self, **kwargs): + self.kwargs = kwargs + + handles = [PositionalWaitHandle(), KeywordOnlyWaitHandle(), KwargsWaitHandle()] + + MultipleAllGatherHandles(handles).wait(handle_dependency=False) + + assert handles[0].handle_dependency is False + assert handles[1].handle_dependency is False + assert handles[2].kwargs == {"handle_dependency": False} + + class TestZeroGatheredParametersFree(DistributedTest): world_size = 1 From 807085c893a03c8377e9fca24de4b0ef9df61ff3 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 24 May 2026 14:19:51 -0700 Subject: [PATCH 03/19] Gather AutoEP source ZeRO params Signed-off-by: Masahiro Tanaka --- deepspeed/module_inject/auto_ep.py | 47 ++++--- deepspeed/module_inject/auto_ep_layer.py | 86 +++++++------ deepspeed/moe/ep_repack.py | 110 +++++++++------- tests/unit/v1/moe/test_autoep_integration.py | 26 ++++ tests/unit/v1/moe/test_autoep_unit.py | 129 +++++++++++++++++++ 5 files changed, 300 insertions(+), 98 deletions(-) diff --git a/deepspeed/module_inject/auto_ep.py b/deepspeed/module_inject/auto_ep.py index 3ed6ea06e528..277ae833fc04 100644 --- a/deepspeed/module_inject/auto_ep.py +++ b/deepspeed/module_inject/auto_ep.py @@ -33,6 +33,7 @@ unsupported_preset_for_hf_model_type, ) from deepspeed.moe.fused_expert_layout import classify_fused_gate_up_layout +from deepspeed.runtime.zero.utils import is_zero_param from deepspeed.utils import logger @@ -81,6 +82,16 @@ def _raise_if_duplicate_moe_specs(specs: list[MoELayerSpec]) -> None: "AutoEP patterns so each MoE module matches exactly one preset.") +def _source_param_shape(param: torch.Tensor | nn.Parameter) -> torch.Size: + if is_zero_param(param): + return torch.Size(param.ds_shape) + return torch.Size(param.shape) + + +def _source_param_ndim(param: torch.Tensor | nn.Parameter) -> int: + return len(_source_param_shape(param)) + + def _has_3d_expert_params(module: nn.Module, preset: MoEModelPreset) -> bool: """Check if module stores expert weights as 3D parameter tensors (transformers 5.0.0+). @@ -92,7 +103,7 @@ def _has_3d_expert_params(module: nn.Module, preset: MoEModelPreset) -> bool: if param is None: return False if isinstance(param, nn.Parameter) or isinstance(param, torch.Tensor): - return param.ndim == 3 + return _source_param_ndim(param) == 3 return False @@ -146,7 +157,7 @@ def _detect_expert_storage(experts_module: nn.Module, preset: MoEModelPreset) -> return "module_list" # Check children for 3D params as fallback for name, param in experts_module.named_parameters(recurse=False): - if param.ndim == 3: + if _source_param_ndim(param) == 3: return "fused_3d" return "module_list" @@ -162,13 +173,15 @@ def _infer_hidden_and_ffn_size( w1_param = getattr(experts_module, preset.expert_w1, None) w2_param = getattr(experts_module, preset.expert_w2, None) if w1_param is not None and w2_param is not None: + w1_shape = _source_param_shape(w1_param) + w2_shape = _source_param_shape(w2_param) if preset.expert_w3 is None: - layout = classify_fused_gate_up_layout(tuple(w1_param.shape), tuple(w2_param.shape)) + layout = classify_fused_gate_up_layout(tuple(w1_shape), tuple(w2_shape)) if layout is None: raise ValueError("expert_w3=None expects fused gate+up weights with either " f"[E, 2*ffn, hidden]/[E, hidden, ffn] or [E, hidden, 2*ffn]/[E, ffn, hidden], " - f"but got {preset.expert_w1}={tuple(w1_param.shape)} and " - f"{preset.expert_w2}={tuple(w2_param.shape)}.") + f"but got {preset.expert_w1}={tuple(w1_shape)} and " + f"{preset.expert_w2}={tuple(w2_shape)}.") hidden_size = layout.hidden_size ffn_hidden_size = layout.ffn_hidden_size else: @@ -177,8 +190,8 @@ def _infer_hidden_and_ffn_size( if w3_param is None: raise ValueError(f"expert_w3='{preset.expert_w3}' is set but no such weight " f"exists on experts module.") - hidden_size = w1_param.shape[2] - ffn_hidden_size = w1_param.shape[1] + hidden_size = w1_shape[2] + ffn_hidden_size = w1_shape[1] return hidden_size, ffn_hidden_size elif storage == "module_list": # Legacy: individual expert modules @@ -195,8 +208,9 @@ def _infer_hidden_and_ffn_size( if isinstance(w1, nn.Linear): return w1.in_features, w1.out_features elif isinstance(w1, (nn.Parameter, torch.Tensor)): - if w1.ndim == 2: - return w1.shape[1], w1.shape[0] + w1_shape = _source_param_shape(w1) + if len(w1_shape) == 2: + return w1_shape[1], w1_shape[0] raise ValueError(f"Could not infer hidden_size/ffn_hidden_size from experts module " f"with storage={storage}, preset.expert_w1={preset.expert_w1}") @@ -327,12 +341,13 @@ def ep_parser(self) -> list[MoELayerSpec]: # Validate/derive from router weight shape router_weight = getattr(router_child, 'weight', None) - if router_weight is not None and router_weight.ndim == 2: - num_experts_from_weight = router_weight.shape[0] - hidden_from_weight = router_weight.shape[1] + router_weight_shape = _source_param_shape(router_weight) if router_weight is not None else None + if router_weight_shape is not None and len(router_weight_shape) == 2: + num_experts_from_weight = router_weight_shape[0] + hidden_from_weight = router_weight_shape[1] if num_experts is not None and num_experts != num_experts_from_weight: raise ValueError(f"Config num_experts={num_experts} mismatches router weight " - f"shape {router_weight.shape} (expected {num_experts_from_weight}) " + f"shape {router_weight_shape} (expected {num_experts_from_weight}) " f"in layer '{module_name}'") num_experts = num_experts_from_weight @@ -359,10 +374,10 @@ def ep_parser(self) -> list[MoELayerSpec]: continue # Cross-validate hidden_size with router - if router_weight is not None and router_weight.ndim == 2: - if hidden_size != router_weight.shape[1]: + if router_weight_shape is not None and len(router_weight_shape) == 2: + if hidden_size != router_weight_shape[1]: raise ValueError(f"hidden_size={hidden_size} from expert weights mismatches " - f"router weight dim={router_weight.shape[1]} in '{module_name}'") + f"router weight dim={router_weight_shape[1]} in '{module_name}'") # Validate top_k <= num_experts if top_k > num_experts: diff --git a/deepspeed/module_inject/auto_ep_layer.py b/deepspeed/module_inject/auto_ep_layer.py index 8834eaadb4d4..25a7b0b05261 100644 --- a/deepspeed/module_inject/auto_ep_layer.py +++ b/deepspeed/module_inject/auto_ep_layer.py @@ -25,7 +25,7 @@ from deepspeed.moe.ep_count import count_tokens_per_expert from deepspeed.moe.ep_experts import GroupedExperts from deepspeed.moe.ep_kernels import TokenReorderer -from deepspeed.moe.ep_repack import repack_expert_requires_grad_flags, repack_expert_weights +from deepspeed.moe.ep_repack import _gather_source_zero_params, repack_expert_requires_grad_flags, repack_expert_weights # --------------------------------------------------------------------------- # Named tuples @@ -69,9 +69,16 @@ def resolve_combine_impl( def _copy_parameter_data(target: nn.Parameter, source: torch.Tensor) -> None: + full_shape = torch.Size(getattr(source, "ds_shape", source.shape)) with torch.no_grad(): - target.data = torch.empty_like(source.data) - target.data.copy_(source.data) + source_data = source.data + if torch.Size(source_data.shape) != full_shape: + raise RuntimeError("AutoEP source parameter must be gathered before copying: " + f"expected full shape {tuple(full_shape)}, got {tuple(source_data.shape)}") + if (torch.Size(target.data.shape) != full_shape or target.data.dtype != source_data.dtype + or target.data.device != source_data.device): + target.data = torch.empty(full_shape, dtype=source_data.dtype, device=source_data.device) + target.data.copy_(source_data) def apply_scores_before_experts_if_enabled( @@ -373,43 +380,48 @@ def __init__( # Router: copy gate weights from source source_gate = getattr(source_module, spec.router_name) + source_gate_bias = getattr(source_gate, 'bias', None) + source_ecb = getattr(source_gate, 'e_score_correction_bias', None) + unsupported_router_biases = [ + getattr(source_gate, bias_name, None) for bias_name in spec.unsupported_router_bias_names + ] if not spec.supports_expert_bias and resolved_config.load_balance_coeff is not None: raise ValueError(f"AutoEP preset '{spec.model_family}' does not support load_balance_coeff/expert_bias " "yet. Set load_balance_coeff=None.") - for bias_name in spec.unsupported_router_bias_names: - router_bias = getattr(source_gate, bias_name, None) - if router_bias is None: - continue - if torch.is_tensor(router_bias) and torch.count_nonzero(router_bias.detach()).item() == 0: - continue - raise ValueError(f"AutoEP preset '{spec.model_family}' does not support nonzero router bias " - f"'{bias_name}' yet.") - self.router = TokenChoiceTopKRouter( - dim=spec.hidden_size, - num_experts=spec.num_experts, - num_expert_groups=spec.num_expert_groups, - num_limited_groups=spec.num_limited_groups, - top_k=spec.top_k, - score_func=spec.score_func, - route_norm=route_norm, - route_scale=spec.route_scale, - gate_bias=spec.gate_bias, - group_score_func=spec.group_score_func, - ) - # Copy gate weights - _copy_parameter_data(self.router.gate.weight, source_gate.weight) - self.router.gate.weight.requires_grad_(source_gate.weight.requires_grad) - if spec.gate_bias and getattr(source_gate, 'bias', None) is not None: - _copy_parameter_data(self.router.gate.bias, source_gate.bias) - self.router.gate.bias.requires_grad_(source_gate.bias.requires_grad) - - # Copy pre-trained score correction bias (DeepSeek-V3/Moonlight noaux_tc routing) - source_ecb = getattr(source_gate, 'e_score_correction_bias', None) - if source_ecb is not None and isinstance(source_ecb, nn.Parameter): - self.router.e_score_correction_bias = nn.Parameter(source_ecb.data.clone(), - requires_grad=source_ecb.requires_grad) - logger.info('AutoEP: copied e_score_correction_bias from source gate ' - '(shape=%s)', source_ecb.shape) + with _gather_source_zero_params([source_gate.weight, source_gate_bias, source_ecb, + *unsupported_router_biases]): + for bias_name, router_bias in zip(spec.unsupported_router_bias_names, unsupported_router_biases): + if router_bias is None: + continue + if torch.is_tensor(router_bias) and torch.count_nonzero(router_bias.detach()).item() == 0: + continue + raise ValueError(f"AutoEP preset '{spec.model_family}' does not support nonzero router bias " + f"'{bias_name}' yet.") + self.router = TokenChoiceTopKRouter( + dim=spec.hidden_size, + num_experts=spec.num_experts, + num_expert_groups=spec.num_expert_groups, + num_limited_groups=spec.num_limited_groups, + top_k=spec.top_k, + score_func=spec.score_func, + route_norm=route_norm, + route_scale=spec.route_scale, + gate_bias=spec.gate_bias, + group_score_func=spec.group_score_func, + ) + # Copy gate weights + _copy_parameter_data(self.router.gate.weight, source_gate.weight) + self.router.gate.weight.requires_grad_(source_gate.weight.requires_grad) + if spec.gate_bias and source_gate_bias is not None: + _copy_parameter_data(self.router.gate.bias, source_gate_bias) + self.router.gate.bias.requires_grad_(source_gate_bias.requires_grad) + + # Copy pre-trained score correction bias (DeepSeek-V3/Moonlight noaux_tc routing) + if source_ecb is not None and isinstance(source_ecb, nn.Parameter): + self.router.e_score_correction_bias = nn.Parameter(source_ecb.data.clone(), + requires_grad=source_ecb.requires_grad) + logger.info('AutoEP: copied e_score_correction_bias from source gate ' + '(shape=%s)', source_ecb.shape) # Alias router under the name OutputRecorder expects (layer_name if provided), # but only when OutputRecorder captures from the router child and the alias is safe. diff --git a/deepspeed/moe/ep_repack.py b/deepspeed/moe/ep_repack.py index d106a162050f..0ccdb018f039 100644 --- a/deepspeed/moe/ep_repack.py +++ b/deepspeed/moe/ep_repack.py @@ -10,11 +10,30 @@ from __future__ import annotations +from contextlib import contextmanager + import torch import torch.nn as nn from deepspeed.module_inject.auto_ep_config import MoELayerSpec from deepspeed.moe.fused_expert_layout import classify_fused_gate_up_layout +from deepspeed.runtime.zero import GatheredParameters +from deepspeed.runtime.zero.utils import is_zero_param + + +@contextmanager +def _gather_source_zero_params(params): + """Gather source ZeRO params while AutoEP reads full tensor values.""" + zero_params = [param for param in params if is_zero_param(param)] + if not zero_params: + yield + return + with GatheredParameters(zero_params, modifier_rank=None, enabled=True): + yield + + +def _source_data(param: torch.Tensor | nn.Parameter) -> torch.Tensor: + return param.data if torch.is_tensor(param) else param def repack_expert_weights( @@ -85,41 +104,42 @@ def _repack_fused_3d( w1_full = getattr(experts_source, spec.expert_w1_name) w2_full = getattr(experts_source, spec.expert_w2_name) - if isinstance(w1_full, nn.Parameter): - w1_full = w1_full.data - if isinstance(w2_full, nn.Parameter): - w2_full = w2_full.data - - # Slice to local experts - w1_local = w1_full[expert_start:expert_end].clone() - w2_local = w2_full[expert_start:expert_end].clone() - - if spec.expert_w3_name is None: - layout = classify_fused_gate_up_layout(tuple(w1_local.shape), tuple(w2_local.shape)) - if layout is None: - raise ValueError("Unsupported fused expert weight layout for AutoEP repacking: " - f"{spec.expert_w1_name}={tuple(w1_local.shape)}, " - f"{spec.expert_w2_name}={tuple(w2_local.shape)}") - - ffn_hidden = layout.ffn_hidden_size - if layout.layout == "gate_up_first": - w1 = w1_local[:, :ffn_hidden, :].contiguous() # [E_local, ffn, hidden] - w3 = w1_local[:, ffn_hidden:, :].contiguous() # [E_local, ffn, hidden] - w2 = w2_local.contiguous() # [E_local, hidden, ffn] + source_params = [w1_full, w2_full] + if spec.expert_w3_name is not None: + source_params.append(getattr(experts_source, spec.expert_w3_name)) + + with _gather_source_zero_params(source_params): + w1_full_data = _source_data(w1_full) + w2_full_data = _source_data(w2_full) + + # Slice to local experts + w1_local = w1_full_data[expert_start:expert_end].clone() + w2_local = w2_full_data[expert_start:expert_end].clone() + + if spec.expert_w3_name is None: + layout = classify_fused_gate_up_layout(tuple(w1_local.shape), tuple(w2_local.shape)) + if layout is None: + raise ValueError("Unsupported fused expert weight layout for AutoEP repacking: " + f"{spec.expert_w1_name}={tuple(w1_local.shape)}, " + f"{spec.expert_w2_name}={tuple(w2_local.shape)}") + + ffn_hidden = layout.ffn_hidden_size + if layout.layout == "gate_up_first": + w1 = w1_local[:, :ffn_hidden, :].contiguous() # [E_local, ffn, hidden] + w3 = w1_local[:, ffn_hidden:, :].contiguous() # [E_local, ffn, hidden] + w2 = w2_local.contiguous() # [E_local, hidden, ffn] + else: + w1 = w1_local[:, :, :ffn_hidden].transpose(1, 2).contiguous() # [E_local, ffn, hidden] + w3 = w1_local[:, :, ffn_hidden:].transpose(1, 2).contiguous() # [E_local, ffn, hidden] + w2 = w2_local.transpose(1, 2).contiguous() # [E_local, hidden, ffn] else: - w1 = w1_local[:, :, :ffn_hidden].transpose(1, 2).contiguous() # [E_local, ffn, hidden] - w3 = w1_local[:, :, ffn_hidden:].transpose(1, 2).contiguous() # [E_local, ffn, hidden] - w2 = w2_local.transpose(1, 2).contiguous() # [E_local, hidden, ffn] - else: - # Separate w1 (gate), w3 (up) - w3_full = getattr(experts_source, spec.expert_w3_name) - if isinstance(w3_full, nn.Parameter): - w3_full = w3_full.data - w3_local = w3_full[expert_start:expert_end].clone() + # Separate w1 (gate), w3 (up) + w3_full = getattr(experts_source, spec.expert_w3_name) + w3_local = _source_data(w3_full)[expert_start:expert_end].clone() - w1 = w1_local.contiguous() # [E_local, ffn, hidden] - w2 = w2_local.contiguous() # [E_local, hidden, ffn] - w3 = w3_local.contiguous() # [E_local, ffn, hidden] + w1 = w1_local.contiguous() # [E_local, ffn, hidden] + w2 = w2_local.contiguous() # [E_local, hidden, ffn] + w3 = w3_local.contiguous() # [E_local, ffn, hidden] return w1, w2, w3 @@ -155,23 +175,23 @@ def _repack_module_list( w2_list = [] w3_list = [] - for expert_idx in range(expert_start, expert_end): + for expert_idx in range(len(experts_source)): expert = experts_source[expert_idx] # Get weight tensors - handle both nn.Linear children and direct attributes w1_param = _get_expert_weight(expert, spec.expert_w1_name) w2_param = _get_expert_weight(expert, spec.expert_w2_name) - - # nn.Linear stores weight as [out_features, in_features] - # TorchTitan expects [ffn_hidden, hidden] for w1/w3 and [hidden, ffn_hidden] for w2 - # nn.Linear.weight is already [out, in] which matches TorchTitan's [ffn, hidden] for w1 - # No transpose needed - store as-is - w1_list.append(w1_param.data.clone()) - w2_list.append(w2_param.data.clone()) - - if spec.expert_w3_name is not None: - w3_param = _get_expert_weight(expert, spec.expert_w3_name) - w3_list.append(w3_param.data.clone()) + w3_param = _get_expert_weight(expert, spec.expert_w3_name) if spec.expert_w3_name is not None else None + + with _gather_source_zero_params([w1_param, w2_param, w3_param]): + if expert_start <= expert_idx < expert_end: + # nn.Linear stores weight as [out_features, in_features]. + # TorchTitan expects [ffn_hidden, hidden] for w1/w3 and [hidden, ffn_hidden] for w2. + # nn.Linear.weight is already [out, in], which matches TorchTitan's [ffn, hidden] for w1. + w1_list.append(w1_param.data.clone()) + w2_list.append(w2_param.data.clone()) + if w3_param is not None: + w3_list.append(w3_param.data.clone()) _require_consistent_dtype_device(w1_list, spec.expert_w1_name, expert_start, expert_end) _require_consistent_dtype_device(w2_list, spec.expert_w2_name, expert_start, expert_end) diff --git a/tests/unit/v1/moe/test_autoep_integration.py b/tests/unit/v1/moe/test_autoep_integration.py index 5775f5e218da..4ed77a794f67 100644 --- a/tests/unit/v1/moe/test_autoep_integration.py +++ b/tests/unit/v1/moe/test_autoep_integration.py @@ -150,6 +150,32 @@ def test_zero3_ep_save_load_same_topology_2gpu(self, tmpdir): class TestAutoEPZero3ReplicaGroups(DistributedTest): world_size = 4 + def test_zero3_ep_source_zero_init_expert_replica_placement_4gpu(self): + _seed_everything(3456) + + config = _make_autoep_config(zero_stage=3, ep_size=2) + with deepspeed.zero.Init(config_dict_or_path=config): + model = MockMoETransformer() + assert any(hasattr(param, "ds_id") for param in model.parameters()) + + engine, _, _, _ = deepspeed.initialize(model=model, config=config) + + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer + autoep_layers = [m for _, m in engine.module.named_modules() if isinstance(m, AutoEPMoELayer)] + assert len(autoep_layers) == 2 + + for layer in autoep_layers: + for param in layer.experts.parameters(): + assert param.ds_zero_placement_family == "autoep_expert" + assert param.ds_zero_partition_group_name == layer.ep_group_name + assert param.ds_zero_partition_world_size == 2 + for param in layer.router.parameters(): + assert param.ds_zero_placement_family == "replicated" + assert param.ds_zero_partition_world_size == 4 + + losses, _ = _run_training_steps(engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) + def test_zero3_ep_expert_replica_group_train_save_load_4gpu(self, tmpdir): _seed_everything(9012) diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index ade8f55c4d09..5044859dd82f 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -13,6 +13,7 @@ import deepspeed.runtime.engine as ds_engine import deepspeed.runtime.zero.stage3 as zero_stage3 +import deepspeed.moe.ep_repack as ep_repack from deepspeed.module_inject.auto_ep import AutoEP, _resolve_route_scale from deepspeed.module_inject.auto_ep_config import ( AutoEPConfig, @@ -94,6 +95,47 @@ def _assert_same_dtype_device(actual, expected): assert actual.device == expected.device +def _mark_fake_zero_param(param, full_data, partition_data=None, ds_id=0, name="param"): + param.ds_id = ds_id + param.ds_shape = torch.Size(full_data.shape) + param._autoep_test_full_data = full_data.detach().clone() + param._autoep_test_name = name + if partition_data is None: + partition_data = torch.zeros(1, dtype=full_data.dtype, device=full_data.device) + param.data = partition_data.detach().clone() + return param + + +class FakeGatheredParameters: + calls = [] + + def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True): + self.params = list(params) + self.modifier_rank = modifier_rank + self.enabled = enabled + self._saved_data = [] + FakeGatheredParameters.calls.append({ + "names": [getattr(param, "_autoep_test_name", f"param{param.ds_id}") for param in self.params], + "modifier_rank": + modifier_rank, + "enabled": + enabled, + }) + + def __enter__(self): + if not self.enabled: + return + for param in self.params: + self._saved_data.append((param, param.data)) + param.data = param._autoep_test_full_data.detach().clone() + + def __exit__(self, *exc): + if not self.enabled: + return + for param, data in self._saved_data: + param.data = data + + class MockSharedExpert(nn.Module): def __init__(self, hidden_size=64): @@ -648,6 +690,40 @@ def test_fused_replacement_preserves_frozen_experts_and_trainable_router(self): _assert_same_dtype_device(replaced.experts.w2, source.experts.down_proj) _assert_same_dtype_device(replaced.experts.w3, source.experts.gate_up_proj) + def test_zero_init_source_gathered_for_parser_router_and_fused_repack(self, monkeypatch): + FakeGatheredParameters.calls = [] + monkeypatch.setattr(ep_repack, "GatheredParameters", FakeGatheredParameters) + + model = MockMoETransformer(num_layers=1, num_experts=4, moe_every_n=1) + source = model.model.layers[0].mlp + expected_gate = source.gate.weight.detach().clone() + expected_gate_up = source.experts.gate_up_proj.detach().clone() + expected_down = source.experts.down_proj.detach().clone() + + _mark_fake_zero_param(source.gate.weight, expected_gate, ds_id=1, name="router.weight") + _mark_fake_zero_param(source.experts.gate_up_proj, expected_gate_up, ds_id=2, name="experts.gate_up_proj") + _mark_fake_zero_param(source.experts.down_proj, expected_down, ds_id=3, name="experts.down_proj") + + auto_ep = AutoEP(model, _runtime_config(enabled=True, autoep_size=1, preset_model="mixtral")) + specs = auto_ep.ep_parser() + assert len(specs) == 1 + assert specs[0].expert_storage == "fused_3d" + assert specs[0].num_experts == 4 + assert specs[0].hidden_size == 64 + + auto_ep.replace_moe_layer(specs[0], ep_size=1, ep_rank=0) + + replaced = model.model.layers[0].mlp + torch.testing.assert_close(replaced.router.gate.weight, expected_gate) + torch.testing.assert_close(replaced.experts.w1, expected_gate_up[:, :128, :]) + torch.testing.assert_close(replaced.experts.w3, expected_gate_up[:, 128:, :]) + torch.testing.assert_close(replaced.experts.w2, expected_down) + assert [call["names"] for call in FakeGatheredParameters.calls] == [ + ["router.weight"], + ["experts.gate_up_proj", "experts.down_proj"], + ] + assert all(call["modifier_rank"] is None for call in FakeGatheredParameters.calls) + def test_module_list_replacement_preserves_frozen_experts_and_trainable_router(self, monkeypatch): monkeypatch.setattr(get_preset_adapter("deepseek_v3"), "_installed_transformers_version", lambda: "5.0.0") model = MockDeepSeekV3Transformer(num_layers=1, num_experts=4).to(dtype=torch.bfloat16) @@ -678,6 +754,52 @@ def test_module_list_replacement_preserves_frozen_experts_and_trainable_router(s _assert_same_dtype_device(replaced.experts.w2, source.experts[0].down_proj.weight) _assert_same_dtype_device(replaced.experts.w3, source.experts[0].up_proj.weight) + def test_module_list_zero_source_gathers_all_experts_in_global_order(self, monkeypatch): + FakeGatheredParameters.calls = [] + monkeypatch.setattr(ep_repack, "GatheredParameters", FakeGatheredParameters) + monkeypatch.setattr(get_preset_adapter("deepseek_v3"), "_installed_transformers_version", lambda: "5.0.0") + + model = MockDeepSeekV3Transformer(num_layers=1, num_experts=4) + source = model.model.layers[0].mlp + for expert_idx, expert in enumerate(source.experts): + for offset, (suffix, param) in enumerate(( + ("w1", expert.gate_proj.weight), + ("w2", expert.down_proj.weight), + ("w3", expert.up_proj.weight), + )): + full_data = param.detach().clone() + _mark_fake_zero_param(param, + full_data, + ds_id=10 + 3 * expert_idx + offset, + name=f"e{expert_idx}.{suffix}") + + auto_ep = AutoEP(model, _runtime_config(enabled=True, autoep_size=2)) + spec = auto_ep.ep_parser()[0] + w1, w2, w3 = repack_expert_weights(source.experts, spec, ep_rank=1, ep_size=2) + + expected_w1 = torch.stack([ + source.experts[2].gate_proj.weight._autoep_test_full_data, + source.experts[3].gate_proj.weight._autoep_test_full_data + ]) + expected_w2 = torch.stack([ + source.experts[2].down_proj.weight._autoep_test_full_data, + source.experts[3].down_proj.weight._autoep_test_full_data + ]) + expected_w3 = torch.stack([ + source.experts[2].up_proj.weight._autoep_test_full_data, + source.experts[3].up_proj.weight._autoep_test_full_data + ]) + + torch.testing.assert_close(w1, expected_w1) + torch.testing.assert_close(w2, expected_w2) + torch.testing.assert_close(w3, expected_w3) + assert [call["names"] for call in FakeGatheredParameters.calls] == [ + ["e0.w1", "e0.w2", "e0.w3"], + ["e1.w1", "e1.w2", "e1.w3"], + ["e2.w1", "e2.w2", "e2.w3"], + ["e3.w1", "e3.w2", "e3.w3"], + ] + def test_module_list_mixed_expert_requires_grad_flags_are_rejected(self, monkeypatch): monkeypatch.setattr(get_preset_adapter("deepseek_v3"), "_installed_transformers_version", lambda: "5.0.0") model = MockDeepSeekV3Transformer(num_layers=1, num_experts=4) @@ -734,6 +856,8 @@ def test_qwen_adapter_guards(self, monkeypatch): AutoEP(model, _runtime_config(enabled=True, autoep_size=1))._resolve_presets() def test_deepseek_v3_detection_and_score_correction_bias_copy(self, monkeypatch): + FakeGatheredParameters.calls = [] + monkeypatch.setattr(ep_repack, "GatheredParameters", FakeGatheredParameters) monkeypatch.setattr(get_preset_adapter("deepseek_v3"), "_installed_transformers_version", lambda: "5.0.0") model = MockDeepSeekV3Transformer(num_layers=1, num_experts=8) auto_ep = AutoEP(model, _runtime_config(enabled=True, autoep_size=2)) @@ -747,6 +871,10 @@ def test_deepseek_v3_detection_and_score_correction_bias_copy(self, monkeypatch) source_bias = torch.arange(8, dtype=torch.float32) model.model.layers[0].mlp.gate.e_score_correction_bias = nn.Parameter(source_bias.clone()) + _mark_fake_zero_param(model.model.layers[0].mlp.gate.e_score_correction_bias, + source_bias, + ds_id=100, + name="router.e_score_correction_bias") auto_ep.replace_moe_layer(specs[0], ep_size=2, ep_rank=0) @@ -754,3 +882,4 @@ def test_deepseek_v3_detection_and_score_correction_bias_copy(self, monkeypatch) assert isinstance(replaced, AutoEPMoELayer) assert replaced.router.e_score_correction_bias is not None torch.testing.assert_close(replaced.router.e_score_correction_bias, source_bias) + assert ["router.e_score_correction_bias"] in [call["names"] for call in FakeGatheredParameters.calls] From 28e725ab331a3d30a97e405516d5a4934132d111 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 11 Jun 2026 12:38:39 -0700 Subject: [PATCH 04/19] Document constrained AutoEP ZeRO-3 support Signed-off-by: Masahiro Tanaka --- docs/code-docs/source/autoep.rst | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/docs/code-docs/source/autoep.rst b/docs/code-docs/source/autoep.rst index 2d2f4a2d9cbd..991591e4348a 100644 --- a/docs/code-docs/source/autoep.rst +++ b/docs/code-docs/source/autoep.rst @@ -47,7 +47,13 @@ Transformers build that exposes the matching config/model classes, - ``load_balance_coeff`` / expert-bias auxiliary-loss-free load balancing is not currently supported; non-null values are rejected. -**ZeRO compatibility:** Stages 0, 1, and 2. Stage 3 is not supported. +**ZeRO compatibility:** Stages 0, 1, and 2, plus constrained Stage 3 +support. Stage 3 requires AutoEP-managed MoE layers and does not support native +DeepSpeed MoE layers, AutoTP, tensor model parallelism from ``mpu``, sequence +parallelism, MiCS, hpZeRO secondary tensor groups, non-1 expert tensor +parallelism, or quantized gradients. Stage 3 checkpoint load is same-topology +only with optimizer state; module-only loads, optimizer-state-free loads, +Universal Checkpoint conversion, and topology changes are not supported. **Usage:** @@ -79,15 +85,16 @@ Transformers build that exposes the matching config/model classes, - AutoEP currently cannot be combined with AutoTP (``tensor_parallel.autotp_size > 1``) or tensor model parallelism from ``mpu``; support is planned as follow-up work. -- AutoEP currently supports ZeRO stages 0, 1, and 2 only. ZeRO stage 3 and its - partitioned-parameter get/set APIs are outside the scope of the current AutoEP - support. -- Checkpoint save/load requires matching ``autoep_size``. - To change ``autoep_size`` across runs for the same AutoEP-detected model - topology, convert the checkpoint to Universal Checkpoint format and load it - with ``checkpoint.load_universal``; see the +- AutoEP with ZeRO Stage 3 is supported only without sequence parallelism, + MiCS, hpZeRO secondary tensor groups, non-1 expert tensor parallelism, or + quantized gradients. +- Checkpoint save/load requires matching ``autoep_size``. To change + ``autoep_size`` across runs for the same AutoEP-detected model topology, + convert a ZeRO Stage 1 or ZeRO Stage 2 checkpoint to Universal Checkpoint + format and load it with ``checkpoint.load_universal``; see the `Universal Checkpointing tutorial `__ - for the detailed flow and constraints. + for the detailed flow and constraints. ZeRO Stage 3 AutoEP checkpoints must + be loaded with the same topology. - DeepSeek-V2 and DeepSeek-V3 AutoEP do not support load-balance expert bias yet. The built-in DeepSeek presets disable it by default; explicit non-null values fail. From 5770e0f9685e597304280e22fe57221b7dcd9501 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 11 Jun 2026 12:39:11 -0700 Subject: [PATCH 05/19] Fix AutoEP zero.Init unit test import Signed-off-by: Masahiro Tanaka --- tests/unit/v1/moe/test_autoep_unit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index 5044859dd82f..517c9120119e 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -38,6 +38,7 @@ from deepspeed.moe.layer import MoE from deepspeed.moe.ep_experts import GroupedExperts from deepspeed.moe.ep_kernels import TokenReorderer +from deepspeed.moe.ep_repack import repack_expert_weights from deepspeed.moe.ep_router import TokenChoiceTopKRouter from deepspeed.runtime.engine import DeepSpeedEngine from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 From 573b52514ec92a58ef85a6d8a1f834226d19daad Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 11 Jun 2026 12:41:08 -0700 Subject: [PATCH 06/19] Add 8 GPU AutoEP zero.Init validation Signed-off-by: Masahiro Tanaka --- tests/unit/v1/moe/test_autoep_integration.py | 30 ++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/unit/v1/moe/test_autoep_integration.py b/tests/unit/v1/moe/test_autoep_integration.py index 4ed77a794f67..69e276d91a0e 100644 --- a/tests/unit/v1/moe/test_autoep_integration.py +++ b/tests/unit/v1/moe/test_autoep_integration.py @@ -214,3 +214,33 @@ def test_zero3_ep_expert_replica_group_train_save_load_4gpu(self, tmpdir): losses, _ = _run_training_steps(reloaded_engine, num_steps=1) assert torch.isfinite(torch.tensor(losses[0])) + + +class TestAutoEPZero3ReplicaGroups8GPU(DistributedTest): + world_size = 8 + + def test_zero3_ep_source_zero_init_expert_replica_placement_8gpu(self): + _seed_everything(4567) + + config = _make_autoep_config(zero_stage=3, ep_size=4) + with deepspeed.zero.Init(config_dict_or_path=config): + model = MockMoETransformer() + assert any(hasattr(param, "ds_id") for param in model.parameters()) + + engine, _, _, _ = deepspeed.initialize(model=model, config=config) + + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer + autoep_layers = [m for _, m in engine.module.named_modules() if isinstance(m, AutoEPMoELayer)] + assert len(autoep_layers) == 2 + + for layer in autoep_layers: + for param in layer.experts.parameters(): + assert param.ds_zero_placement_family == "autoep_expert" + assert param.ds_zero_partition_group_name == layer.ep_group_name + assert param.ds_zero_partition_world_size == 2 + for param in layer.router.parameters(): + assert param.ds_zero_placement_family == "replicated" + assert param.ds_zero_partition_world_size == 8 + + losses, _ = _run_training_steps(engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) From 8c05ea76e1f7865cf6f3c81f4690d5d2c168bd40 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 11 Jun 2026 13:03:53 -0700 Subject: [PATCH 07/19] Fix AutoEP ZeRO-3 expert gradient averaging Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/stage3.py | 16 ++++++- tests/unit/v1/moe/test_autoep_unit.py | 66 +++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9f408d0e4f70..254606393eaa 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -628,6 +628,11 @@ def _autoep_expert_parallel_group(self, params): raise RuntimeError("AutoEP expert ZeRO Stage 3 subgroup is missing an expert-parallel group name.") return groups._get_expert_parallel_group(group_name) + def _gradient_averaging_world_size(self, params, partition_world_size): + if self._autoep_expert_parallel_group(params) is None: + return partition_world_size + return dist.get_world_size(group=self.dp_process_group) + def _get_trainable_parameter_groups(self): param_groups = [] PARAMS_KEY = "params" @@ -1656,11 +1661,12 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, world_sz = dist.get_world_size(process_group) rank = dist.get_rank(process_group) - buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size)) + averaging_world_sz = self._gradient_averaging_world_size(params_in_bucket, world_sz) + buffer_to_reduce.div_(averaging_world_sz / float(self.sequence_parallel_size)) dist.all_reduce(buffer_to_reduce, group=process_group) - if self.postscale_gradients and self.gradient_predivide_factor != world_sz: + if self.postscale_gradients and self.gradient_predivide_factor != averaging_world_sz: buffer_to_reduce = buffer_to_reduce.mul(self.gradient_predivide_factor) if communication_data_type != self.gradient_accumulation_dtype: @@ -1719,6 +1725,12 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter], process_group): grad_partitions_for_rank = [g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank] + partition_world_size = dist.get_world_size(group=process_group) + averaging_world_size = self._gradient_averaging_world_size(params_to_reduce, partition_world_size) + if averaging_world_size != partition_world_size: + scale = partition_world_size / float(averaging_world_size) + grad_partitions_for_rank = [g.mul(scale) for g in grad_partitions_for_rank] + if communication_data_type != self.gradient_accumulation_dtype: grad_partitions_for_rank = [g.to(self.gradient_accumulation_dtype) for g in grad_partitions_for_rank] diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index 517c9120119e..dc6f1b01d2f5 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -523,6 +523,72 @@ def FloatTensor(self, values): assert calls == ["expert_data_parallel", "expert_parallel"] assert torch.isfinite(norm) + def test_zero3_autoep_reduce_scatter_grads_average_by_global_dp(self, monkeypatch): + optimizer = object.__new__(DeepSpeedZeroOptimizer_Stage3) + optimizer.dp_process_group = "global_data_parallel" + optimizer.dtype = torch.float32 + optimizer.gradient_accumulation_dtype = torch.float32 + optimizer.postscale_gradients = True + optimizer.gradient_predivide_factor = 1.0 + optimizer.all2all_process_group = None + optimizer._assert_same_partition_group = lambda _: None + optimizer._get_param_partition_group = lambda _: "expert_data_parallel" + optimizer._autoep_expert_parallel_group = lambda _: "expert_parallel" + param = nn.Parameter(torch.ones(4)) + param.grad = torch.ones(4) + + class FakeAccelerator: + + def device_count(self): + return 4 + + def fake_get_world_size(group=None): + return 2 if group == "expert_data_parallel" else 4 + + def fake_reduce_scatter(grads, process_group): + assert process_group == "expert_data_parallel" + return [torch.full((2, ), 8.0)] + + monkeypatch.setattr(zero_stage3, "get_accelerator", lambda: FakeAccelerator()) + monkeypatch.setattr(zero_stage3.dist, "get_world_size", fake_get_world_size) + monkeypatch.setattr(zero_stage3, "reduce_scatter_coalesced", fake_reduce_scatter) + + grad_partitions = optimizer._DeepSpeedZeroOptimizer_Stage3__avg_scatter_grads([param], torch.float32) + + torch.testing.assert_close(grad_partitions[0], torch.full((2, ), 4.0)) + + def test_zero3_autoep_contiguous_grads_average_by_global_dp(self, monkeypatch): + optimizer = object.__new__(DeepSpeedZeroOptimizer_Stage3) + optimizer.dp_process_group = "global_data_parallel" + optimizer.ipg_buckets = {torch.float32: SimpleNamespace(params=[], process_group="expert_data_parallel")} + optimizer.postscale_gradients = True + optimizer.gradient_predivide_factor = 1.0 + optimizer.sequence_parallel_size = 1 + optimizer.gradient_accumulation_dtype = torch.float32 + optimizer._assert_same_partition_group = lambda _: None + optimizer._autoep_expert_parallel_group = lambda _: "expert_parallel" + optimizer._apply_distributed_muon_update = lambda communication_data_type, buffer: None + param = nn.Parameter(torch.empty(2)) + param.grad = torch.zeros(2) + param.partition_numel = lambda: 1 + optimizer.ipg_buckets[torch.float32].params = [param] + + def fake_get_world_size(group=None): + return 2 if group == "expert_data_parallel" else 4 + + def fake_all_reduce(tensor, group=None): + assert group == "expert_data_parallel" + tensor.mul_(2) + + monkeypatch.setattr(zero_stage3.dist, "get_world_size", fake_get_world_size) + monkeypatch.setattr(zero_stage3.dist, "get_rank", lambda group=None: 0) + monkeypatch.setattr(zero_stage3.dist, "all_reduce", fake_all_reduce) + + grad_partitions = optimizer._DeepSpeedZeroOptimizer_Stage3__avg_scatter_contiguous_grads( + torch.tensor([4.0, 8.0]), torch.float32) + + torch.testing.assert_close(grad_partitions[0], torch.tensor([2.0])) + def test_pipeline_load_module_state_dict_accepts_autoep_zero3_fetch_kwarg(self): from deepspeed.runtime.pipe.engine import PipelineEngine From d7fac4a9546102bfc86f4b32c9b9c84b049983c6 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 11 Jun 2026 15:31:50 -0700 Subject: [PATCH 08/19] Add partition-native AutoEP ZeRO-3 checkpoints Signed-off-by: Masahiro Tanaka --- deepspeed/checkpoint/constants.py | 4 + deepspeed/checkpoint/ds_to_universal.py | 386 ++++++++++++++++++- deepspeed/runtime/engine.py | 346 +++++++++++++++-- deepspeed/runtime/pipe/engine.py | 6 +- deepspeed/runtime/zero/stage3.py | 56 ++- deepspeed/utils/zero_to_fp32.py | 31 +- tests/unit/v1/moe/test_autoep_checkpoint.py | 71 ++++ tests/unit/v1/moe/test_autoep_integration.py | 36 +- tests/unit/v1/moe/test_autoep_unit.py | 120 +++++- 9 files changed, 988 insertions(+), 68 deletions(-) diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 0f83458a713d..14c6c6a18faa 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -93,6 +93,10 @@ ######################################### AUTOEP_LAYERS_KEY = 'ds_autoep_layers' AUTOEP_LAYERS_KEY_LEGACY = 'autoep_layers' +AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY = 'checkpoint_format' +AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT = 'zero3_partitioned' +AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY = 'checkpoint_format_version' +AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION = 1 ######################################### # Universal Checkpoint EP keys diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 7d8c9d728bad..6ba90a14fa67 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -45,6 +45,12 @@ PARAMETER_WITH_SUB_PARAMS, AUTOEP_LAYERS_KEY, AUTOEP_LAYERS_KEY_LEGACY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY, + AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, + EP_IS_EXPERT_PARAM, + EP_NUM_EXPERTS, EXPERT_PARAMETER_PATTERNS, SubparamShape, ) @@ -152,23 +158,31 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): fragment_mapping.start, fragment_mapping.numel) -def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): +def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index, exclude_param_names=None): + exclude_param_names = exclude_param_names or set() state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False) + optim_sd = state_dict[OPTIMIZER_STATE_DICT] + partition_groups = optim_sd.get('ds_zero_partition_groups') or [] for idx, sub_group_shape in enumerate(param_shapes): flat_state = dict( - exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg"], - exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg_sq"], - fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][idx], + exp_avg=optim_sd['optimizer_state_dict']['state'][idx]["exp_avg"], + exp_avg_sq=optim_sd['optimizer_state_dict']['state'][idx]["exp_avg_sq"], + fp32=optim_sd['fp32_flat_groups'][idx], ) + partition_metadata = partition_groups[idx] if idx < len(partition_groups) else {} + partition_count = partition_metadata.get('partition_count', dp_degree) + partition_rank = partition_metadata.get('partition_rank', dp_index) offset = 0 for name, shape in sub_group_shape.items(): - unpartitioned_numel = shape.numel() - partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree) - padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel)) - for state_key in flat_state.keys(): - dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, - padding_free_numel) + unpartitioned_numel = _shape_numel(shape) + partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, partition_count) + padding_free_numel = max(0, min(partitioned_numel, + unpartitioned_numel - partition_rank * partitioned_numel)) + if name not in exclude_param_names: + for state_key in flat_state.keys(): + dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, + padding_free_numel) offset += partitioned_numel @@ -373,8 +387,13 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) -def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir): - do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir) +def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir, exclude_param_names=None): + do_work = partial(extract_zero_shards_stage3, + optim_files, + param_shapes, + dp_degree, + temp_dir, + exclude_param_names=exclude_param_names) _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) @@ -406,6 +425,303 @@ def _zero_partitioned_param_info(unpartitioned_numel, world_size): return partitioned_numel, padding_numel +def _shape_numel(shape): + if hasattr(shape, "numel"): + return shape.numel() + return math.prod(shape) + + +def _zero3_rank_from_file(path): + match = re.search(r'(?:bf16_)?zero_pp_rank_([0-9]+)_mp_rank_', os.path.basename(path)) + if match is None: + raise ValueError(f"Cannot parse ZeRO rank from checkpoint file name: {path}") + return int(match.group(1)) + + +def _get_autoep_metadata(model_state): + autoep_metadata = model_state.get(AUTOEP_LAYERS_KEY) + if autoep_metadata is None: + autoep_metadata = model_state.get(AUTOEP_LAYERS_KEY_LEGACY) + return autoep_metadata + + +def _uses_zero3_partitioned_autoep_metadata(autoep_metadata): + if not isinstance(autoep_metadata, list): + return False + _validate_zero3_partitioned_autoep_metadata(autoep_metadata, require_partitioned=False) + return any( + isinstance(entry, dict) + and entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT + for entry in autoep_metadata) + + +def _validate_zero3_partitioned_autoep_metadata(autoep_metadata, require_partitioned=True): + if not isinstance(autoep_metadata, list): + raise RuntimeError(f"ds_autoep_layers metadata is malformed: expected list, got " + f"{type(autoep_metadata).__name__}") + + required_fields = { + 'moe_layer_id', + 'module_path', + 'num_experts', + 'num_local_experts', + 'ep_size', + 'expert_key_prefix', + } + partitioned_fields = { + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + 'ep_rank', + 'expert_data_parallel_rank', + 'expert_data_parallel_world_size', + 'global_expert_start', + 'global_expert_end', + } + seen_layer_ids = set() + seen_prefixes = set() + partitioned_count = 0 + + for entry in autoep_metadata: + if not isinstance(entry, dict): + raise RuntimeError(f"ds_autoep_layers entry is malformed: expected dict, got {type(entry).__name__}") + missing = required_fields - entry.keys() + if missing: + raise RuntimeError(f"ds_autoep_layers entry is invalid: missing fields {sorted(missing)}") + layer_id = entry['moe_layer_id'] + if layer_id in seen_layer_ids: + raise RuntimeError(f"ds_autoep_layers metadata has duplicate moe_layer_id: {layer_id}") + seen_layer_ids.add(layer_id) + prefix = entry['expert_key_prefix'] + if prefix in seen_prefixes: + raise RuntimeError(f"ds_autoep_layers metadata has duplicate expert_key_prefix: {prefix}") + seen_prefixes.add(prefix) + + checkpoint_format = entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) + if checkpoint_format is None: + continue + if checkpoint_format != AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT: + continue + + missing = partitioned_fields - entry.keys() + if missing: + raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata is invalid: missing fields {sorted(missing)}") + version = entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY] + if version != AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION: + raise RuntimeError("Unsupported AutoEP ZeRO-3 checkpoint format version: " + f"{version}. This converter supports version " + f"{AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION}.") + if entry['num_local_experts'] * entry['ep_size'] != entry['num_experts']: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata is inconsistent: " + f"num_local_experts={entry['num_local_experts']}, ep_size={entry['ep_size']}, " + f"num_experts={entry['num_experts']}") + expected_start = entry['ep_rank'] * entry['num_local_experts'] + expected_end = expected_start + entry['num_local_experts'] + if entry['global_expert_start'] != expected_start or entry['global_expert_end'] != expected_end: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has inconsistent global expert range: " + f"got [{entry['global_expert_start']}, {entry['global_expert_end']}), " + f"expected [{expected_start}, {expected_end})") + partitioned_count += 1 + + if require_partitioned and partitioned_count == 0: + raise RuntimeError("Expected AutoEP ZeRO-3 partition-native metadata but found no partitioned entries") + + +def _autoep_expert_param_info(autoep_metadata): + info = {} + if not isinstance(autoep_metadata, list): + return info + _validate_zero3_partitioned_autoep_metadata(autoep_metadata) + for entry in autoep_metadata: + if not isinstance(entry, dict): + continue + if entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) != AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT: + continue + prefix = entry.get('expert_key_prefix') + if not prefix: + continue + for wname in ('w1', 'w2', 'w3'): + info[f"{prefix}.{wname}"] = entry + return info + + +def _autoep_expert_param_names_by_rank(model_files): + expert_param_names = set() + metadata_by_rank = {} + for model_file in model_files: + rank = _zero3_rank_from_file(model_file) + model_state = torch.load(model_file, map_location=torch.device('cpu'), weights_only=False) + autoep_metadata = _get_autoep_metadata(model_state) + if autoep_metadata is not None: + metadata_by_rank[rank] = autoep_metadata + if _uses_zero3_partitioned_autoep_metadata(autoep_metadata): + expert_param_names.update(_autoep_expert_param_info(autoep_metadata)) + return expert_param_names, metadata_by_rank + + +def _rank_map_from_files(files, description): + rank_map = {} + for path in files: + rank = _zero3_rank_from_file(path) + if rank in rank_map: + raise RuntimeError(f"Duplicate ZeRO rank {rank} in {description} files: " + f"{rank_map[rank]} and {path}") + rank_map[rank] = path + return rank_map + + +def _validate_zero3_model_optim_rank_sets(model_files, optim_files): + model_rank_map = _rank_map_from_files(model_files, "model-state") + optim_rank_map = _rank_map_from_files(optim_files, "optimizer-state") + model_ranks = set(model_rank_map) + optim_ranks = set(optim_rank_map) + if model_ranks != optim_ranks: + raise RuntimeError("ZeRO-3 checkpoint model/optimizer rank sets do not match: " + f"model_only={sorted(model_ranks - optim_ranks)}, " + f"optim_only={sorted(optim_ranks - model_ranks)}") + if not model_ranks: + raise RuntimeError("ZeRO-3 checkpoint has no model/optimizer rank files") + return model_rank_map, optim_rank_map + + +def _validate_autoep_expert_shapes(model_states_by_rank, metadata_by_rank): + for rank, autoep_metadata in metadata_by_rank.items(): + if not _uses_zero3_partitioned_autoep_metadata(autoep_metadata): + continue + expert_info = _autoep_expert_param_info(autoep_metadata) + param_shapes = model_states_by_rank[rank][PARAM_SHAPES] + zero_shape_names = {name for sub_group_shape in param_shapes for name in sub_group_shape} + missing = set(expert_info) - zero_shape_names + if missing: + raise RuntimeError(f"AutoEP expert parameters are missing from rank {rank} ZeRO param_shapes: " + f"{sorted(missing)}") + frozen_shapes = model_states_by_rank[rank].get('frozen_param_shapes') or {} + frozen_experts = set(expert_info).intersection(frozen_shapes) + if frozen_experts: + raise RuntimeError("AutoEP frozen expert parameters cannot be converted from the ZeRO-3 " + f"partition-native format yet: {sorted(frozen_experts)}") + + +def _save_zero3_autoep_universal_tensor(output_dir, param_name, state_key, tensor, num_experts): + param_dir = os.path.join(output_dir, "zero", param_name) + os.makedirs(param_dir, exist_ok=True) + _save_checkpoint( + os.path.join(param_dir, f"{state_key}.pt"), + { + PARAM: tensor, + CAT_DIM: 0, + EP_IS_EXPERT_PARAM: True, + EP_NUM_EXPERTS: num_experts, + }, + ) + + +def _consolidate_zero3_autoep_expert_states(output_dir, model_files, optim_files): + model_rank_map, optim_rank_map = _validate_zero3_model_optim_rank_sets(model_files, optim_files) + model_states_by_rank = { + rank: torch.load(model_file, map_location=torch.device('cpu'), weights_only=False) + for rank, model_file in model_rank_map.items() + } + optim_states_by_rank = { + rank: torch.load(optim_file, map_location=torch.device('cpu'), weights_only=False) + for rank, optim_file in optim_rank_map.items() + } + metadata_by_rank = { + rank: _get_autoep_metadata(model_state) + for rank, model_state in model_states_by_rank.items() if _get_autoep_metadata(model_state) is not None + } + _validate_autoep_expert_shapes(model_states_by_rank, metadata_by_rank) + + expert_fragments = {} + num_experts_by_param = {} + expected_dp_world_by_param_rank = {} + expected_ep_ranks_by_param = {} + + for rank, model_state in model_states_by_rank.items(): + optim_state = optim_states_by_rank.get(rank) + if optim_state is None: + raise FileNotFoundError(f"Missing ZeRO optimizer checkpoint for rank {rank}") + + autoep_metadata = _get_autoep_metadata(model_state) + if not _uses_zero3_partitioned_autoep_metadata(autoep_metadata): + continue + + expert_info = _autoep_expert_param_info(autoep_metadata) + param_shapes = model_state[PARAM_SHAPES] + zero_optim_state = optim_state[OPTIMIZER_STATE_DICT] + partition_groups = zero_optim_state.get('ds_zero_partition_groups') or [] + + for sub_group_id, sub_group_shape in enumerate(param_shapes): + optimizer_sub_state = zero_optim_state['optimizer_state_dict']['state'][sub_group_id] + flat_state = { + 'fp32': zero_optim_state['fp32_flat_groups'][sub_group_id], + 'exp_avg': optimizer_sub_state.get('exp_avg'), + 'exp_avg_sq': optimizer_sub_state.get('exp_avg_sq'), + } + partition_metadata = partition_groups[sub_group_id] if sub_group_id < len(partition_groups) else {} + partition_count = partition_metadata.get('partition_count', len(model_states_by_rank)) + partition_rank = partition_metadata.get('partition_rank', rank) + + offset = 0 + for param_name, shape in sub_group_shape.items(): + unpartitioned_numel = _shape_numel(shape) + partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, partition_count) + padding_free_numel = max( + 0, min(partitioned_numel, unpartitioned_numel - partition_rank * partitioned_numel)) + + layer_info = expert_info.get(param_name) + if layer_info is not None: + ep_rank = layer_info['ep_rank'] + num_experts_by_param[param_name] = layer_info['num_experts'] + expected_dp_world_by_param_rank[(param_name, + ep_rank)] = layer_info['expert_data_parallel_world_size'] + expected_ep_ranks_by_param[param_name] = set(range(layer_info['ep_size'])) + for state_key, flat_tensor in flat_state.items(): + if flat_tensor is None: + raise RuntimeError(f"Missing optimizer state '{state_key}' for AutoEP expert " + f"parameter {param_name} on ZeRO rank {rank}") + fragment = flat_tensor.narrow(0, offset, padding_free_numel).clone() + key = (param_name, state_key, ep_rank) + expert_fragments.setdefault(key, []).append((partition_rank, fragment, shape)) + + offset += partitioned_numel + + grouped_by_param = {} + for (param_name, state_key, ep_rank), fragments in expert_fragments.items(): + grouped_by_param.setdefault((param_name, state_key), {})[ep_rank] = fragments + + for (param_name, state_key), ep_rank_fragments in grouped_by_param.items(): + missing_ep_ranks = expected_ep_ranks_by_param[param_name] - set(ep_rank_fragments) + if missing_ep_ranks: + raise RuntimeError(f"Missing AutoEP universal fragments for {param_name}/{state_key} EP ranks: " + f"{sorted(missing_ep_ranks)}") + ep_tensors = [] + for ep_rank in sorted(ep_rank_fragments): + fragments = sorted(ep_rank_fragments[ep_rank], key=lambda item: item[0]) + expected_dp_world = expected_dp_world_by_param_rank[(param_name, ep_rank)] + partition_ranks = [partition_rank for partition_rank, _, _ in fragments] + if len(partition_ranks) != len(set(partition_ranks)): + raise RuntimeError(f"Duplicate AutoEP expert-DP partition ranks for {param_name}/{state_key} " + f"EP rank {ep_rank}: {partition_ranks}") + if set(partition_ranks) != set(range(expected_dp_world)): + raise RuntimeError(f"Incomplete AutoEP expert-DP fragments for {param_name}/{state_key} " + f"EP rank {ep_rank}: got {sorted(partition_ranks)}, " + f"expected {list(range(expected_dp_world))}") + shape = fragments[0][2] + if any(tuple(fragment_shape) != tuple(shape) for _, _, fragment_shape in fragments): + raise RuntimeError(f"Inconsistent AutoEP expert fragment shapes for {param_name}/{state_key} " + f"EP rank {ep_rank}") + full_flat = torch.cat([fragment for _, fragment, _ in fragments], dim=0)[:_shape_numel(shape)] + ep_tensors.append(full_flat.view(shape)) + + if not ep_tensors: + continue + full_expert_tensor = torch.cat(ep_tensors, dim=0) + if full_expert_tensor.shape[0] != num_experts_by_param[param_name]: + raise RuntimeError(f"AutoEP universal tensor for {param_name}/{state_key} has wrong expert dimension: " + f"got {full_expert_tensor.shape[0]}, expected {num_experts_by_param[param_name]}") + _save_zero3_autoep_universal_tensor(output_dir, param_name, state_key, full_expert_tensor, + num_experts_by_param[param_name]) + + def _parse_model_states_stage3(files): return torch.load(files[0], map_location=torch.device('cpu'), weights_only=False)[PARAM_SHAPES] @@ -435,6 +751,10 @@ def _get_optim_files(checkpoint_dir): return _get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") +def _filter_zero3_optim_files(optim_files): + return [f for f in optim_files if re.match(r'(?:bf16_)?zero_pp_rank_', os.path.basename(f))] + + def _get_model_state_files(checkpoint_dir): return _get_checkpoint_files(checkpoint_dir, "*_model_states.pt") @@ -509,7 +829,10 @@ def main(args): print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}') optim_files = _get_optim_files(args.input_folder) - zero_stage = _get_zero_stage(optim_files) + zero3_optim_files = _filter_zero3_optim_files(optim_files) + zero_stage = _get_zero_stage(zero3_optim_files or optim_files) + if zero_stage > 2 and zero3_optim_files: + optim_files = zero3_optim_files if zero_stage <= 2: ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) @@ -588,19 +911,39 @@ def main(args): else: # Stage 3 path model_files = _get_zero3_model_state_files(args.input_folder) - _raise_if_stage3_autoep_universal_conversion(model_files) + autoep_expert_param_names, autoep_metadata_by_rank = _autoep_expert_param_names_by_rank(model_files) + has_autoep_metadata = any(metadata is not None for metadata in autoep_metadata_by_rank.values()) + has_zero3_partitioned_autoep = any( + _uses_zero3_partitioned_autoep_metadata(metadata) for metadata in autoep_metadata_by_rank.values()) + if has_autoep_metadata and not has_zero3_partitioned_autoep: + raise NotImplementedError("Stage 3 universal checkpoint conversion for AutoEP requires the " + "partition-native AutoEP ZeRO-3 checkpoint format.") + if not has_zero3_partitioned_autoep: + autoep_expert_param_names = set() + else: + _validate_zero3_model_optim_rank_sets(model_files, optim_files) param_shapes = _parse_model_states_stage3(model_files) dp_degree = len(model_files) temp_dir = os.path.join(args.output_folder, 'tmp') print('*** 1. Extracting ZeRO fragments') - _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir) + _extract_zero_shard_files_stage3(args, + optim_files, + param_shapes, + dp_degree, + temp_dir, + exclude_param_names=autoep_expert_param_names) print('*** 2. Merging slices .....') param_keys = {key for sub_group_shapes in param_shapes for key in sub_group_shapes.keys()} + param_keys -= autoep_expert_param_names _merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir) + if has_zero3_partitioned_autoep: + print('*** 2.5. Consolidating AutoEP ZeRO-3 expert states') + _consolidate_zero3_autoep_expert_states(args.output_folder, model_files, optim_files) + print('*** 3. Saving common optimizer states') _save_optimizer_state_stage3(args, optim_files) @@ -610,8 +953,17 @@ def main(args): # Copy *model_states files into output folder, filtering out expert files for f in glob.glob(os.path.join(args.input_folder, '*model_states.pt')): if _is_expert_model_state_file(f): - continue # Skip expert files (handled separately if AutoEP were supported) - shutil.copy2(f, args.output_folder) + continue + if has_zero3_partitioned_autoep: + model_state = torch.load(f, map_location=torch.device('cpu'), weights_only=False) + autoep_metadata = _get_autoep_metadata(model_state) + if UNIVERSAL_CHECKPOINT_INFO not in model_state: + model_state[UNIVERSAL_CHECKPOINT_INFO] = {} + model_state[UNIVERSAL_CHECKPOINT_INFO][EXPERT_PARAMETER_PATTERNS] = [r'.*\.experts\.w[123]$'] + model_state[UNIVERSAL_CHECKPOINT_INFO][AUTOEP_LAYERS_KEY] = autoep_metadata + torch.save(model_state, os.path.join(args.output_folder, os.path.basename(f))) + else: + shutil.copy2(f, args.output_folder) # Update latest to output folder checkpoint_root_folder, step_folder = os.path.split(args.output_folder) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index aa063879ac1e..5abb33e2d374 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -70,7 +70,18 @@ WEIGHT_QUANTIZE_ROUNDING, \ WEIGHT_QUANTIZE_VERBOSE, \ WEIGHT_QUANTIZE_KERNEL -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS, UNIVERSAL_CHECKPOINT_INFO +from deepspeed.checkpoint.constants import ( + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY, + AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, + EXPERT_PARAMETER_PATTERNS, + FROZEN_PARAM_FRAGMENTS, + OPTIMIZER_STATE_DICT, + UNIVERSAL_CHECKPOINT_INFO, + UNIVERSAL_CHECKPOINT_VERSION_KEY, + UNIVERSAL_CHECKPOINT_VERSION_VALUE, +) from deepspeed.checkpoint.utils import clone_tensors_for_torch_save from deepspeed.checkpoint.ds_to_universal import dp_index_to_str from deepspeed.runtime.sparse_tensor import SparseTensor @@ -3640,7 +3651,8 @@ def load_module_state_dict(self, strict=True, custom_load_fn=None, fetch_z3_params=False, - z3_params_to_fetch=None): + z3_params_to_fetch=None, + allowed_missing_keys=None): if z3_params_to_fetch is not None: params_to_fetch = [ p for p in z3_params_to_fetch if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE @@ -3658,9 +3670,17 @@ def load_module_state_dict(self, if custom_load_fn: custom_load_fn(src=module_state_dict, dst=self.module) else: - self.module.load_state_dict( + load_result = self.module.load_state_dict( module_state_dict, # TODO - strict=strict) + strict=strict and allowed_missing_keys is None) + if allowed_missing_keys is not None: + missing_keys = set(load_result.missing_keys) + unexpected_keys = set(load_result.unexpected_keys) + unexpected_missing = missing_keys - set(allowed_missing_keys) + if unexpected_missing or unexpected_keys: + raise RuntimeError("Checkpoint module state did not match the model outside AutoEP expert " + f"parameters: missing={sorted(unexpected_missing)}, " + f"unexpected={sorted(unexpected_keys)}") if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None: saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS] @@ -3813,8 +3833,13 @@ def load_checkpoint(self, load_zero_checkpoint = load_path is not None and self.zero_optimization() if load_zero_checkpoint and not self.zero_nvme_offload_optimizer(): - if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): - success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + autoep_zero3_partition_native_load = self.has_moe_layers and self.zero_optimization_partition_weights() + if ((load_optimizer_states and not load_module_only) or self.load_universal_checkpoint() + or autoep_zero3_partition_native_load): + success = self._load_zero_checkpoint(load_dir, + tag, + load_optimizer_states=load_optimizer_states + and not load_module_only) else: success = False if not success: @@ -3843,6 +3868,149 @@ def load_checkpoint(self, return load_path, client_states + @staticmethod + def _uses_autoep_zero3_partitioned_experts(autoep_layers): + if not isinstance(autoep_layers, list): + return False + DeepSpeedEngine._validate_autoep_zero3_partitioned_metadata(autoep_layers, require_partitioned=False) + return any( + isinstance(entry, dict) + and entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT + for entry in autoep_layers) + + @staticmethod + def _validate_autoep_zero3_partitioned_metadata(autoep_layers, model=None, require_partitioned=True): + if not isinstance(autoep_layers, list): + raise RuntimeError(f"ds_autoep_layers metadata is malformed: expected list, got " + f"{type(autoep_layers).__name__}") + + try: + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer + except ImportError: + _AutoEPMoELayer = None + + live_autoep_modules = {} + if _AutoEPMoELayer is not None and model is not None: + live_autoep_modules = { + module_name: module + for module_name, module in model.named_modules() if isinstance(module, _AutoEPMoELayer) + } + + required_fields = { + 'moe_layer_id', + 'module_path', + 'num_experts', + 'num_local_experts', + 'ep_size', + 'expert_key_prefix', + } + partitioned_fields = { + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + 'ep_group_name', + 'ep_rank', + 'expert_data_parallel_rank', + 'expert_data_parallel_world_size', + 'global_expert_start', + 'global_expert_end', + } + seen_layer_ids = set() + seen_prefixes = set() + partitioned_count = 0 + + for entry in autoep_layers: + if not isinstance(entry, dict): + raise RuntimeError(f"ds_autoep_layers entry is malformed: expected dict, got " + f"{type(entry).__name__}") + missing = required_fields - entry.keys() + if missing: + raise RuntimeError(f"ds_autoep_layers entry is invalid: missing fields {sorted(missing)}") + + layer_id = entry['moe_layer_id'] + if layer_id in seen_layer_ids: + raise RuntimeError(f"ds_autoep_layers metadata has duplicate moe_layer_id: {layer_id}") + seen_layer_ids.add(layer_id) + + prefix = entry['expert_key_prefix'] + if prefix in seen_prefixes: + raise RuntimeError(f"ds_autoep_layers metadata has duplicate expert_key_prefix: {prefix}") + seen_prefixes.add(prefix) + + checkpoint_format = entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) + if checkpoint_format is None: + continue + if checkpoint_format != AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT: + continue + + missing = partitioned_fields - entry.keys() + if missing: + raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata is invalid: missing fields {sorted(missing)}") + if entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY] != AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION: + raise RuntimeError("Unsupported AutoEP ZeRO-3 checkpoint format version: " + f"{entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY]}. " + f"This DeepSpeed build supports version " + f"{AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION}.") + + num_experts = entry['num_experts'] + num_local_experts = entry['num_local_experts'] + ep_size = entry['ep_size'] + if num_local_experts * ep_size != num_experts: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata is inconsistent: " + f"num_local_experts={num_local_experts}, ep_size={ep_size}, " + f"num_experts={num_experts}") + + expected_start = entry['ep_rank'] * num_local_experts + expected_end = expected_start + num_local_experts + if entry['global_expert_start'] != expected_start or entry['global_expert_end'] != expected_end: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has inconsistent global expert range: " + f"got [{entry['global_expert_start']}, {entry['global_expert_end']}), " + f"expected [{expected_start}, {expected_end})") + + module_path = entry['module_path'] + if live_autoep_modules and module_path not in live_autoep_modules: + raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata references missing module: {module_path}") + if live_autoep_modules: + expected_prefix = f"{module_path}.experts" if module_path else "experts" + if prefix != expected_prefix: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has unexpected expert key prefix: " + f"got {prefix}, expected {expected_prefix}") + + partitioned_count += 1 + + if require_partitioned and partitioned_count == 0: + raise RuntimeError("AutoEP ZeRO-3 partition-native checkpoint metadata was expected but no " + "partitioned AutoEP layer entries were found") + + @staticmethod + def _autoep_expert_parameter_names(autoep_layers, model): + names = set() + if isinstance(autoep_layers, list): + DeepSpeedEngine._validate_autoep_zero3_partitioned_metadata(autoep_layers, model=model) + for entry in autoep_layers: + if not isinstance(entry, dict): + continue + if entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) != AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT: + continue + prefix = entry.get('expert_key_prefix') + if prefix: + names.update(f"{prefix}.{wname}" for wname in ('w1', 'w2', 'w3')) + + if names: + return names + + try: + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer + except ImportError: + _AutoEPMoELayer = None + if _AutoEPMoELayer is None or model is None: + return names + + for module_name, module in model.named_modules(): + if not isinstance(module, _AutoEPMoELayer): + continue + module_prefix = f"{module_name}." if module_name else "" + names.update(f"{module_prefix}experts.{wname}" for wname in ('w1', 'w2', 'w3')) + return names + def _load_checkpoint(self, load_dir, tag, @@ -3867,12 +4035,9 @@ def _load_checkpoint(self, fetch_z3_params = False z3_params_to_fetch = None - if self.has_moe_layers and self.zero_optimization_partition_weights() and (load_module_only - or not load_optimizer_states): - raise NotImplementedError("AutoEP with ZeRO Stage 3 does not support module-only checkpoint load or " - "load_optimizer_states=False yet. Use same-topology checkpoint load with " - "optimizer states for this release.") - if self.zero_optimization_partition_weights() and not load_optimizer_states: + autoep_partitioned_experts = False + allowed_missing_keys = None + if self.zero_optimization_partition_weights() and not load_optimizer_states and not self.has_moe_layers: checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir) fetch_z3_params = True @@ -3891,15 +4056,23 @@ def _load_checkpoint(self, autoep_layers = checkpoint.get(AUTOEP_LAYERS_KEY) if autoep_layers is None: autoep_layers = checkpoint.get(AUTOEP_LAYERS_KEY_LEGACY) - DeepSpeedEngine.load_moe_state_dict(load_dir, - tag, - state_dict=checkpoint['module'], - old_moe_load=old_moe_load, - model=self.module, - mpu=self.mpu, - num_experts=self.num_experts, - checkpoint_engine=self.checkpoint_engine, - autoep_layers=autoep_layers) + autoep_partitioned_experts = (self.zero_optimization_partition_weights() + and DeepSpeedEngine._uses_autoep_zero3_partitioned_experts(autoep_layers)) + if autoep_partitioned_experts: + allowed_missing_keys = DeepSpeedEngine._autoep_expert_parameter_names(autoep_layers, self.module) + if not allowed_missing_keys: + raise RuntimeError("AutoEP ZeRO-3 partition-native checkpoint metadata did not identify any " + "live expert parameters to restore from ZeRO shards.") + else: + DeepSpeedEngine.load_moe_state_dict(load_dir, + tag, + state_dict=checkpoint['module'], + old_moe_load=old_moe_load, + model=self.module, + mpu=self.mpu, + num_experts=self.num_experts, + checkpoint_engine=self.checkpoint_engine, + autoep_layers=autoep_layers) if self.zero_optimization_partition_weights(): z3_params_to_fetch = [] try: @@ -3915,7 +4088,8 @@ def _load_checkpoint(self, strict=load_module_strict, custom_load_fn=custom_load_fn, fetch_z3_params=fetch_z3_params, - z3_params_to_fetch=z3_params_to_fetch) + z3_params_to_fetch=z3_params_to_fetch, + allowed_missing_keys=allowed_missing_keys) self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] @@ -4284,8 +4458,14 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa autoep_layer_info = [] autoep_group_names = set() moe_layer_id = 0 + found_native_moe = False + found_autoep = False for n_module, module in self.module.named_modules(): if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: + found_native_moe = True + if self.zero_optimization_partition_weights() and found_autoep: + raise RuntimeError("AutoEP with ZeRO Stage 3 checkpointing does not support models that also " + "contain native DeepSpeed MoE layers.") group_name = module.expert_group_name num_local_experts = module.num_local_experts expp_rank = groups._get_expert_parallel_rank(group_name) @@ -4335,20 +4515,60 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa moe_layer_id += 1 elif _AutoEPMoELayer is not None and isinstance(module, _AutoEPMoELayer): + found_autoep = True + if self.zero_optimization_partition_weights() and found_native_moe: + raise RuntimeError("AutoEP with ZeRO Stage 3 checkpointing does not support models that also " + "contain native DeepSpeed MoE layers.") + if self.zero_optimization_partition_weights() and self.zero_nvme_offload_optimizer(): + raise RuntimeError("AutoEP with ZeRO Stage 3 checkpointing does not support NVMe optimizer " + "swapping yet because expert state is restored from ZeRO optimizer shards.") group_name = module.ep_group_name num_local_experts = module.num_local_experts expp_rank = groups._get_expert_parallel_rank(group_name) exp_dp_rank = groups._get_expert_data_parallel_rank(group_name) module_prefix = f"{n_module}." if n_module else "" + expert_params = [getattr(module.experts, wname) for wname in ('w1', 'w2', 'w3')] + if self.zero_optimization_partition_weights(): + frozen_expert_names = [ + f"{module_prefix}experts.{wname}" for wname, param in zip(('w1', 'w2', 'w3'), expert_params) + if not param.requires_grad + ] + if frozen_expert_names: + raise RuntimeError("AutoEP with ZeRO Stage 3 checkpointing does not support frozen expert " + "parameters yet because frozen fragments are not stored in ZeRO optimizer " + f"shards: {frozen_expert_names}") # Collect metadata on ALL ranks (before writer guard) autoep_layer_info.append({ - 'moe_layer_id': moe_layer_id, - 'module_path': n_module, - 'num_experts': module.num_experts, - 'num_local_experts': num_local_experts, - 'ep_size': module.ep_size, - 'expert_key_prefix': f"{module_prefix}experts", + 'moe_layer_id': + moe_layer_id, + 'module_path': + n_module, + 'num_experts': + module.num_experts, + 'num_local_experts': + num_local_experts, + 'ep_size': + module.ep_size, + 'expert_key_prefix': + f"{module_prefix}experts", + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY: + AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT + if self.zero_optimization_partition_weights() else 'per_expert_files', + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY: + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION if self.zero_optimization_partition_weights() else None, + 'ep_group_name': + group_name, + 'ep_rank': + expp_rank, + 'expert_data_parallel_rank': + exp_dp_rank, + 'expert_data_parallel_world_size': + groups._get_expert_data_parallel_world_size(group_name), + 'global_expert_start': + expp_rank * num_local_experts, + 'global_expert_end': + expp_rank * num_local_experts + num_local_experts, }) autoep_group_names.add(group_name) if len(autoep_group_names) > 1: @@ -4356,7 +4576,10 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa f"multiple groups: {sorted(autoep_group_names)}. " f"All AutoEPMoELayer instances must use the same ep_size.") - expert_params = [getattr(module.experts, wname) for wname in ('w1', 'w2', 'w3')] + if self.zero_optimization_partition_weights(): + moe_layer_id += 1 + continue + with deepspeed.zero.GatheredParameters(expert_params): if self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): # Slice fused 3D tensors into per-expert state dicts. @@ -4391,7 +4614,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa # zero_pp_rank_*_model_states.pt file so load can discover a checkpoint # on every rank. is_expert_dp_writer = self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank) - if is_expert_dp_writer: + if is_expert_dp_writer and not self.zero_optimization_partition_weights(): optimizer_state = { 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None } @@ -4413,10 +4636,47 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa model_state_dict = self._get_non_moe_state_dict( DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters)) - # TODO: update num experts info,.. in checkpoint + zero_optimizer_state = self.zero_optimization() + save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters + zero_param_shapes = self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None + if autoep_layer_info and self.zero_optimization_partition_weights(): + DeepSpeedEngine._validate_autoep_zero3_partitioned_metadata(autoep_layer_info, model=self.module) + expert_param_names = DeepSpeedEngine._autoep_expert_parameter_names(autoep_layer_info, self.module) + zero_shape_names = { + name + for param_group_shapes in (zero_param_shapes or []) + for name in param_group_shapes.keys() + } + missing_expert_shapes = expert_param_names - zero_shape_names + if missing_expert_shapes: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata references expert parameters that are " + f"missing from ZeRO param_shapes: {sorted(missing_expert_shapes)}") + universal_checkpoint_info = getattr(self.module, UNIVERSAL_CHECKPOINT_INFO, None) + if universal_checkpoint_info is not None: + universal_checkpoint_info = dict(universal_checkpoint_info) + elif autoep_layer_info: + universal_checkpoint_info = {} + if autoep_layer_info: + universal_checkpoint_info.setdefault(UNIVERSAL_CHECKPOINT_VERSION_KEY, + UNIVERSAL_CHECKPOINT_VERSION_VALUE) + universal_checkpoint_info[EXPERT_PARAMETER_PATTERNS] = [r'.*\.experts\.w[123]$'] + universal_checkpoint_info['ds_autoep_layers'] = autoep_layer_info + state = { 'module': model_state_dict, + 'buffer_names': + self._get_buffer_names(), + 'optimizer': + self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, + 'param_shapes': + zero_param_shapes, + 'frozen_param_shapes': + self._get_zero_frozen_param_attributes(self._get_param_shape_func) if save_frozen_param else None, + 'shared_params': + self._get_shared_params() if self.optimizer and zero_optimizer_state else None, + 'frozen_param_fragments': + self._get_zero_frozen_param_attributes(self._get_param_fragment_func) if save_frozen_param else None, 'lr_scheduler': self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, 'data_sampler': @@ -4440,9 +4700,15 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa self.num_experts, 'ds_autoep_layers': autoep_layer_info if autoep_layer_info else None, + 'ds_config': + self.config, + 'ds_version': + version, } + if universal_checkpoint_info is not None: + state[UNIVERSAL_CHECKPOINT_INFO] = universal_checkpoint_info # Check for reserved-key collisions with client_state - reserved_keys = {'ds_autoep_layers', 'autoep_layers'} + reserved_keys = {'ds_autoep_layers', 'autoep_layers', UNIVERSAL_CHECKPOINT_INFO} collisions = reserved_keys.intersection(client_state.keys()) if collisions: raise KeyError(f"client_state contains reserved checkpoint keys: {sorted(collisions)}. " @@ -4718,6 +4984,20 @@ def _consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): raise ValueError("consolidated_16bit_state_dict is only applicable to cases where weights are partitioned, " "including Zero Stage 3 and tensor parallelism.") + def _has_autoep_layers(self): + try: + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer + except ImportError: + return False + return any(isinstance(module, _AutoEPMoELayer) for _, module in self.module.named_modules()) + + def _raise_if_autoep_zero3_consolidated_export(self, operation): + if self.zero_optimization_partition_weights() and self._has_autoep_layers(): + raise NotImplementedError(f"{operation} is not supported for AutoEP with ZeRO Stage 3 checkpoint " + "partitions. AutoEP expert parameters are partitioned over expert replica " + "groups, so global-DP consolidation would produce incomplete expert tensors. " + "Use ds_to_universal.py for an expert-aware checkpoint conversion.") + def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): """ Get a full non-partitioned state_dict with fp16 weights on cpu. @@ -4732,6 +5012,7 @@ def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): """ if not self.zero_optimization_partition_weights(): raise ValueError("this function requires ZeRO-3 mode") + self._raise_if_autoep_zero3_consolidated_export("_zero3_consolidated_16bit_state_dict") state_dict = OrderedDict() if dist.get_rank() == 0 else None shared_params = {} @@ -4811,6 +5092,7 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", exclude_ path = os.path.join(save_dir, save_filename) if self.zero_optimization_partition_weights(): + self._raise_if_autoep_zero3_consolidated_export("save_16bit_model") if self.zero_gather_16bit_weights_on_model_save(): # consolidation is expensive in time and memory and therefore isn't a default state_dict = self._zero3_consolidated_16bit_state_dict( diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 90a152091c1b..e760c4b66a0e 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1330,7 +1330,8 @@ def load_module_state_dict(self, strict=True, custom_load_fn=None, fetch_z3_params=False, - z3_params_to_fetch=None): + z3_params_to_fetch=None, + allowed_missing_keys=None): """Override hack to instead use a directory path. This is important because pipeline models checkpoint by layer instead of rank. @@ -1347,7 +1348,8 @@ def load_module_state_dict(self, super().load_module_state_dict(state_dict, strict, fetch_z3_params=fetch_z3_params, - z3_params_to_fetch=z3_params_to_fetch) + z3_params_to_fetch=z3_params_to_fetch, + allowed_missing_keys=allowed_missing_keys) return self.module.load_state_dir(load_dir=self._curr_ckpt_path, diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 254606393eaa..ac1341d55641 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -34,7 +34,16 @@ from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER +from deepspeed.checkpoint.constants import ( + EP_IS_EXPERT_PARAM, + EP_NUM_EXPERTS, + OPTIMIZER_STATE_DICT, + FP32_FLAT_GROUPS, + PARAM, + PARTITION_COUNT, + ZERO_STAGE, + LOSS_SCALER, +) from deepspeed.accelerator import get_accelerator from deepspeed.runtime.zero.muon.original_muon import muon_update from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam @@ -3328,7 +3337,8 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): raise ValueError(f"failed to find optimizer param in named params") param_name = self.param_names[param] key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, param_name), - key) + key, + param=param) key_tensor.narrow(0, offset, key_layer_state_partition.numel()).copy_(key_layer_state_partition) offset += key_layer_state_partition.numel() if key == "fp32": @@ -3377,14 +3387,23 @@ def _load_global_state_stage3(self, sd): self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) self.overflow = sd.get('overflow', self.overflow) - def load_hp_checkpoint_state(self, folder, key): - rank = dist.get_rank(group=self.dp_process_group) + def load_hp_checkpoint_state(self, folder, key, param=None): + partition_group = self._get_param_partition_group(param) if param is not None else self.dp_process_group + rank = dist.get_rank(group=partition_group) # Load tensors from files and reshape them to flat vectors - loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1) + loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False) + if isinstance(loaded_checkpoint_state, dict): + if loaded_checkpoint_state.get(EP_IS_EXPERT_PARAM, False): + if param is None: + raise ValueError(f"AutoEP universal expert checkpoint state in {folder} requires a target param") + loaded_checkpoint_state = self._slice_autoep_universal_expert_param(loaded_checkpoint_state, param) + else: + loaded_checkpoint_state = loaded_checkpoint_state[PARAM] + loaded_checkpoint_state = loaded_checkpoint_state.view(-1) # Partition the loaded data according to the local rank - world_size = dist.get_world_size(group=self.dp_process_group) + world_size = dist.get_world_size(group=partition_group) unpartitioned_numel = loaded_checkpoint_state.numel() partitioned_numel = math.ceil(unpartitioned_numel / world_size) @@ -3396,6 +3415,31 @@ def load_hp_checkpoint_state(self, folder, key): return checkpoint_state_partition + def _slice_autoep_universal_expert_param(self, checkpoint_state, param): + full_expert_tensor = checkpoint_state[PARAM] + checkpoint_num_experts = checkpoint_state.get(EP_NUM_EXPERTS, full_expert_tensor.shape[0]) + group_name = getattr(param, "ds_zero_partition_group_name", None) + if group_name is None: + raise ValueError("AutoEP universal expert checkpoint target parameter is missing its EP group name") + ep_rank = groups._get_expert_parallel_rank(group_name) + ep_world_size = groups._get_expert_parallel_world_size(group_name) + if checkpoint_num_experts % ep_world_size != 0: + raise ValueError("AutoEP universal expert checkpoint tensor cannot be evenly split across the target " + f"EP topology: checkpoint_num_experts={checkpoint_num_experts}, " + f"target_ep_size={ep_world_size}") + local_expert_count = param.ds_shape[0] if hasattr(param, "ds_shape") else param.shape[0] + expected_local_expert_count = checkpoint_num_experts // ep_world_size + if local_expert_count != expected_local_expert_count: + raise ValueError("AutoEP universal expert checkpoint tensor is incompatible with target parameter " + f"shape: target_local_experts={local_expert_count}, " + f"checkpoint_local_experts={expected_local_expert_count}") + expert_offset = ep_rank * local_expert_count + if expert_offset + local_expert_count > full_expert_tensor.shape[0]: + raise ValueError("AutoEP universal expert checkpoint tensor is incompatible with target EP topology: " + f"ep_rank={ep_rank}, local_experts={local_expert_count}, " + f"checkpoint_shape={tuple(full_expert_tensor.shape)}") + return full_expert_tensor.narrow(0, expert_offset, local_expert_count).contiguous() + def reset_swap_buffers(self): timer_names = set() for sub_group_id, group in enumerate(self.fp16_groups): diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 5995d6e6f04e..1c9d6591e97b 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -33,7 +33,9 @@ from deepspeed.utils import logger from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, - FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS, AUTOEP_LAYERS_KEY, + AUTOEP_LAYERS_KEY_LEGACY, AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY, + AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT) @dataclass @@ -99,6 +101,28 @@ def get_model_state_files(checkpoint_dir): return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") +def _has_autoep_zero3_partitioned_metadata(state_dict): + autoep_layers = state_dict.get(AUTOEP_LAYERS_KEY) + if autoep_layers is None: + autoep_layers = state_dict.get(AUTOEP_LAYERS_KEY_LEGACY) + if not isinstance(autoep_layers, list): + return False + return any( + isinstance(entry, dict) + and entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT + for entry in autoep_layers) + + +def _raise_if_autoep_zero3_partitioned_checkpoint(model_files): + for file in model_files: + state_dict = torch.load(file, map_location=device, weights_only=False) + if _has_autoep_zero3_partitioned_metadata(state_dict): + raise NotImplementedError("zero_to_fp32 does not support AutoEP ZeRO-3 partition-native checkpoints. " + "AutoEP expert parameters are partitioned over expert replica groups, so " + "global data-parallel consolidation would produce incomplete expert tensors. " + "Use ds_to_universal.py for expert-aware conversion.") + + def parse_model_states(files): zero_model_states = [] for file in files: @@ -195,12 +219,13 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_ """ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + model_files = get_model_state_files(ds_checkpoint_dir) + _raise_if_autoep_zero3_partitioned_checkpoint(model_files) + optim_files = get_optim_files(ds_checkpoint_dir) zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") - model_files = get_model_state_files(ds_checkpoint_dir) - zero_model_states = parse_model_states(model_files) print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') diff --git a/tests/unit/v1/moe/test_autoep_checkpoint.py b/tests/unit/v1/moe/test_autoep_checkpoint.py index 81010aea7d58..a8b901d2222b 100644 --- a/tests/unit/v1/moe/test_autoep_checkpoint.py +++ b/tests/unit/v1/moe/test_autoep_checkpoint.py @@ -5,17 +5,25 @@ """Compact AutoEP checkpoint tests.""" import os +from types import SimpleNamespace +import deepspeed import pytest import torch import torch.nn as nn +from deepspeed import comm as dist +from deepspeed.checkpoint.ds_to_universal import main as convert_to_universal from deepspeed.runtime.config import DeepSpeedConfig from unit.common import DistributedTest from unit.v1.moe.autoep_test_utils import ( + MockMoETransformer, UNSUPPORTED_LOAD_BALANCE_VALUES, assert_load_balance_coeff_rejection_message, init_autoep_engine, + make_autoep_integration_config, + run_training_steps, + seed_everything, ) @@ -102,3 +110,66 @@ def test_autoep_metadata_schema_validation(self): autoep_layers=[{ "moe_layer_id": 0 }]) + + +class TestAutoEPZero3UniversalCheckpoint(DistributedTest): + world_size = 2 + + def test_zero3_partition_native_universal_round_trip_same_topology(self, tmpdir): + seed_everything(2468) + + config = make_autoep_integration_config(zero_stage=3, ep_size=2) + engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=config) + run_training_steps(engine, num_steps=1) + + save_dir = str(tmpdir) + tag = "autoep-zero3" + engine.save_checkpoint(save_dir, tag=tag) + + checkpoint_dir = os.path.join(save_dir, tag) + universal_dir = os.path.join(save_dir, f"{tag}_universal") + args = SimpleNamespace(input_folder=checkpoint_dir, + output_folder=universal_dir, + num_extract_workers=1, + num_merge_workers=1, + keep_temp_folder=False, + strict=True, + inject_missing_state=False) + + dist.barrier() + if dist.get_rank() == 0: + convert_to_universal(args) + dist.barrier() + + from deepspeed.checkpoint.constants import PARAM + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer + for module_name, module in engine.module.named_modules(): + if not isinstance(module, AutoEPMoELayer): + continue + module_prefix = f"{module_name}." if module_name else "" + for wname in ("w1", "w2", "w3"): + param = getattr(module.experts, wname) + with deepspeed.zero.GatheredParameters([param]): + local_experts = param.detach().clone() + gathered = [torch.zeros_like(local_experts) for _ in range(dist.get_world_size())] + dist.all_gather(gathered, local_experts) + if dist.get_rank() == 0: + expected = torch.cat(gathered, dim=0).cpu() + universal = torch.load( + os.path.join(universal_dir, "zero", f"{module_prefix}experts.{wname}", "fp32.pt"), + map_location="cpu", + weights_only=False, + )[PARAM] + torch.testing.assert_close(universal, expected) + + universal_config = make_autoep_integration_config(zero_stage=3, ep_size=2) + universal_config["checkpoint"] = {"load_universal": True} + reloaded_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=universal_config) + reloaded_engine.load_checkpoint(save_dir, tag=f"{tag}_universal") + + for expected, restored in zip(engine.optimizer.fp16_partitioned_groups_flat, + reloaded_engine.optimizer.fp16_partitioned_groups_flat): + torch.testing.assert_close(restored, expected) + + losses, _ = run_training_steps(reloaded_engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) diff --git a/tests/unit/v1/moe/test_autoep_integration.py b/tests/unit/v1/moe/test_autoep_integration.py index 69e276d91a0e..63770feb54af 100644 --- a/tests/unit/v1/moe/test_autoep_integration.py +++ b/tests/unit/v1/moe/test_autoep_integration.py @@ -4,6 +4,8 @@ # DeepSpeed Team """Integration tests for AutoEP (multi-GPU, requires distributed backend).""" +import os + import pytest import torch import torch.nn as nn @@ -132,6 +134,26 @@ def test_zero3_ep_save_load_same_topology_2gpu(self, tmpdir): save_dir = str(tmpdir) engine.save_checkpoint(save_dir, tag="autoep-zero3") + checkpoint_dir = os.path.join(save_dir, "autoep-zero3") + checkpoint_files = os.listdir(checkpoint_dir) + assert not any(name.startswith("layer_") and "_expert_" in name for name in checkpoint_files) + + model_state = torch.load(os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt"), + map_location="cpu", + weights_only=False) + from deepspeed.checkpoint.constants import ( + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY, + AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, + PARAM_SHAPES, + ) + assert all(entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY] == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT + for entry in model_state["ds_autoep_layers"]) + assert all(entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY] == AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION + for entry in model_state["ds_autoep_layers"]) + param_names = {name for group_shapes in model_state[PARAM_SHAPES] for name in group_shapes} + assert any(name.endswith("experts.w1") for name in param_names) reloaded = MockMoETransformer() reloaded_engine, _, _, _ = deepspeed.initialize(model=reloaded, config=config) @@ -140,8 +162,18 @@ def test_zero3_ep_save_load_same_topology_2gpu(self, tmpdir): module_only = MockMoETransformer() module_only_engine, _, _, _ = deepspeed.initialize(model=module_only, config=config) - with pytest.raises(NotImplementedError, match="load_optimizer_states=False"): - module_only_engine.load_checkpoint(save_dir, tag="autoep-zero3", load_optimizer_states=False) + module_only_engine.load_checkpoint(save_dir, tag="autoep-zero3", load_optimizer_states=False) + + module_only_flag = MockMoETransformer() + module_only_flag_engine, _, _, _ = deepspeed.initialize(model=module_only_flag, config=config) + module_only_flag_engine.load_checkpoint(save_dir, tag="autoep-zero3", load_module_only=True) + + for expected, restored in zip(engine.optimizer.fp16_partitioned_groups_flat, + module_only_engine.optimizer.fp16_partitioned_groups_flat): + torch.testing.assert_close(restored, expected) + for expected, restored in zip(engine.optimizer.fp16_partitioned_groups_flat, + module_only_flag_engine.optimizer.fp16_partitioned_groups_flat): + torch.testing.assert_close(restored, expected) losses, _ = _run_training_steps(reloaded_engine, num_steps=1) assert torch.isfinite(torch.tensor(losses[0])) diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index dc6f1b01d2f5..0bf18c88ef0a 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -5,6 +5,7 @@ """Compact critical-path tests for AutoEP.""" import inspect +from collections import OrderedDict from types import SimpleNamespace import pytest @@ -595,24 +596,131 @@ def test_pipeline_load_module_state_dict_accepts_autoep_zero3_fetch_kwarg(self): signature = inspect.signature(PipelineEngine.load_module_state_dict) assert "z3_params_to_fetch" in signature.parameters + assert "allowed_missing_keys" in signature.parameters - def test_universal_converter_rejects_zero3_autoep_model_state(self, tmp_path): - from deepspeed.checkpoint.constants import AUTOEP_LAYERS_KEY + def test_autoep_zero3_16bit_export_guard_directs_to_universal_conversion(self): + engine = object.__new__(DeepSpeedEngine) + engine.zero_optimization_partition_weights = lambda: True + engine._has_autoep_layers = lambda: True + + with pytest.raises(NotImplementedError, match="ds_to_universal.py"): + engine._raise_if_autoep_zero3_consolidated_export("save_16bit_model") + + def test_universal_converter_detects_zero3_partitioned_autoep_model_state(self, tmp_path): + from deepspeed.checkpoint.constants import ( + AUTOEP_LAYERS_KEY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY, + AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, + ) from deepspeed.checkpoint.ds_to_universal import ( + _autoep_expert_param_names_by_rank, _get_zero3_model_state_files, - _raise_if_stage3_autoep_universal_conversion, + _uses_zero3_partitioned_autoep_metadata, ) zero3_model_file = tmp_path / "zero_pp_rank_0_mp_rank_00_model_states.pt" expert_file = tmp_path / "layer_0_expert_0_mp_rank_00_model_states.pt" - torch.save({AUTOEP_LAYERS_KEY: [{"moe_layer_id": 0}]}, zero3_model_file) + metadata = [{ + "moe_layer_id": 0, + "module_path": "model.layers.0.mlp", + "num_experts": 4, + "num_local_experts": 2, + "ep_size": 2, + "expert_key_prefix": "model.layers.0.mlp.experts", + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY: AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY: AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, + "ep_group_name": "ep_size_2", + "ep_rank": 0, + "expert_data_parallel_rank": 0, + "expert_data_parallel_world_size": 1, + "global_expert_start": 0, + "global_expert_end": 2, + }] + torch.save({AUTOEP_LAYERS_KEY: metadata}, zero3_model_file) torch.save({"expert": torch.empty(1)}, expert_file) model_files = _get_zero3_model_state_files(str(tmp_path)) + expert_param_names, metadata_by_rank = _autoep_expert_param_names_by_rank(model_files) assert model_files == [str(zero3_model_file)] - with pytest.raises(NotImplementedError, match="same-topology ZeRO-3 checkpoint load"): - _raise_if_stage3_autoep_universal_conversion(model_files) + assert expert_param_names == { + "model.layers.0.mlp.experts.w1", + "model.layers.0.mlp.experts.w2", + "model.layers.0.mlp.experts.w3", + } + assert _uses_zero3_partitioned_autoep_metadata(metadata_by_rank[0]) + + def test_universal_stage3_extract_accepts_tuple_param_shapes(self, tmp_path): + from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT + from deepspeed.checkpoint.ds_to_universal import extract_zero_shards_stage3 + + optim_file = tmp_path / "zero_pp_rank_0_mp_rank_00_optim_states.pt" + torch.save( + { + OPTIMIZER_STATE_DICT: { + "optimizer_state_dict": { + "state": [{ + "exp_avg": torch.arange(6, dtype=torch.float32), + "exp_avg_sq": torch.arange(6, dtype=torch.float32) + 10, + }] + }, + "fp32_flat_groups": [torch.arange(6, dtype=torch.float32) + 20], + } + }, + optim_file, + ) + + temp_dir = tmp_path / "tmp" + extract_zero_shards_stage3([str(optim_file)], [OrderedDict([("dense.weight", (2, 3))])], 1, str(temp_dir), 0) + + fp32_fragment = torch.load(temp_dir / "dense.weight" / "0" / "fp32.00", weights_only=False) + exp_avg_fragment = torch.load(temp_dir / "dense.weight" / "0" / "exp_avg.00", weights_only=False) + torch.testing.assert_close(fp32_fragment, torch.arange(6, dtype=torch.float32) + 20) + torch.testing.assert_close(exp_avg_fragment, torch.arange(6, dtype=torch.float32)) + + def test_zero_to_fp32_rejects_zero3_partitioned_autoep_checkpoint(self, tmp_path): + from deepspeed.checkpoint.constants import ( + AUTOEP_LAYERS_KEY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY, + AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, + BUFFER_NAMES, + PARAM_SHAPES, + ) + from deepspeed.utils.zero_to_fp32 import _raise_if_autoep_zero3_partitioned_checkpoint + + model_file = tmp_path / "zero_pp_rank_0_mp_rank_00_model_states.pt" + torch.save( + { + BUFFER_NAMES: [], + PARAM_SHAPES: [], + "module": {}, + "shared_params": {}, + AUTOEP_LAYERS_KEY: [{ + "moe_layer_id": 0, + "module_path": "model.layers.0.mlp", + "num_experts": 4, + "num_local_experts": 2, + "ep_size": 2, + "expert_key_prefix": "model.layers.0.mlp.experts", + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY: AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY: AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, + "ep_group_name": "ep_size_2", + "ep_rank": 0, + "expert_data_parallel_rank": 0, + "expert_data_parallel_world_size": 1, + "global_expert_start": 0, + "global_expert_end": 2, + }], + }, + model_file, + ) + + with pytest.raises(NotImplementedError, match="ds_to_universal.py"): + _raise_if_autoep_zero3_partitioned_checkpoint([str(model_file)]) def test_preset_registry_core_contracts(self): assert set(PRESET_MODELS) == {"mixtral", "qwen3_moe", "qwen3_5_moe", "deepseek_v2", "deepseek_v3"} From 8f528ced81afa31ac689866cba0e6b2641d0ef9a Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 00:05:20 -0700 Subject: [PATCH 09/19] Address review findings for partition-native AutoEP checkpoints - Honor load_module_strict=False in load_module_state_dict when AutoEP expert keys are allowed to be missing; the mismatch error now only fires for strict loads. - Add a 4-GPU universal checkpoint round trip with expert-DP world size 2 so converter consolidation and universal/module-only loads cover real partition shards instead of the degenerate single-rank case. - Check AutoEP partition-native metadata inside parse_model_states so zero_to_fp32 loads each model state file once and rejects unsupported checkpoints before the optimizer shard load. Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/engine.py | 4 +- deepspeed/utils/zero_to_fp32.py | 23 +++-- tests/unit/v1/moe/test_autoep_checkpoint.py | 97 +++++++++++++++++++++ tests/unit/v1/moe/test_autoep_unit.py | 20 +++++ 4 files changed, 134 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5abb33e2d374..6262b0bdb6cf 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3673,7 +3673,9 @@ def load_module_state_dict(self, load_result = self.module.load_state_dict( module_state_dict, # TODO strict=strict and allowed_missing_keys is None) - if allowed_missing_keys is not None: + # The expert-key allowance only tightens strict loads; a caller + # passing strict=False keeps the usual non-strict semantics. + if strict and allowed_missing_keys is not None: missing_keys = set(load_result.missing_keys) unexpected_keys = set(load_result.unexpected_keys) unexpected_missing = missing_keys - set(allowed_missing_keys) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 1c9d6591e97b..f40f1f0acbbf 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -113,20 +113,25 @@ def _has_autoep_zero3_partitioned_metadata(state_dict): for entry in autoep_layers) +def _raise_if_autoep_zero3_partitioned_state(state_dict): + if _has_autoep_zero3_partitioned_metadata(state_dict): + raise NotImplementedError("zero_to_fp32 does not support AutoEP ZeRO-3 partition-native checkpoints. " + "AutoEP expert parameters are partitioned over expert replica groups, so " + "global data-parallel consolidation would produce incomplete expert tensors. " + "Use ds_to_universal.py for expert-aware conversion.") + + def _raise_if_autoep_zero3_partitioned_checkpoint(model_files): for file in model_files: state_dict = torch.load(file, map_location=device, weights_only=False) - if _has_autoep_zero3_partitioned_metadata(state_dict): - raise NotImplementedError("zero_to_fp32 does not support AutoEP ZeRO-3 partition-native checkpoints. " - "AutoEP expert parameters are partitioned over expert replica groups, so " - "global data-parallel consolidation would produce incomplete expert tensors. " - "Use ds_to_universal.py for expert-aware conversion.") + _raise_if_autoep_zero3_partitioned_state(state_dict) def parse_model_states(files): zero_model_states = [] for file in files: state_dict = torch.load(file, map_location=device, weights_only=False) + _raise_if_autoep_zero3_partitioned_state(state_dict) if BUFFER_NAMES not in state_dict: raise ValueError(f"{file} is not a model state checkpoint") @@ -219,16 +224,16 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_ """ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + # parse_model_states rejects AutoEP ZeRO-3 partition-native checkpoints + # before the expensive optimizer-shard load below. model_files = get_model_state_files(ds_checkpoint_dir) - _raise_if_autoep_zero3_partitioned_checkpoint(model_files) + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') optim_files = get_optim_files(ds_checkpoint_dir) zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") - zero_model_states = parse_model_states(model_files) - print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') - if zero_stage <= 2: return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters) diff --git a/tests/unit/v1/moe/test_autoep_checkpoint.py b/tests/unit/v1/moe/test_autoep_checkpoint.py index a8b901d2222b..491dca697ea2 100644 --- a/tests/unit/v1/moe/test_autoep_checkpoint.py +++ b/tests/unit/v1/moe/test_autoep_checkpoint.py @@ -173,3 +173,100 @@ def test_zero3_partition_native_universal_round_trip_same_topology(self, tmpdir) losses, _ = run_training_steps(reloaded_engine, num_steps=1) assert torch.isfinite(torch.tensor(losses[0])) + + +class TestAutoEPZero3UniversalCheckpoint4GPU(DistributedTest): + world_size = 4 + + def test_zero3_partition_native_universal_round_trip_replica_groups_4gpu(self, tmpdir): + """Same round trip as the 2-GPU test, but with expert-DP world size 2 so + the converter consolidates multiple partition fragments per expert + parameter and the universal/module-only loads slice real shard offsets + instead of the degenerate world_size=1 case.""" + seed_everything(1357) + + config = make_autoep_integration_config(zero_stage=3, ep_size=2) + engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=config) + run_training_steps(engine, num_steps=1) + + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer + autoep_modules = [(name, module) for name, module in engine.module.named_modules() + if isinstance(module, AutoEPMoELayer)] + assert autoep_modules + for _, module in autoep_modules: + for param in module.experts.parameters(): + assert param.ds_zero_partition_world_size == 2 + + save_dir = str(tmpdir) + tag = "autoep-zero3-4gpu" + engine.save_checkpoint(save_dir, tag=tag) + + # Module-only restore must reassemble expert weights from two real + # partition shards per replica group. + module_only_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), + config=make_autoep_integration_config(zero_stage=3, + ep_size=2)) + module_only_engine.load_checkpoint(save_dir, tag=tag, load_optimizer_states=False) + for expected, restored in zip(engine.optimizer.fp16_partitioned_groups_flat, + module_only_engine.optimizer.fp16_partitioned_groups_flat): + torch.testing.assert_close(restored, expected) + + checkpoint_dir = os.path.join(save_dir, tag) + universal_dir = os.path.join(save_dir, f"{tag}_universal") + args = SimpleNamespace(input_folder=checkpoint_dir, + output_folder=universal_dir, + num_extract_workers=1, + num_merge_workers=1, + keep_temp_folder=False, + strict=True, + inject_missing_state=False) + + dist.barrier() + if dist.get_rank() == 0: + convert_to_universal(args) + dist.barrier() + + from deepspeed.checkpoint.constants import PARAM + world_size = dist.get_world_size() + for module_name, module in autoep_modules: + module_prefix = f"{module_name}." if module_name else "" + ep_rank_tensor = torch.tensor([module.ep_rank], dtype=torch.long, device=engine.device) + ep_ranks = [torch.zeros_like(ep_rank_tensor) for _ in range(world_size)] + dist.all_gather(ep_ranks, ep_rank_tensor) + ep_ranks = [int(t.item()) for t in ep_ranks] + for wname in ("w1", "w2", "w3"): + param = getattr(module.experts, wname) + with deepspeed.zero.GatheredParameters([param]): + local_experts = param.detach().clone() + gathered = [torch.zeros_like(local_experts) for _ in range(world_size)] + dist.all_gather(gathered, local_experts) + if dist.get_rank() == 0: + # Replicas within an EP rank must agree; keep one + # representative per EP rank in EP-rank order. + representative = {} + for global_rank, ep_rank in enumerate(ep_ranks): + if ep_rank in representative: + torch.testing.assert_close(gathered[global_rank], gathered[representative[ep_rank]]) + else: + representative[ep_rank] = global_rank + assert sorted(representative) == list(range(module.ep_size)) + expected = torch.cat([gathered[representative[ep_rank]] for ep_rank in range(module.ep_size)], + dim=0).cpu() + universal = torch.load( + os.path.join(universal_dir, "zero", f"{module_prefix}experts.{wname}", "fp32.pt"), + map_location="cpu", + weights_only=False, + )[PARAM] + torch.testing.assert_close(universal, expected) + + universal_config = make_autoep_integration_config(zero_stage=3, ep_size=2) + universal_config["checkpoint"] = {"load_universal": True} + reloaded_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=universal_config) + reloaded_engine.load_checkpoint(save_dir, tag=f"{tag}_universal") + + for expected, restored in zip(engine.optimizer.fp16_partitioned_groups_flat, + reloaded_engine.optimizer.fp16_partitioned_groups_flat): + torch.testing.assert_close(restored, expected) + + losses, _ = run_training_steps(reloaded_engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index 0bf18c88ef0a..d3fab039bb64 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -598,6 +598,19 @@ def test_pipeline_load_module_state_dict_accepts_autoep_zero3_fetch_kwarg(self): assert "z3_params_to_fetch" in signature.parameters assert "allowed_missing_keys" in signature.parameters + def test_load_module_state_dict_nonstrict_keeps_nonstrict_semantics_with_allowed_missing_keys(self): + engine = object.__new__(DeepSpeedEngine) + # bypass nn.Module.__setattr__, which requires Module.__init__ + object.__setattr__(engine, "module", nn.Linear(2, 2)) + checkpoint = {"module": {"unexpected_key": torch.zeros(1)}} + + # strict=False must keep the documented non-strict load semantics even + # when AutoEP expert keys are allowed to be missing. + engine.load_module_state_dict(checkpoint, strict=False, allowed_missing_keys=["weight"]) + + with pytest.raises(RuntimeError, match="outside AutoEP expert"): + engine.load_module_state_dict(checkpoint, strict=True, allowed_missing_keys=["weight"]) + def test_autoep_zero3_16bit_export_guard_directs_to_universal_conversion(self): engine = object.__new__(DeepSpeedEngine) engine.zero_optimization_partition_weights = lambda: True @@ -722,6 +735,13 @@ def test_zero_to_fp32_rejects_zero3_partitioned_autoep_checkpoint(self, tmp_path with pytest.raises(NotImplementedError, match="ds_to_universal.py"): _raise_if_autoep_zero3_partitioned_checkpoint([str(model_file)]) + # parse_model_states is the guard point used by + # _get_fp32_state_dict_from_zero_checkpoint, which loads each model + # state file only once. + from deepspeed.utils.zero_to_fp32 import parse_model_states + with pytest.raises(NotImplementedError, match="ds_to_universal.py"): + parse_model_states([str(model_file)]) + def test_preset_registry_core_contracts(self): assert set(PRESET_MODELS) == {"mixtral", "qwen3_moe", "qwen3_5_moe", "deepseek_v2", "deepseek_v3"} assert preset_name_for_hf_model_type("mixtral") == "mixtral" From b76fb5776341f9c19680c28492b644d0083173ae Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 10:37:45 -0700 Subject: [PATCH 10/19] Fail fast on pre-partitioned AutoEP expert params with mismatched groups A parameter that is already ZeRO-partitioned when _resolve_zero3_param_placement runs, for example an AutoEPMoELayer wrapped directly in zero.Init, keeps the partition group fixed at conversion time. Recording the freshly resolved expert replica group in ds_zero_partition_* metadata for such a parameter would silently reduce-scatter different experts across the wrong ranks. Raise instead, and derive the metadata from the actual partition group so it always describes the real partitioning. Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/engine.py | 26 ++++++++++++++++++++++ tests/unit/v1/moe/test_autoep_unit.py | 31 +++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6262b0bdb6cf..9b23bc5a47d6 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1656,6 +1656,14 @@ def _validate_zero3_moe_compatibility(self): raise AssertionError("AutoEP with ZeRO Stage 3 only supports " "expert_parallel.expert_tensor_parallel_size=1.") + @staticmethod + def _is_same_process_group(group_a, group_b): + if group_a is group_b: + return True + if group_a is None or group_b is None: + return False + return dist.get_all_ranks_from_group(group_a) == dist.get_all_ranks_from_group(group_b) + def _resolve_zero3_param_placement(self): for name, param in self.module.named_parameters(): family = getattr(param, "ds_zero_placement_family", "replicated") @@ -1670,6 +1678,24 @@ def _resolve_zero3_param_placement(self): else: raise AssertionError(f"Parameter '{name}' has unsupported ZeRO placement family '{family}'.") + if hasattr(param, "ds_id"): + # Already ZeRO-partitioned, e.g. converted under zero.Init. + # The partition group was fixed at conversion time and cannot + # be re-resolved here. An expert parameter partitioned over + # any other group would silently reduce-scatter different + # experts across the wrong ranks, so fail fast instead of + # recording placement metadata the partitioning does not match. + actual_group = getattr(param, "ds_process_group", None) + if family == "autoep_expert" and not self._is_same_process_group(actual_group, partition_group): + raise AssertionError(f"AutoEP expert parameter '{name}' was already ZeRO-partitioned over a " + "non-expert process group. Build the model so AutoEP expert parameters are " + "created by the engine transform instead of wrapping AutoEPMoELayer modules " + "directly in zero.Init.") + if actual_group is not None: + # Keep placement metadata consistent with the actual + # partitioning rather than the freshly resolved target. + partition_group = actual_group + param.ds_zero_placement_family = family param.ds_zero_partition_group_name = group_name param.ds_zero_partition_process_group = partition_group diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index d3fab039bb64..79e9cc04bf6b 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -611,6 +611,37 @@ def test_load_module_state_dict_nonstrict_keeps_nonstrict_semantics_with_allowed with pytest.raises(RuntimeError, match="outside AutoEP expert"): engine.load_module_state_dict(checkpoint, strict=True, allowed_missing_keys=["weight"]) + def test_resolve_zero3_param_placement_rejects_pre_partitioned_expert_on_wrong_group(self, monkeypatch): + engine = object.__new__(DeepSpeedEngine) + model = nn.Linear(2, 2, bias=False) + # bypass nn.Module.__setattr__, which requires Module.__init__ + object.__setattr__(engine, "module", model) + + expert_group = object() + other_group = object() + monkeypatch.setattr(ds_engine.groups, "_get_expert_data_parallel_group", lambda name: expert_group) + monkeypatch.setattr(ds_engine.dist, "get_rank", lambda group=None: 0) + monkeypatch.setattr(ds_engine.dist, "get_world_size", lambda group=None: 1) + monkeypatch.setattr(ds_engine.dist, + "get_all_ranks_from_group", + lambda group: [0] if group is expert_group else [0, 1], + raising=False) + + param = model.weight + param.ds_zero_placement_family = "autoep_expert" + param.ds_zero_partition_group_name = "ep_size_2" + param.ds_id = 0 + param.ds_process_group = other_group + + with pytest.raises(AssertionError, match="already ZeRO-partitioned over a non-expert process group"): + engine._resolve_zero3_param_placement() + + # A pre-partitioned expert param over the matching group is accepted + # and keeps metadata derived from its actual partition group. + param.ds_process_group = expert_group + engine._resolve_zero3_param_placement() + assert param.ds_zero_partition_process_group is expert_group + def test_autoep_zero3_16bit_export_guard_directs_to_universal_conversion(self): engine = object.__new__(DeepSpeedEngine) engine.zero_optimization_partition_weights = lambda: True From c177dbce9b845b3474f60e02528cf415860544b2 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 12:59:03 -0700 Subject: [PATCH 11/19] Update AutoEP checkpoint docs for ZeRO-3 partition-native support The AutoEP and Universal Checkpointing docs still described the pre-redesign behavior: per-expert checkpoint files for every stage, no ZeRO-3 universal conversion, and no module-only or optimizer-state-free ZeRO-3 loads. Describe the partition-native ZeRO-3 layout and the supported conversion/load paths, and limit the remaining constraints to topology changes and zero_to_fp32 consolidation. Signed-off-by: Masahiro Tanaka --- docs/_pages/config-json.md | 2 +- docs/_tutorials/universal-checkpointing.md | 52 ++++++++++++++-------- docs/code-docs/source/autoep.rst | 17 ++++--- 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 28780f2ba512..0bd9d0238d8c 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1077,7 +1077,7 @@ Use a built-in preset but override specific naming/weight fields for a fine-tune - `autoep_size` must divide `num_experts` for all detected MoE layers - AutoEP currently cannot be combined with AutoTP (`tensor_parallel.autotp_size > 1`); support is planned as follow-up work - AutoEP with ZeRO Stage 3 is supported only without AutoTP, sequence parallelism, MiCS, hpZeRO secondary tensor groups, non-1 `expert_tensor_parallel_size`, or quantized gradients -- ZeRO Stage 3 supports same-topology AutoEP checkpoint save/load when optimizer state is loaded; module-only loads, optimizer-state-skipping loads, universal checkpoint conversion, and topology-changing loads are not supported +- ZeRO Stage 3 saves AutoEP checkpoints partition-natively and supports same-topology save/load, module-only loads, optimizer-state-skipping loads, and universal checkpoint conversion with same-topology universal load; topology-changing loads (different `autoep_size` or data-parallel world size) are not supported yet ### Logging diff --git a/docs/_tutorials/universal-checkpointing.md b/docs/_tutorials/universal-checkpointing.md index e744fc655085..e5dee224721a 100644 --- a/docs/_tutorials/universal-checkpointing.md +++ b/docs/_tutorials/universal-checkpointing.md @@ -83,25 +83,37 @@ the model and topology you want to use for resumed training. ### AutoEP Requirements and Limitations AutoEP checkpoints are saved as regular DeepSpeed checkpoints, but routed expert -weights have an additional layout. With AutoEP enabled, DeepSpeed writes the -routed expert weights (`w1`, `w2`, and `w3`) into per-expert files named like -`layer__expert__mp_rank__model_states.pt`. -The regular model checkpoint records AutoEP metadata in `ds_autoep_layers`; older -checkpoints may use the legacy `autoep_layers` key. For ZeRO Stage 1 and ZeRO -Stage 2 checkpoints, router, gate, shared-expert, and other -non-routed-expert parameters stay in the regular `mp_rank_*_model_states.pt` -files and use the standard Universal Checkpointing path. - -Use ZeRO Stage 1 or ZeRO Stage 2 for the current AutoEP Universal Checkpoint -conversion path. ZeRO Stage 3 AutoEP Universal Checkpoint conversion is not -supported; when AutoEP metadata is present in `zero_pp_rank_*_model_states.pt`, -the converter raises `NotImplementedError` and instructs users to use regular -same-topology ZeRO-3 checkpoint load for AutoEP checkpoints. AutoEP ZeRO Stage 3 -checkpoints do not support Universal conversion, topology-changing loads, -module-only loads, or optimizer-state-free loads. +weights have an additional layout that depends on the ZeRO stage. With ZeRO +Stage 1 or ZeRO Stage 2, DeepSpeed writes the routed expert weights (`w1`, +`w2`, and `w3`) into per-expert files named like +`layer__expert__mp_rank__model_states.pt`, +while router, gate, shared-expert, and other non-routed-expert parameters stay +in the regular `mp_rank_*_model_states.pt` files and use the standard +Universal Checkpointing path. With ZeRO Stage 3, AutoEP checkpoints are +partition-native: no per-expert files are produced, and expert weights are +stored as ZeRO partition shards in the `zero_pp_rank_*_model_states.pt` and +optimizer shard files together with the recorded partition-group metadata. In +both layouts the regular model checkpoint records AutoEP metadata in +`ds_autoep_layers`; older checkpoints may use the legacy `autoep_layers` +key. + +Both ZeRO Stage 1/2 and ZeRO Stage 3 AutoEP checkpoints can be converted to +Universal Checkpoint format. For ZeRO Stage 3, `ds_to_universal.py` detects the +partition-native AutoEP metadata in `zero_pp_rank_*_model_states.pt`, +consolidates each expert parameter from its partition shards across the expert +replica group, and writes the same `zero/` parameter layout as the other +stages. ZeRO Stage 3 AutoEP also supports module-only loads +(`load_module_only=True`) and optimizer-state-free loads +(`load_optimizer_states=False`) from the partition shards. The remaining Stage +3 limitations are topology changes: universal load is currently supported for +the same topology only (same `autoep_size` and data-parallel world size), and +`zero_to_fp32.py` consolidation is not supported for partition-native AutoEP +checkpoints (the script raises `NotImplementedError`; use `ds_to_universal.py` +instead). During conversion, `ds_to_universal.py` reads `ds_autoep_layers` or the legacy -`autoep_layers` key, consolidates each AutoEP layer's routed expert files, and +`autoep_layers` key, consolidates each AutoEP layer's routed expert state (the +per-expert files for ZeRO Stage 1/2, the partition shards for ZeRO Stage 3), and writes full expert tensors to paths such as `zero/.w1/fp32.pt`. These files are tagged with `is_expert_param` and `ep_num_experts`, which are the load-time signals used for AutoEP expert resharding. When matching expert @@ -111,8 +123,10 @@ such as `exp_avg.pt` and `exp_avg_sq.pt` next to the converted parameter. Regular AutoEP checkpoint load requires the target run to use the same `autoep_size` as the save run. To change `autoep_size` for the same AutoEP-detected model topology, convert a ZeRO Stage 1 or ZeRO Stage 2 checkpoint -to Universal format and load the Universal checkpoint. For ZeRO Stage 3 AutoEP -checkpoints, use regular same-topology checkpoint load instead. +to Universal format and load the Universal checkpoint. ZeRO Stage 3 AutoEP +checkpoints can also be converted to and loaded from Universal format, but +currently only with the same topology, so a Stage 3 universal load does not +yet enable `autoep_size` or data-parallel world-size changes. In the Universal Checkpoint load path, AutoEP routed experts are restored from the `zero/` parameter layout rather than from the regular diff --git a/docs/code-docs/source/autoep.rst b/docs/code-docs/source/autoep.rst index 991591e4348a..87cdc025752a 100644 --- a/docs/code-docs/source/autoep.rst +++ b/docs/code-docs/source/autoep.rst @@ -51,9 +51,13 @@ Transformers build that exposes the matching config/model classes, support. Stage 3 requires AutoEP-managed MoE layers and does not support native DeepSpeed MoE layers, AutoTP, tensor model parallelism from ``mpu``, sequence parallelism, MiCS, hpZeRO secondary tensor groups, non-1 expert tensor -parallelism, or quantized gradients. Stage 3 checkpoint load is same-topology -only with optimizer state; module-only loads, optimizer-state-free loads, -Universal Checkpoint conversion, and topology changes are not supported. +parallelism, or quantized gradients. Stage 3 AutoEP checkpoints are saved +partition-natively in the ``zero_pp_rank_*`` shard files and support +same-topology load, module-only loads (``load_module_only``), +optimizer-state-free loads (``load_optimizer_states=False``), and Universal +Checkpoint conversion with same-topology universal load. Topology-changing +loads (different ``autoep_size`` or data-parallel world size) are not +supported for Stage 3 yet. **Usage:** @@ -93,8 +97,11 @@ Universal Checkpoint conversion, and topology changes are not supported. convert a ZeRO Stage 1 or ZeRO Stage 2 checkpoint to Universal Checkpoint format and load it with ``checkpoint.load_universal``; see the `Universal Checkpointing tutorial `__ - for the detailed flow and constraints. ZeRO Stage 3 AutoEP checkpoints must - be loaded with the same topology. + for the detailed flow and constraints. ZeRO Stage 3 AutoEP checkpoints can + also be converted to Universal Checkpoint format, but Stage 3 universal load + currently supports the same topology only, so both regular and universal + Stage 3 loads must keep the same ``autoep_size`` and data-parallel world + size. - DeepSeek-V2 and DeepSeek-V3 AutoEP do not support load-balance expert bias yet. The built-in DeepSeek presets disable it by default; explicit non-null values fail. From fd4a861cf99cf384c58debddb3105799e00446d9 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 17:29:03 -0700 Subject: [PATCH 12/19] Validate AutoEP universal topology loads Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/engine.py | 6 + docs/_tutorials/universal-checkpointing.md | 37 +-- docs/code-docs/source/autoep.rst | 23 +- tests/unit/v1/moe/test_autoep_checkpoint.py | 292 +++++++++++++++++++- 4 files changed, 328 insertions(+), 30 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9b23bc5a47d6..b6934e0373ee 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -4061,6 +4061,12 @@ def _load_checkpoint(self, if checkpoint is None: return None, None + if (self.load_universal_checkpoint() and self.zero_optimization_partition_weights() + and self._has_autoep_layers() and (load_module_only or not load_optimizer_states)): + raise NotImplementedError("AutoEP ZeRO-3 universal checkpoint load currently requires optimizer state. " + "Set load_optimizer_states=True and load_module_only=False; weights-only " + "AutoEP ZeRO-3 universal loads are not supported yet.") + fetch_z3_params = False z3_params_to_fetch = None autoep_partitioned_experts = False diff --git a/docs/_tutorials/universal-checkpointing.md b/docs/_tutorials/universal-checkpointing.md index e5dee224721a..9037f5dc7309 100644 --- a/docs/_tutorials/universal-checkpointing.md +++ b/docs/_tutorials/universal-checkpointing.md @@ -104,12 +104,14 @@ consolidates each expert parameter from its partition shards across the expert replica group, and writes the same `zero/` parameter layout as the other stages. ZeRO Stage 3 AutoEP also supports module-only loads (`load_module_only=True`) and optimizer-state-free loads -(`load_optimizer_states=False`) from the partition shards. The remaining Stage -3 limitations are topology changes: universal load is currently supported for -the same topology only (same `autoep_size` and data-parallel world size), and -`zero_to_fp32.py` consolidation is not supported for partition-native AutoEP -checkpoints (the script raises `NotImplementedError`; use `ds_to_universal.py` -instead). +(`load_optimizer_states=False`) from the partition shards. After conversion to +Universal Checkpoint format, ZeRO Stage 3 AutoEP can load optimizer-including +checkpoints at a different data-parallel world size, a different `autoep_size`, +or both, as long as the target AutoEP topology is valid for the same model +parameter names and expert count. Weights-only/module-only universal loads for +ZeRO Stage 3 AutoEP are not supported yet and fail fast. `zero_to_fp32.py` +consolidation is not supported for partition-native AutoEP checkpoints (the +script raises `NotImplementedError`; use `ds_to_universal.py` instead). During conversion, `ds_to_universal.py` reads `ds_autoep_layers` or the legacy `autoep_layers` key, consolidates each AutoEP layer's routed expert state (the @@ -121,19 +123,19 @@ optimizer shards are available, the converter also writes optimizer state files such as `exp_avg.pt` and `exp_avg_sq.pt` next to the converted parameter. Regular AutoEP checkpoint load requires the target run to use the same -`autoep_size` as the save run. To change `autoep_size` for the same -AutoEP-detected model topology, convert a ZeRO Stage 1 or ZeRO Stage 2 checkpoint -to Universal format and load the Universal checkpoint. ZeRO Stage 3 AutoEP -checkpoints can also be converted to and loaded from Universal format, but -currently only with the same topology, so a Stage 3 universal load does not -yet enable `autoep_size` or data-parallel world-size changes. +`autoep_size` as the save run. To change `autoep_size` or data-parallel world +size for the same AutoEP-detected model topology, convert the checkpoint to +Universal format and load the Universal checkpoint with optimizer state enabled. +For ZeRO Stage 3 AutoEP this path reslices routed expert parameters and their +Adam `fp32`, `exp_avg`, and `exp_avg_sq` states using the target runtime +topology. In the Universal Checkpoint load path, AutoEP routed experts are restored from the `zero/` parameter layout rather than from the regular `layer_*_expert_*_model_states.pt` files. The target run's AutoEP process group supplies the load-side expert-parallel rank and size. For each tagged expert tensor, the loader slices the saved expert dimension by `ep_rank` and `ep_size` -when `ep_size > 1`. +and then applies the target ZeRO partitioning group and padding. The target model still needs to expose matching AutoEP parameter names and compatible shapes, for example `.experts.w1`, @@ -145,10 +147,11 @@ valid before checkpoint loading: `autoep_size` must divide the target pipeline stage size (`world_size / pp_size`) and every detected target layer's expert count. -Topology changes are limited to `autoep_size` resharding for matching -AutoEP-managed expert parameters. For every AutoEP layer in the checkpoint, the -saved `ep_num_experts` must be divisible by the target `autoep_size` when the -target `ep_size > 1`. For example, an 8-expert checkpoint can load with target +Topology changes are limited to data-parallel world-size changes and +`autoep_size` resharding for matching AutoEP-managed expert parameters. For +every AutoEP layer in the checkpoint, the saved `ep_num_experts` must be +divisible by the target `autoep_size`. For example, an 8-expert checkpoint can +load with target `autoep_size` values of 1, 2, 4, or 8, but not 3. With `autoep_size=1`, the expert tensor is not sliced, but the target parameter must still have the compatible full expert shape. diff --git a/docs/code-docs/source/autoep.rst b/docs/code-docs/source/autoep.rst index 87cdc025752a..38190fee66bf 100644 --- a/docs/code-docs/source/autoep.rst +++ b/docs/code-docs/source/autoep.rst @@ -55,9 +55,11 @@ parallelism, or quantized gradients. Stage 3 AutoEP checkpoints are saved partition-natively in the ``zero_pp_rank_*`` shard files and support same-topology load, module-only loads (``load_module_only``), optimizer-state-free loads (``load_optimizer_states=False``), and Universal -Checkpoint conversion with same-topology universal load. Topology-changing -loads (different ``autoep_size`` or data-parallel world size) are not -supported for Stage 3 yet. +Checkpoint conversion. Optimizer-including Universal Checkpoint loads can +resume with a different data-parallel world size, a different ``autoep_size``, +or both, when the target ``autoep_size`` divides the model's expert count. +Weights-only/module-only Universal Checkpoint loads are not supported for +Stage 3 AutoEP yet. **Usage:** @@ -92,16 +94,13 @@ supported for Stage 3 yet. - AutoEP with ZeRO Stage 3 is supported only without sequence parallelism, MiCS, hpZeRO secondary tensor groups, non-1 expert tensor parallelism, or quantized gradients. -- Checkpoint save/load requires matching ``autoep_size``. To change - ``autoep_size`` across runs for the same AutoEP-detected model topology, - convert a ZeRO Stage 1 or ZeRO Stage 2 checkpoint to Universal Checkpoint - format and load it with ``checkpoint.load_universal``; see the +- Regular checkpoint save/load requires matching ``autoep_size``. To change + ``autoep_size`` or data-parallel world size across runs for the same + AutoEP-detected model topology, convert the checkpoint to Universal + Checkpoint format and load it with ``checkpoint.load_universal`` and + optimizer state enabled; see the `Universal Checkpointing tutorial `__ - for the detailed flow and constraints. ZeRO Stage 3 AutoEP checkpoints can - also be converted to Universal Checkpoint format, but Stage 3 universal load - currently supports the same topology only, so both regular and universal - Stage 3 loads must keep the same ``autoep_size`` and data-parallel world - size. + for the detailed flow and constraints. - DeepSeek-V2 and DeepSeek-V3 AutoEP do not support load-balance expert bias yet. The built-in DeepSeek presets disable it by default; explicit non-null values fail. diff --git a/tests/unit/v1/moe/test_autoep_checkpoint.py b/tests/unit/v1/moe/test_autoep_checkpoint.py index 491dca697ea2..d809b608e7bb 100644 --- a/tests/unit/v1/moe/test_autoep_checkpoint.py +++ b/tests/unit/v1/moe/test_autoep_checkpoint.py @@ -15,7 +15,7 @@ from deepspeed import comm as dist from deepspeed.checkpoint.ds_to_universal import main as convert_to_universal from deepspeed.runtime.config import DeepSpeedConfig -from unit.common import DistributedTest +from unit.common import DistributedFixture, DistributedTest from unit.v1.moe.autoep_test_utils import ( MockMoETransformer, UNSUPPORTED_LOAD_BALANCE_VALUES, @@ -26,6 +26,225 @@ seed_everything, ) +TOPOLOGY_TAG = "autoep-zero3-topology" +EXPERT_WEIGHT_NAMES = ("w1", "w2", "w3") +UNIVERSAL_STATE_KEYS = ("fp32", "exp_avg", "exp_avg_sq") + + +def _convert_checkpoint_to_universal(save_dir, tag): + checkpoint_dir = os.path.join(save_dir, tag) + universal_dir = os.path.join(save_dir, f"{tag}_universal") + args = SimpleNamespace(input_folder=checkpoint_dir, + output_folder=universal_dir, + num_extract_workers=1, + num_merge_workers=1, + keep_temp_folder=False, + strict=True, + inject_missing_state=False) + + dist.barrier() + if dist.get_rank() == 0: + convert_to_universal(args) + dist.barrier() + return universal_dir + + +def _load_universal_state(universal_dir, param_name, key): + from deepspeed.checkpoint.constants import PARAM + + return torch.load(os.path.join(universal_dir, "zero", param_name, f"{key}.pt"), + map_location="cpu", + weights_only=False)[PARAM] + + +def _load_universal_optimizer_step(universal_dir): + from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT + + state = torch.load(os.path.join(universal_dir, "zero", "optimizer_state.pt"), + map_location="cpu", + weights_only=False) + step = state[OPTIMIZER_STATE_DICT]["state"][0]["step"] + return int(step.item() if torch.is_tensor(step) else step) + + +def _assert_universal_expert_metadata(universal_dir, num_experts): + from deepspeed.checkpoint.constants import EP_IS_EXPERT_PARAM, EP_NUM_EXPERTS, PARAM + + found = 0 + nonzero_moments = {"exp_avg": False, "exp_avg_sq": False} + zero_dir = os.path.join(universal_dir, "zero") + for root, _, files in os.walk(zero_dir): + for key in UNIVERSAL_STATE_KEYS: + filename = f"{key}.pt" + if filename not in files: + continue + state = torch.load(os.path.join(root, filename), map_location="cpu", weights_only=False) + if not isinstance(state, dict) or not state.get(EP_IS_EXPERT_PARAM, False): + continue + found += 1 + assert state[EP_NUM_EXPERTS] == num_experts + assert state[PARAM].shape[0] == num_experts + if key in nonzero_moments and torch.count_nonzero(state[PARAM]).item() > 0: + nonzero_moments[key] = True + assert found > 0 + assert all(nonzero_moments.values()) + + +def _train_save_convert_autoep_zero3(tmpdir, *, tag, ep_size, num_experts=4): + seed_everything(8642 + ep_size + num_experts) + config = make_autoep_integration_config(zero_stage=3, ep_size=ep_size) + engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(num_experts=num_experts), config=config) + run_training_steps(engine, num_steps=3) + + save_dir = str(tmpdir) + engine.save_checkpoint(save_dir, tag=tag) + universal_dir = _convert_checkpoint_to_universal(save_dir, tag) + if dist.get_rank() == 0: + _assert_universal_expert_metadata(universal_dir, num_experts) + dist.barrier() + engine.destroy() + + +def _autoep_modules(engine): + from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer + + return [(name, module) for name, module in engine.module.named_modules() if isinstance(module, AutoEPMoELayer)] + + +def _expert_params(engine): + for module_name, module in _autoep_modules(engine): + module_prefix = f"{module_name}." if module_name else "" + for wname in EXPERT_WEIGHT_NAMES: + yield f"{module_prefix}experts.{wname}", module, getattr(module.experts, wname) + + +def _router_params(engine): + for module_name, module in _autoep_modules(engine): + module_prefix = f"{module_name}." if module_name else "" + for router_name, param in module.router.named_parameters(): + yield f"{module_prefix}router.{router_name}", param + + +def _gather_zero_param(param): + with deepspeed.zero.GatheredParameters([param]): + return param.detach().clone() + + +def _collect_by_ep_rank(local_tensor, ep_rank, ep_size, device): + local_tensor = local_tensor.contiguous() + gathered = [torch.zeros_like(local_tensor) for _ in range(dist.get_world_size())] + dist.all_gather(gathered, local_tensor) + + ep_rank_tensor = torch.tensor([ep_rank], dtype=torch.long, device=device) + ep_rank_tensors = [torch.zeros_like(ep_rank_tensor) for _ in range(dist.get_world_size())] + dist.all_gather(ep_rank_tensors, ep_rank_tensor) + ep_ranks = [int(t.item()) for t in ep_rank_tensors] + + if dist.get_rank() != 0: + return None + + representatives = {} + for global_rank, gathered_ep_rank in enumerate(ep_ranks): + if gathered_ep_rank in representatives: + torch.testing.assert_close(gathered[global_rank], + gathered[representatives[gathered_ep_rank]], + rtol=0, + atol=0) + else: + representatives[gathered_ep_rank] = global_rank + assert sorted(representatives) == list(range(ep_size)) + return torch.cat([gathered[representatives[rank]] for rank in range(ep_size)], dim=0).cpu() + + +def _zero_optimizer_param_state(engine, param, key): + zero_optimizer = engine.optimizer + for sub_group_id, fp16_group in enumerate(zero_optimizer.fp16_groups): + offset = 0 + for group_param in fp16_group: + partition_numel = group_param.partition_numel() + if group_param is param: + if key == "fp32": + flat_state = zero_optimizer.fp32_partitioned_groups_flat[sub_group_id] + else: + fp32_param = zero_optimizer.fp32_partitioned_groups_flat[sub_group_id] + flat_state = zero_optimizer.optimizer.state[fp32_param][key] + return flat_state.narrow(0, offset, partition_numel).detach().clone() + offset += partition_numel + param_name = engine.optimizer.param_names.get(param, "") + raise AssertionError(f"parameter {param_name} was not found in ZeRO fp16 groups") + + +def _gather_optimizer_state_for_param(engine, param, key): + local_partition = _zero_optimizer_param_state(engine, param, key).contiguous() + partition_group = getattr(param, "ds_process_group", dist.get_world_group()) + partition_world = dist.get_world_size(group=partition_group) + gathered = [torch.zeros_like(local_partition) for _ in range(partition_world)] + dist.all_gather(gathered, local_partition, group=partition_group) + full_flat = torch.cat(gathered, dim=0)[:param.ds_numel] + return full_flat.view(param.ds_shape).contiguous() + + +def _assert_router_params_match_universal(engine, universal_dir): + for param_name, param in _router_params(engine): + restored = _gather_zero_param(param).cpu() + expected = _load_universal_state(universal_dir, param_name, "fp32") + torch.testing.assert_close(restored, expected, rtol=0, atol=0) + + +def _assert_expert_params_match_universal(engine, universal_dir): + for param_name, module, param in _expert_params(engine): + local_experts = _gather_zero_param(param) + restored = _collect_by_ep_rank(local_experts, module.ep_rank, module.ep_size, engine.device) + if dist.get_rank() == 0: + expected = _load_universal_state(universal_dir, param_name, "fp32") + torch.testing.assert_close(restored, expected, rtol=0, atol=0) + + +def _assert_expert_optimizer_states_match_universal(engine, universal_dir): + nonzero_moments = {"exp_avg": False, "exp_avg_sq": False} + for param_name, module, param in _expert_params(engine): + for key in UNIVERSAL_STATE_KEYS: + local_state = _gather_optimizer_state_for_param(engine, param, key) + restored = _collect_by_ep_rank(local_state, module.ep_rank, module.ep_size, engine.device) + if dist.get_rank() == 0: + expected = _load_universal_state(universal_dir, param_name, key) + torch.testing.assert_close(restored, expected, rtol=0, atol=0) + if key in nonzero_moments and torch.count_nonzero(expected).item() > 0: + nonzero_moments[key] = True + if dist.get_rank() == 0: + assert all(nonzero_moments.values()) + dist.barrier() + + +def _assert_optimizer_step_restored(engine, universal_dir): + expected_step = _load_universal_optimizer_step(universal_dir) + steps = [] + for fp32_param in engine.optimizer.fp32_partitioned_groups_flat: + step = engine.optimizer.optimizer.state[fp32_param]["step"] + steps.append(int(step.item() if torch.is_tensor(step) else step)) + assert steps + assert expected_step > 0 + assert len(set(steps)) == 1 + assert steps[0] == expected_step + + +def _assert_topology_load_matches_universal(tmpdir, *, target_ep_size, num_experts=4, tag=TOPOLOGY_TAG): + save_dir = str(tmpdir) + universal_dir = os.path.join(save_dir, f"{tag}_universal") + config = make_autoep_integration_config(zero_stage=3, ep_size=target_ep_size) + config["checkpoint"] = {"load_universal": True} + engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(num_experts=num_experts), config=config) + engine.load_checkpoint(save_dir, tag=f"{tag}_universal") + + _assert_expert_params_match_universal(engine, universal_dir) + _assert_router_params_match_universal(engine, universal_dir) + _assert_expert_optimizer_states_match_universal(engine, universal_dir) + _assert_optimizer_step_restored(engine, universal_dir) + + losses, _ = run_training_steps(engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) + engine.destroy() + @pytest.mark.parametrize("enabled", [True, False]) @pytest.mark.parametrize("include_key", [False, True]) @@ -174,6 +393,31 @@ def test_zero3_partition_native_universal_round_trip_same_topology(self, tmpdir) losses, _ = run_training_steps(reloaded_engine, num_steps=1) assert torch.isfinite(torch.tensor(losses[0])) + def test_zero3_universal_weights_only_loads_fail_fast(self, tmpdir): + seed_everything(6420) + + config = make_autoep_integration_config(zero_stage=3, ep_size=2) + engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=config) + run_training_steps(engine, num_steps=2) + + save_dir = str(tmpdir) + tag = "autoep-zero3-universal-flags" + engine.save_checkpoint(save_dir, tag=tag) + _convert_checkpoint_to_universal(save_dir, tag) + + universal_config = make_autoep_integration_config(zero_stage=3, ep_size=2) + universal_config["checkpoint"] = {"load_universal": True} + no_optimizer_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=universal_config) + with pytest.raises(NotImplementedError, match="requires optimizer state"): + no_optimizer_engine.load_checkpoint(save_dir, tag=f"{tag}_universal", load_optimizer_states=False) + no_optimizer_engine.destroy() + + module_only_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=universal_config) + with pytest.raises(NotImplementedError, match="requires optimizer state"): + module_only_engine.load_checkpoint(save_dir, tag=f"{tag}_universal", load_module_only=True) + module_only_engine.destroy() + engine.destroy() + class TestAutoEPZero3UniversalCheckpoint4GPU(DistributedTest): world_size = 4 @@ -270,3 +514,49 @@ def test_zero3_partition_native_universal_round_trip_replica_groups_4gpu(self, t losses, _ = run_training_steps(reloaded_engine, num_steps=1) assert torch.isfinite(torch.tensor(losses[0])) + + +class _AutoEPTopologyBaselineWs4Ep2(DistributedFixture): + world_size = 4 + + def run(self, tmpdir): + _train_save_convert_autoep_zero3(tmpdir, tag=TOPOLOGY_TAG, ep_size=2) + + +@pytest.fixture(name="autoep_topology_baseline_ws4_ep2") +def autoep_topology_baseline_ws4_ep2_fixture(request): + _AutoEPTopologyBaselineWs4Ep2()(request) + + +class TestAutoEPZero3UniversalTopologyChange(DistributedTest): + world_size = 4 + + @pytest.mark.world_size(2) + def test_dp_world_size_4to2_fixed_ep_size(self, autoep_topology_baseline_ws4_ep2, tmpdir): + _assert_topology_load_matches_universal(tmpdir, target_ep_size=2) + + @pytest.mark.world_size(8) + def test_dp_world_size_4to8_fixed_ep_size(self, autoep_topology_baseline_ws4_ep2, tmpdir): + _assert_topology_load_matches_universal(tmpdir, target_ep_size=2) + + @pytest.mark.world_size(4) + def test_autoep_size_2to4_fixed_world_size(self, autoep_topology_baseline_ws4_ep2, tmpdir): + _assert_topology_load_matches_universal(tmpdir, target_ep_size=4) + + @pytest.mark.world_size(4) + def test_autoep_size_2to1_fixed_world_size(self, autoep_topology_baseline_ws4_ep2, tmpdir): + _assert_topology_load_matches_universal(tmpdir, target_ep_size=1) + + @pytest.mark.world_size(8) + def test_dp_world_size_4to8_and_autoep_size_2to4(self, autoep_topology_baseline_ws4_ep2, tmpdir): + _assert_topology_load_matches_universal(tmpdir, target_ep_size=4) + + @pytest.mark.world_size(4) + def test_universal_load_rejects_mismatched_target_expert_shape(self, autoep_topology_baseline_ws4_ep2, tmpdir): + save_dir = str(tmpdir) + config = make_autoep_integration_config(zero_stage=3, ep_size=2) + config["checkpoint"] = {"load_universal": True} + engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(num_experts=8), config=config) + with pytest.raises(ValueError, match="target_local_experts=4, checkpoint_local_experts=2"): + engine.load_checkpoint(save_dir, tag=f"{TOPOLOGY_TAG}_universal") + engine.destroy() From bac569d898851f1cbca72c0ae4ed3b68cc91a482 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 17:31:35 -0700 Subject: [PATCH 13/19] Fix AutoEP topology fixture registration Signed-off-by: Masahiro Tanaka --- tests/unit/v1/moe/test_autoep_checkpoint.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/unit/v1/moe/test_autoep_checkpoint.py b/tests/unit/v1/moe/test_autoep_checkpoint.py index d809b608e7bb..960c9dfef41f 100644 --- a/tests/unit/v1/moe/test_autoep_checkpoint.py +++ b/tests/unit/v1/moe/test_autoep_checkpoint.py @@ -516,18 +516,13 @@ def test_zero3_partition_native_universal_round_trip_replica_groups_4gpu(self, t assert torch.isfinite(torch.tensor(losses[0])) -class _AutoEPTopologyBaselineWs4Ep2(DistributedFixture): +class autoep_topology_baseline_ws4_ep2(DistributedFixture): world_size = 4 def run(self, tmpdir): _train_save_convert_autoep_zero3(tmpdir, tag=TOPOLOGY_TAG, ep_size=2) -@pytest.fixture(name="autoep_topology_baseline_ws4_ep2") -def autoep_topology_baseline_ws4_ep2_fixture(request): - _AutoEPTopologyBaselineWs4Ep2()(request) - - class TestAutoEPZero3UniversalTopologyChange(DistributedTest): world_size = 4 From 6e7400d2c1518f1cbe022a83a9d5518a0d8bd936 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 18:52:46 -0700 Subject: [PATCH 14/19] Register AutoEP topology baseline as an explicit pytest fixture The DistributedFixture subclass alone was not collected as a fixture, so all six topology-change tests failed in setup with 'fixture not found'. Wrap it in an explicit @pytest.fixture function. Signed-off-by: Masahiro Tanaka --- tests/unit/v1/moe/test_autoep_checkpoint.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unit/v1/moe/test_autoep_checkpoint.py b/tests/unit/v1/moe/test_autoep_checkpoint.py index 960c9dfef41f..c12634b46cb1 100644 --- a/tests/unit/v1/moe/test_autoep_checkpoint.py +++ b/tests/unit/v1/moe/test_autoep_checkpoint.py @@ -516,13 +516,18 @@ def test_zero3_partition_native_universal_round_trip_replica_groups_4gpu(self, t assert torch.isfinite(torch.tensor(losses[0])) -class autoep_topology_baseline_ws4_ep2(DistributedFixture): +class _AutoEPTopologyBaselineWs4Ep2(DistributedFixture): world_size = 4 def run(self, tmpdir): _train_save_convert_autoep_zero3(tmpdir, tag=TOPOLOGY_TAG, ep_size=2) +@pytest.fixture +def autoep_topology_baseline_ws4_ep2(request): + _AutoEPTopologyBaselineWs4Ep2()(request) + + class TestAutoEPZero3UniversalTopologyChange(DistributedTest): world_size = 4 From fbb1bf91909aef1741d6a647ce39576c804eb45d Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 20:19:54 -0700 Subject: [PATCH 15/19] Fix AutoEP topology universal router check Signed-off-by: Masahiro Tanaka --- tests/unit/v1/moe/test_autoep_checkpoint.py | 26 +++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/unit/v1/moe/test_autoep_checkpoint.py b/tests/unit/v1/moe/test_autoep_checkpoint.py index c12634b46cb1..427c7277a22c 100644 --- a/tests/unit/v1/moe/test_autoep_checkpoint.py +++ b/tests/unit/v1/moe/test_autoep_checkpoint.py @@ -49,12 +49,24 @@ def _convert_checkpoint_to_universal(save_dir, tag): return universal_dir -def _load_universal_state(universal_dir, param_name, key): - from deepspeed.checkpoint.constants import PARAM - +def _load_universal_file(universal_dir, param_name, key): return torch.load(os.path.join(universal_dir, "zero", param_name, f"{key}.pt"), map_location="cpu", - weights_only=False)[PARAM] + weights_only=False) + + +def _load_universal_dense_state(universal_dir, param_name, key): + state = _load_universal_file(universal_dir, param_name, key) + assert torch.is_tensor(state), f"expected raw tensor state for dense ZeRO-3 parameter {param_name}/{key}" + return state + + +def _load_universal_expert_state(universal_dir, param_name, key): + from deepspeed.checkpoint.constants import PARAM + + state = _load_universal_file(universal_dir, param_name, key) + assert isinstance(state, dict), f"expected metadata dict for AutoEP expert parameter {param_name}/{key}" + return state[PARAM] def _load_universal_optimizer_step(universal_dir): @@ -187,7 +199,7 @@ def _gather_optimizer_state_for_param(engine, param, key): def _assert_router_params_match_universal(engine, universal_dir): for param_name, param in _router_params(engine): restored = _gather_zero_param(param).cpu() - expected = _load_universal_state(universal_dir, param_name, "fp32") + expected = _load_universal_dense_state(universal_dir, param_name, "fp32").view_as(restored) torch.testing.assert_close(restored, expected, rtol=0, atol=0) @@ -196,7 +208,7 @@ def _assert_expert_params_match_universal(engine, universal_dir): local_experts = _gather_zero_param(param) restored = _collect_by_ep_rank(local_experts, module.ep_rank, module.ep_size, engine.device) if dist.get_rank() == 0: - expected = _load_universal_state(universal_dir, param_name, "fp32") + expected = _load_universal_expert_state(universal_dir, param_name, "fp32") torch.testing.assert_close(restored, expected, rtol=0, atol=0) @@ -207,7 +219,7 @@ def _assert_expert_optimizer_states_match_universal(engine, universal_dir): local_state = _gather_optimizer_state_for_param(engine, param, key) restored = _collect_by_ep_rank(local_state, module.ep_rank, module.ep_size, engine.device) if dist.get_rank() == 0: - expected = _load_universal_state(universal_dir, param_name, key) + expected = _load_universal_expert_state(universal_dir, param_name, key) torch.testing.assert_close(restored, expected, rtol=0, atol=0) if key in nonzero_moments and torch.count_nonzero(expected).item() > 0: nonzero_moments[key] = True From 18569fc05c819aee6ae8cbb7e5c908e216590372 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 20:42:01 -0700 Subject: [PATCH 16/19] Update config-json AutoEP ZeRO-3 universal topology-change note Bring the config.md constraint bullet in line with autoep.rst and the universal checkpointing tutorial: ZeRO-3 AutoEP optimizer-including universal loads can resume at a different data-parallel world size and/or autoep_size; only weights-only/module-only universal loads remain unsupported. Signed-off-by: Masahiro Tanaka --- docs/_pages/config-json.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 0bd9d0238d8c..6ae2b55f1f1e 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1077,7 +1077,7 @@ Use a built-in preset but override specific naming/weight fields for a fine-tune - `autoep_size` must divide `num_experts` for all detected MoE layers - AutoEP currently cannot be combined with AutoTP (`tensor_parallel.autotp_size > 1`); support is planned as follow-up work - AutoEP with ZeRO Stage 3 is supported only without AutoTP, sequence parallelism, MiCS, hpZeRO secondary tensor groups, non-1 `expert_tensor_parallel_size`, or quantized gradients -- ZeRO Stage 3 saves AutoEP checkpoints partition-natively and supports same-topology save/load, module-only loads, optimizer-state-skipping loads, and universal checkpoint conversion with same-topology universal load; topology-changing loads (different `autoep_size` or data-parallel world size) are not supported yet +- ZeRO Stage 3 saves AutoEP checkpoints partition-natively and supports same-topology save/load, module-only loads, optimizer-state-skipping loads, and universal checkpoint conversion. Optimizer-including universal loads can resume at a different data-parallel world size, a different `autoep_size`, or both (when the target `autoep_size` divides the expert count); weights-only/module-only universal loads for Stage 3 AutoEP are not supported yet ### Logging From 50f65163c3f808ddf5c591386f8c0cf08be72632 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 21:03:41 -0700 Subject: [PATCH 17/19] Support AutoEP universal module-only loads Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/engine.py | 6 -- deepspeed/runtime/zero/stage3.py | 39 ++++++++- docs/_pages/config-json.md | 2 +- docs/_tutorials/universal-checkpointing.md | 25 +++--- docs/code-docs/source/autoep.rst | 8 +- tests/unit/v1/moe/test_autoep_checkpoint.py | 88 +++++++++++++++++---- 6 files changed, 127 insertions(+), 41 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b6934e0373ee..9b23bc5a47d6 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -4061,12 +4061,6 @@ def _load_checkpoint(self, if checkpoint is None: return None, None - if (self.load_universal_checkpoint() and self.zero_optimization_partition_weights() - and self._has_autoep_layers() and (load_module_only or not load_optimizer_states)): - raise NotImplementedError("AutoEP ZeRO-3 universal checkpoint load currently requires optimizer state. " - "Set load_optimizer_states=True and load_module_only=False; weights-only " - "AutoEP ZeRO-3 universal loads are not supported yet.") - fetch_z3_params = False z3_params_to_fetch = None autoep_partitioned_experts = False diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ac1341d55641..10953189dde1 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -3313,7 +3313,10 @@ def load_state_dict(self, # self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): - self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder) + if load_optimizer_states: + self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder) + else: + self.load_module_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder) def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): """ Load optimizer and model states from the checkpoint directory. """ @@ -3382,6 +3385,40 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): partitioned_param.data = q.data + def load_module_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): + """Load module parameter partitions from a ZeRO-3 universal checkpoint.""" + checkpoint_dir = os.path.join(checkpoint_dir, "zero") + for sub_group_id, fp16_group in enumerate(self.fp16_groups): + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + if fp16_param is None: + raise RuntimeError("ZeRO-3 universal module-only checkpoint load requires available parameter " + f"partitions for subgroup {sub_group_id}.") + module_param_partition = torch.zeros_like(fp16_param) + offset = 0 + for param in fp16_group: + if param not in self.param_names: + raise ValueError("failed to find optimizer param in named params") + param_name = self.param_names[param] + param_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, param_name), + "fp32", + param=param) + numel = param_partition.numel() + module_param_partition.narrow(0, offset, numel).copy_( + param_partition.to(device=module_param_partition.device, dtype=module_param_partition.dtype)) + offset += numel + fp16_param.data.copy_(module_param_partition) + + # Keep fp32 master weights consistent for warm-start/fine-tuning cases + # that intentionally skipped optimizer state. + self.refresh_fp32_params() + + for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + def _load_global_state_stage3(self, sd): self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 6ae2b55f1f1e..6d395ae33370 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1077,7 +1077,7 @@ Use a built-in preset but override specific naming/weight fields for a fine-tune - `autoep_size` must divide `num_experts` for all detected MoE layers - AutoEP currently cannot be combined with AutoTP (`tensor_parallel.autotp_size > 1`); support is planned as follow-up work - AutoEP with ZeRO Stage 3 is supported only without AutoTP, sequence parallelism, MiCS, hpZeRO secondary tensor groups, non-1 `expert_tensor_parallel_size`, or quantized gradients -- ZeRO Stage 3 saves AutoEP checkpoints partition-natively and supports same-topology save/load, module-only loads, optimizer-state-skipping loads, and universal checkpoint conversion. Optimizer-including universal loads can resume at a different data-parallel world size, a different `autoep_size`, or both (when the target `autoep_size` divides the expert count); weights-only/module-only universal loads for Stage 3 AutoEP are not supported yet +- ZeRO Stage 3 saves AutoEP checkpoints partition-natively and supports same-topology save/load, module-only loads, optimizer-state-skipping loads, and universal checkpoint conversion. Universal loads can resume at a different data-parallel world size, a different `autoep_size`, or both (when the target `autoep_size` divides the expert count), including weights-only/module-only loads from the converted `fp32.pt` parameter files ### Logging diff --git a/docs/_tutorials/universal-checkpointing.md b/docs/_tutorials/universal-checkpointing.md index 9037f5dc7309..352549349a37 100644 --- a/docs/_tutorials/universal-checkpointing.md +++ b/docs/_tutorials/universal-checkpointing.md @@ -104,14 +104,14 @@ consolidates each expert parameter from its partition shards across the expert replica group, and writes the same `zero/` parameter layout as the other stages. ZeRO Stage 3 AutoEP also supports module-only loads (`load_module_only=True`) and optimizer-state-free loads -(`load_optimizer_states=False`) from the partition shards. After conversion to -Universal Checkpoint format, ZeRO Stage 3 AutoEP can load optimizer-including -checkpoints at a different data-parallel world size, a different `autoep_size`, -or both, as long as the target AutoEP topology is valid for the same model -parameter names and expert count. Weights-only/module-only universal loads for -ZeRO Stage 3 AutoEP are not supported yet and fail fast. `zero_to_fp32.py` -consolidation is not supported for partition-native AutoEP checkpoints (the -script raises `NotImplementedError`; use `ds_to_universal.py` instead). +(`load_optimizer_states=False`) from both partition shards and Universal +Checkpoint format. After conversion to Universal Checkpoint format, ZeRO Stage 3 +AutoEP can load optimizer-including or weights-only/module-only checkpoints at a +different data-parallel world size, a different `autoep_size`, or both, as long +as the target AutoEP topology is valid for the same model parameter names and +expert count. `zero_to_fp32.py` consolidation is not supported for +partition-native AutoEP checkpoints (the script raises `NotImplementedError`; +use `ds_to_universal.py` instead). During conversion, `ds_to_universal.py` reads `ds_autoep_layers` or the legacy `autoep_layers` key, consolidates each AutoEP layer's routed expert state (the @@ -125,10 +125,11 @@ such as `exp_avg.pt` and `exp_avg_sq.pt` next to the converted parameter. Regular AutoEP checkpoint load requires the target run to use the same `autoep_size` as the save run. To change `autoep_size` or data-parallel world size for the same AutoEP-detected model topology, convert the checkpoint to -Universal format and load the Universal checkpoint with optimizer state enabled. -For ZeRO Stage 3 AutoEP this path reslices routed expert parameters and their -Adam `fp32`, `exp_avg`, and `exp_avg_sq` states using the target runtime -topology. +Universal format and load the Universal checkpoint. For ZeRO Stage 3 AutoEP, +optimizer-including loads reslice routed expert parameters and their Adam +`fp32`, `exp_avg`, and `exp_avg_sq` states using the target runtime topology; +weights-only/module-only loads reslice routed expert parameters and standard +parameters from the universal `fp32.pt` files without requiring optimizer state. In the Universal Checkpoint load path, AutoEP routed experts are restored from the `zero/` parameter layout rather than from the regular diff --git a/docs/code-docs/source/autoep.rst b/docs/code-docs/source/autoep.rst index 38190fee66bf..b7b7fa293d82 100644 --- a/docs/code-docs/source/autoep.rst +++ b/docs/code-docs/source/autoep.rst @@ -58,8 +58,9 @@ optimizer-state-free loads (``load_optimizer_states=False``), and Universal Checkpoint conversion. Optimizer-including Universal Checkpoint loads can resume with a different data-parallel world size, a different ``autoep_size``, or both, when the target ``autoep_size`` divides the model's expert count. -Weights-only/module-only Universal Checkpoint loads are not supported for -Stage 3 AutoEP yet. +Weights-only/module-only Universal Checkpoint loads use the converted +``fp32.pt`` parameter files and support the same data-parallel and +``autoep_size`` topology changes. **Usage:** @@ -97,8 +98,7 @@ Stage 3 AutoEP yet. - Regular checkpoint save/load requires matching ``autoep_size``. To change ``autoep_size`` or data-parallel world size across runs for the same AutoEP-detected model topology, convert the checkpoint to Universal - Checkpoint format and load it with ``checkpoint.load_universal`` and - optimizer state enabled; see the + Checkpoint format and load it with ``checkpoint.load_universal``; see the `Universal Checkpointing tutorial `__ for the detailed flow and constraints. - DeepSeek-V2 and DeepSeek-V3 AutoEP do not support load-balance expert bias diff --git a/tests/unit/v1/moe/test_autoep_checkpoint.py b/tests/unit/v1/moe/test_autoep_checkpoint.py index 427c7277a22c..ea7dd7f72c37 100644 --- a/tests/unit/v1/moe/test_autoep_checkpoint.py +++ b/tests/unit/v1/moe/test_autoep_checkpoint.py @@ -137,6 +137,14 @@ def _router_params(engine): yield f"{module_prefix}router.{router_name}", param +def _shared_params(engine): + routed_expert_names = {param_name for param_name, _, _ in _expert_params(engine)} + router_names = {param_name for param_name, _ in _router_params(engine)} + for param_name, param in engine.module.named_parameters(): + if param_name not in routed_expert_names and param_name not in router_names: + yield param_name, param + + def _gather_zero_param(param): with deepspeed.zero.GatheredParameters([param]): return param.detach().clone() @@ -203,6 +211,13 @@ def _assert_router_params_match_universal(engine, universal_dir): torch.testing.assert_close(restored, expected, rtol=0, atol=0) +def _assert_shared_params_match_universal(engine, universal_dir): + for param_name, param in _shared_params(engine): + restored = _gather_zero_param(param).cpu() + expected = _load_universal_dense_state(universal_dir, param_name, "fp32").view_as(restored) + torch.testing.assert_close(restored, expected, rtol=0, atol=0) + + def _assert_expert_params_match_universal(engine, universal_dir): for param_name, module, param in _expert_params(engine): local_experts = _gather_zero_param(param) @@ -228,6 +243,12 @@ def _assert_expert_optimizer_states_match_universal(engine, universal_dir): dist.barrier() +def _assert_module_params_match_universal(engine, universal_dir): + _assert_expert_params_match_universal(engine, universal_dir) + _assert_router_params_match_universal(engine, universal_dir) + _assert_shared_params_match_universal(engine, universal_dir) + + def _assert_optimizer_step_restored(engine, universal_dir): expected_step = _load_universal_optimizer_step(universal_dir) steps = [] @@ -240,19 +261,32 @@ def _assert_optimizer_step_restored(engine, universal_dir): assert steps[0] == expected_step -def _assert_topology_load_matches_universal(tmpdir, *, target_ep_size, num_experts=4, tag=TOPOLOGY_TAG): +def _assert_forward_runs(engine): + with torch.no_grad(): + output = engine(torch.randn(1, 8, 64, device=engine.device)) + assert torch.isfinite(output.float()).all() + + +def _assert_topology_load_matches_universal(tmpdir, + *, + target_ep_size, + num_experts=4, + tag=TOPOLOGY_TAG, + load_kwargs=None, + check_optimizer_states=True): save_dir = str(tmpdir) universal_dir = os.path.join(save_dir, f"{tag}_universal") config = make_autoep_integration_config(zero_stage=3, ep_size=target_ep_size) config["checkpoint"] = {"load_universal": True} engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(num_experts=num_experts), config=config) - engine.load_checkpoint(save_dir, tag=f"{tag}_universal") + engine.load_checkpoint(save_dir, tag=f"{tag}_universal", **(load_kwargs or {})) - _assert_expert_params_match_universal(engine, universal_dir) - _assert_router_params_match_universal(engine, universal_dir) - _assert_expert_optimizer_states_match_universal(engine, universal_dir) - _assert_optimizer_step_restored(engine, universal_dir) + _assert_module_params_match_universal(engine, universal_dir) + if check_optimizer_states: + _assert_expert_optimizer_states_match_universal(engine, universal_dir) + _assert_optimizer_step_restored(engine, universal_dir) + _assert_forward_runs(engine) losses, _ = run_training_steps(engine, num_steps=1) assert torch.isfinite(torch.tensor(losses[0])) engine.destroy() @@ -405,7 +439,7 @@ def test_zero3_partition_native_universal_round_trip_same_topology(self, tmpdir) losses, _ = run_training_steps(reloaded_engine, num_steps=1) assert torch.isfinite(torch.tensor(losses[0])) - def test_zero3_universal_weights_only_loads_fail_fast(self, tmpdir): + def _assert_zero3_universal_weights_only_load(self, tmpdir, load_kwargs): seed_everything(6420) config = make_autoep_integration_config(zero_stage=3, ep_size=2) @@ -415,21 +449,27 @@ def test_zero3_universal_weights_only_loads_fail_fast(self, tmpdir): save_dir = str(tmpdir) tag = "autoep-zero3-universal-flags" engine.save_checkpoint(save_dir, tag=tag) - _convert_checkpoint_to_universal(save_dir, tag) + universal_dir = _convert_checkpoint_to_universal(save_dir, tag) universal_config = make_autoep_integration_config(zero_stage=3, ep_size=2) universal_config["checkpoint"] = {"load_universal": True} - no_optimizer_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=universal_config) - with pytest.raises(NotImplementedError, match="requires optimizer state"): - no_optimizer_engine.load_checkpoint(save_dir, tag=f"{tag}_universal", load_optimizer_states=False) - no_optimizer_engine.destroy() - - module_only_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=universal_config) - with pytest.raises(NotImplementedError, match="requires optimizer state"): - module_only_engine.load_checkpoint(save_dir, tag=f"{tag}_universal", load_module_only=True) - module_only_engine.destroy() + reloaded_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=universal_config) + reloaded_engine.load_checkpoint(save_dir, tag=f"{tag}_universal", **load_kwargs) + + _assert_module_params_match_universal(reloaded_engine, universal_dir) + _assert_forward_runs(reloaded_engine) + losses, _ = run_training_steps(reloaded_engine, num_steps=1) + assert torch.isfinite(torch.tensor(losses[0])) + + reloaded_engine.destroy() engine.destroy() + def test_zero3_universal_load_optimizer_states_false_same_topology(self, tmpdir): + self._assert_zero3_universal_weights_only_load(tmpdir, {"load_optimizer_states": False}) + + def test_zero3_universal_module_only_same_topology(self, tmpdir): + self._assert_zero3_universal_weights_only_load(tmpdir, {"load_module_only": True}) + class TestAutoEPZero3UniversalCheckpoint4GPU(DistributedTest): world_size = 4 @@ -563,6 +603,20 @@ def test_autoep_size_2to1_fixed_world_size(self, autoep_topology_baseline_ws4_ep def test_dp_world_size_4to8_and_autoep_size_2to4(self, autoep_topology_baseline_ws4_ep2, tmpdir): _assert_topology_load_matches_universal(tmpdir, target_ep_size=4) + @pytest.mark.world_size(2) + def test_module_only_dp_world_size_4to2_fixed_ep_size(self, autoep_topology_baseline_ws4_ep2, tmpdir): + _assert_topology_load_matches_universal(tmpdir, + target_ep_size=2, + load_kwargs={"load_module_only": True}, + check_optimizer_states=False) + + @pytest.mark.world_size(4) + def test_load_optimizer_states_false_autoep_size_2to4(self, autoep_topology_baseline_ws4_ep2, tmpdir): + _assert_topology_load_matches_universal(tmpdir, + target_ep_size=4, + load_kwargs={"load_optimizer_states": False}, + check_optimizer_states=False) + @pytest.mark.world_size(4) def test_universal_load_rejects_mismatched_target_expert_shape(self, autoep_topology_baseline_ws4_ep2, tmpdir): save_dir = str(tmpdir) From f2d422d436e0f8a7e99b517efbaaea35366ce317 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 12 Jun 2026 21:26:06 -0700 Subject: [PATCH 18/19] Preserve fp32 masters for AutoEP module-only loads Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/stage3.py | 16 ++++-- tests/unit/v1/moe/test_autoep_checkpoint.py | 62 ++++++++++++++++++++- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 10953189dde1..e0856570162c 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -3316,7 +3316,7 @@ def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, l if load_optimizer_states: self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder) else: - self.load_module_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder) + self.load_module_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, load_from_fp32_weights) def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): """ Load optimizer and model states from the checkpoint directory. """ @@ -3385,7 +3385,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): partitioned_param.data = q.data - def load_module_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): + def load_module_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, load_from_fp32_weights): """Load module parameter partitions from a ZeRO-3 universal checkpoint.""" checkpoint_dir = os.path.join(checkpoint_dir, "zero") for sub_group_id, fp16_group in enumerate(self.fp16_groups): @@ -3393,7 +3393,9 @@ def load_module_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir if fp16_param is None: raise RuntimeError("ZeRO-3 universal module-only checkpoint load requires available parameter " f"partitions for subgroup {sub_group_id}.") + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] module_param_partition = torch.zeros_like(fp16_param) + fp32_param_partition = torch.zeros_like(fp32_param) if load_from_fp32_weights else None offset = 0 for param in fp16_group: if param not in self.param_names: @@ -3405,12 +3407,16 @@ def load_module_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir numel = param_partition.numel() module_param_partition.narrow(0, offset, numel).copy_( param_partition.to(device=module_param_partition.device, dtype=module_param_partition.dtype)) + if load_from_fp32_weights: + fp32_param_partition.narrow(0, offset, numel).copy_( + param_partition.to(device=fp32_param_partition.device, dtype=fp32_param_partition.dtype)) offset += numel fp16_param.data.copy_(module_param_partition) + if load_from_fp32_weights: + fp32_param.data.copy_(fp32_param_partition) - # Keep fp32 master weights consistent for warm-start/fine-tuning cases - # that intentionally skipped optimizer state. - self.refresh_fp32_params() + if not load_from_fp32_weights: + self.refresh_fp32_params() for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], diff --git a/tests/unit/v1/moe/test_autoep_checkpoint.py b/tests/unit/v1/moe/test_autoep_checkpoint.py index ea7dd7f72c37..e3328b9c4cf1 100644 --- a/tests/unit/v1/moe/test_autoep_checkpoint.py +++ b/tests/unit/v1/moe/test_autoep_checkpoint.py @@ -20,7 +20,9 @@ MockMoETransformer, UNSUPPORTED_LOAD_BALANCE_VALUES, assert_load_balance_coeff_rejection_message, + engine_input_dtype, init_autoep_engine, + make_autoep_config, make_autoep_integration_config, run_training_steps, seed_everything, @@ -243,6 +245,29 @@ def _assert_expert_optimizer_states_match_universal(engine, universal_dir): dist.barrier() +def _assert_expert_fp32_master_params_match_universal(engine, universal_dir): + for param_name, module, param in _expert_params(engine): + local_state = _gather_optimizer_state_for_param(engine, param, "fp32") + restored = _collect_by_ep_rank(local_state, module.ep_rank, module.ep_size, engine.device) + if dist.get_rank() == 0: + expected = _load_universal_expert_state(universal_dir, param_name, "fp32") + torch.testing.assert_close(restored, expected, rtol=0, atol=0) + dist.barrier() + + +def _assert_dense_fp32_master_params_match_universal(engine, universal_dir, param_iter): + for param_name, param in param_iter: + restored = _gather_optimizer_state_for_param(engine, param, "fp32").cpu() + expected = _load_universal_dense_state(universal_dir, param_name, "fp32").view_as(restored) + torch.testing.assert_close(restored, expected, rtol=0, atol=0) + + +def _assert_fp32_master_params_match_universal(engine, universal_dir): + _assert_expert_fp32_master_params_match_universal(engine, universal_dir) + _assert_dense_fp32_master_params_match_universal(engine, universal_dir, _router_params(engine)) + _assert_dense_fp32_master_params_match_universal(engine, universal_dir, _shared_params(engine)) + + def _assert_module_params_match_universal(engine, universal_dir): _assert_expert_params_match_universal(engine, universal_dir) _assert_router_params_match_universal(engine, universal_dir) @@ -263,10 +288,21 @@ def _assert_optimizer_step_restored(engine, universal_dir): def _assert_forward_runs(engine): with torch.no_grad(): - output = engine(torch.randn(1, 8, 64, device=engine.device)) + output = engine(torch.randn(1, 8, 64, device=engine.device, dtype=engine_input_dtype(engine))) assert torch.isfinite(output.float()).all() +def _run_training_steps_with_engine_input_dtype(engine, num_steps=2, seq_len=8, hidden_dim=64): + losses = [] + for _ in range(num_steps): + x = torch.randn(1, seq_len, hidden_dim, device=engine.device, dtype=engine_input_dtype(engine)) + loss = engine(x).mean() + engine.backward(loss) + engine.step() + losses.append(loss.item()) + return losses + + def _assert_topology_load_matches_universal(tmpdir, *, target_ep_size, @@ -470,6 +506,30 @@ def test_zero3_universal_load_optimizer_states_false_same_topology(self, tmpdir) def test_zero3_universal_module_only_same_topology(self, tmpdir): self._assert_zero3_universal_weights_only_load(tmpdir, {"load_module_only": True}) + @pytest.mark.parametrize("load_kwargs", [{"load_optimizer_states": False}, {"load_module_only": True}]) + def test_zero3_universal_weights_only_preserves_fp32_master_weights(self, tmpdir, load_kwargs): + seed_everything(6421) + + config = make_autoep_config(zero_stage=3, ep_size=2) + engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=config) + _run_training_steps_with_engine_input_dtype(engine, num_steps=2) + + save_dir = str(tmpdir) + tag = "autoep-zero3-universal-fp32-master" + engine.save_checkpoint(save_dir, tag=tag) + universal_dir = _convert_checkpoint_to_universal(save_dir, tag) + + universal_config = make_autoep_config(zero_stage=3, ep_size=2) + universal_config["checkpoint"] = {"load_universal": True} + reloaded_engine, _, _, _ = deepspeed.initialize(model=MockMoETransformer(), config=universal_config) + reloaded_engine.load_checkpoint(save_dir, tag=f"{tag}_universal", **load_kwargs) + + _assert_fp32_master_params_match_universal(reloaded_engine, universal_dir) + _assert_forward_runs(reloaded_engine) + + reloaded_engine.destroy() + engine.destroy() + class TestAutoEPZero3UniversalCheckpoint4GPU(DistributedTest): world_size = 4 From a90deeced76fee325f9b8728ff1f9ea664030b34 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 22 Jun 2026 22:37:30 -0700 Subject: [PATCH 19/19] Address AutoEP ZeRO-3 review comments Signed-off-by: Masahiro Tanaka --- deepspeed/checkpoint/autoep_zero3_metadata.py | 109 ++++++++++ deepspeed/checkpoint/ds_to_universal.py | 85 +------- deepspeed/runtime/engine.py | 201 ++++-------------- tests/unit/v1/moe/test_autoep_grad_parity.py | 53 ++++- 4 files changed, 212 insertions(+), 236 deletions(-) create mode 100644 deepspeed/checkpoint/autoep_zero3_metadata.py diff --git a/deepspeed/checkpoint/autoep_zero3_metadata.py b/deepspeed/checkpoint/autoep_zero3_metadata.py new file mode 100644 index 000000000000..f854d17c2738 --- /dev/null +++ b/deepspeed/checkpoint/autoep_zero3_metadata.py @@ -0,0 +1,109 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Shared validation for AutoEP ZeRO-3 checkpoint metadata.""" + +from deepspeed.checkpoint.constants import ( + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY, + AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, +) + +AUTOEP_METADATA_REQUIRED_FIELDS = frozenset({ + 'moe_layer_id', + 'module_path', + 'num_experts', + 'num_local_experts', + 'ep_size', + 'expert_key_prefix', +}) + +AUTOEP_ZERO3_PARTITIONED_METADATA_FIELDS = frozenset({ + AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, + 'ep_group_name', + 'ep_rank', + 'expert_data_parallel_rank', + 'expert_data_parallel_world_size', + 'global_expert_start', + 'global_expert_end', +}) + + +def is_autoep_zero3_partitioned_entry(entry): + return (isinstance(entry, dict) + and entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT) + + +def validate_autoep_zero3_partitioned_metadata(autoep_metadata, + require_partitioned=True, + expected_expert_prefixes=None, + version_context="This DeepSpeed build"): + if not isinstance(autoep_metadata, list): + raise RuntimeError(f"ds_autoep_layers metadata is malformed: expected list, got " + f"{type(autoep_metadata).__name__}") + + seen_layer_ids = set() + seen_prefixes = set() + partitioned_count = 0 + + for entry in autoep_metadata: + if not isinstance(entry, dict): + raise RuntimeError(f"ds_autoep_layers entry is malformed: expected dict, got " + f"{type(entry).__name__}") + missing = AUTOEP_METADATA_REQUIRED_FIELDS - entry.keys() + if missing: + raise RuntimeError(f"ds_autoep_layers entry is invalid: missing fields {sorted(missing)}") + + layer_id = entry['moe_layer_id'] + if layer_id in seen_layer_ids: + raise RuntimeError(f"ds_autoep_layers metadata has duplicate moe_layer_id: {layer_id}") + seen_layer_ids.add(layer_id) + + prefix = entry['expert_key_prefix'] + if prefix in seen_prefixes: + raise RuntimeError(f"ds_autoep_layers metadata has duplicate expert_key_prefix: {prefix}") + seen_prefixes.add(prefix) + + if not is_autoep_zero3_partitioned_entry(entry): + continue + + missing = AUTOEP_ZERO3_PARTITIONED_METADATA_FIELDS - entry.keys() + if missing: + raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata is invalid: missing fields {sorted(missing)}") + version = entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY] + if version != AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION: + raise RuntimeError("Unsupported AutoEP ZeRO-3 checkpoint format version: " + f"{version}. {version_context} supports version " + f"{AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION}.") + + num_experts = entry['num_experts'] + num_local_experts = entry['num_local_experts'] + ep_size = entry['ep_size'] + if num_local_experts * ep_size != num_experts: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata is inconsistent: " + f"num_local_experts={num_local_experts}, ep_size={ep_size}, " + f"num_experts={num_experts}") + + expected_start = entry['ep_rank'] * num_local_experts + expected_end = expected_start + num_local_experts + if entry['global_expert_start'] != expected_start or entry['global_expert_end'] != expected_end: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has inconsistent global expert range: " + f"got [{entry['global_expert_start']}, {entry['global_expert_end']}), " + f"expected [{expected_start}, {expected_end})") + + if expected_expert_prefixes is not None: + module_path = entry['module_path'] + if module_path not in expected_expert_prefixes: + raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata references missing module: {module_path}") + expected_prefix = expected_expert_prefixes[module_path] + if prefix != expected_prefix: + raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has unexpected expert key prefix: " + f"got {prefix}, expected {expected_prefix}") + + partitioned_count += 1 + + if require_partitioned and partitioned_count == 0: + raise RuntimeError("AutoEP ZeRO-3 partition-native checkpoint metadata was expected but no " + "partitioned AutoEP layer entries were found") diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 6ba90a14fa67..d0573a85906d 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -45,15 +45,15 @@ PARAMETER_WITH_SUB_PARAMS, AUTOEP_LAYERS_KEY, AUTOEP_LAYERS_KEY_LEGACY, - AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION, - AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, - AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY, - AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT, EP_IS_EXPERT_PARAM, EP_NUM_EXPERTS, EXPERT_PARAMETER_PATTERNS, SubparamShape, ) +from deepspeed.checkpoint.autoep_zero3_metadata import ( + is_autoep_zero3_partitioned_entry, + validate_autoep_zero3_partitioned_metadata, +) def parse_arguments(): @@ -449,80 +449,13 @@ def _uses_zero3_partitioned_autoep_metadata(autoep_metadata): if not isinstance(autoep_metadata, list): return False _validate_zero3_partitioned_autoep_metadata(autoep_metadata, require_partitioned=False) - return any( - isinstance(entry, dict) - and entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT - for entry in autoep_metadata) + return any(is_autoep_zero3_partitioned_entry(entry) for entry in autoep_metadata) def _validate_zero3_partitioned_autoep_metadata(autoep_metadata, require_partitioned=True): - if not isinstance(autoep_metadata, list): - raise RuntimeError(f"ds_autoep_layers metadata is malformed: expected list, got " - f"{type(autoep_metadata).__name__}") - - required_fields = { - 'moe_layer_id', - 'module_path', - 'num_experts', - 'num_local_experts', - 'ep_size', - 'expert_key_prefix', - } - partitioned_fields = { - AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, - 'ep_rank', - 'expert_data_parallel_rank', - 'expert_data_parallel_world_size', - 'global_expert_start', - 'global_expert_end', - } - seen_layer_ids = set() - seen_prefixes = set() - partitioned_count = 0 - - for entry in autoep_metadata: - if not isinstance(entry, dict): - raise RuntimeError(f"ds_autoep_layers entry is malformed: expected dict, got {type(entry).__name__}") - missing = required_fields - entry.keys() - if missing: - raise RuntimeError(f"ds_autoep_layers entry is invalid: missing fields {sorted(missing)}") - layer_id = entry['moe_layer_id'] - if layer_id in seen_layer_ids: - raise RuntimeError(f"ds_autoep_layers metadata has duplicate moe_layer_id: {layer_id}") - seen_layer_ids.add(layer_id) - prefix = entry['expert_key_prefix'] - if prefix in seen_prefixes: - raise RuntimeError(f"ds_autoep_layers metadata has duplicate expert_key_prefix: {prefix}") - seen_prefixes.add(prefix) - - checkpoint_format = entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) - if checkpoint_format is None: - continue - if checkpoint_format != AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT: - continue - - missing = partitioned_fields - entry.keys() - if missing: - raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata is invalid: missing fields {sorted(missing)}") - version = entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY] - if version != AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION: - raise RuntimeError("Unsupported AutoEP ZeRO-3 checkpoint format version: " - f"{version}. This converter supports version " - f"{AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION}.") - if entry['num_local_experts'] * entry['ep_size'] != entry['num_experts']: - raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata is inconsistent: " - f"num_local_experts={entry['num_local_experts']}, ep_size={entry['ep_size']}, " - f"num_experts={entry['num_experts']}") - expected_start = entry['ep_rank'] * entry['num_local_experts'] - expected_end = expected_start + entry['num_local_experts'] - if entry['global_expert_start'] != expected_start or entry['global_expert_end'] != expected_end: - raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has inconsistent global expert range: " - f"got [{entry['global_expert_start']}, {entry['global_expert_end']}), " - f"expected [{expected_start}, {expected_end})") - partitioned_count += 1 - - if require_partitioned and partitioned_count == 0: - raise RuntimeError("Expected AutoEP ZeRO-3 partition-native metadata but found no partitioned entries") + validate_autoep_zero3_partitioned_metadata(autoep_metadata, + require_partitioned=require_partitioned, + version_context="This converter") def _autoep_expert_param_info(autoep_metadata): @@ -533,7 +466,7 @@ def _autoep_expert_param_info(autoep_metadata): for entry in autoep_metadata: if not isinstance(entry, dict): continue - if entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) != AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT: + if not is_autoep_zero3_partitioned_entry(entry): continue prefix = entry.get('expert_key_prefix') if not prefix: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9b23bc5a47d6..571cba1d2164 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -82,6 +82,10 @@ UNIVERSAL_CHECKPOINT_VERSION_KEY, UNIVERSAL_CHECKPOINT_VERSION_VALUE, ) +from deepspeed.checkpoint.autoep_zero3_metadata import ( + is_autoep_zero3_partitioned_entry, + validate_autoep_zero3_partitioned_metadata, +) from deepspeed.checkpoint.utils import clone_tensors_for_torch_save from deepspeed.checkpoint.ds_to_universal import dp_index_to_str from deepspeed.runtime.sparse_tensor import SparseTensor @@ -3901,112 +3905,28 @@ def _uses_autoep_zero3_partitioned_experts(autoep_layers): if not isinstance(autoep_layers, list): return False DeepSpeedEngine._validate_autoep_zero3_partitioned_metadata(autoep_layers, require_partitioned=False) - return any( - isinstance(entry, dict) - and entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT - for entry in autoep_layers) + return any(is_autoep_zero3_partitioned_entry(entry) for entry in autoep_layers) @staticmethod def _validate_autoep_zero3_partitioned_metadata(autoep_layers, model=None, require_partitioned=True): - if not isinstance(autoep_layers, list): - raise RuntimeError(f"ds_autoep_layers metadata is malformed: expected list, got " - f"{type(autoep_layers).__name__}") - try: from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer except ImportError: _AutoEPMoELayer = None - live_autoep_modules = {} + expected_expert_prefixes = None if _AutoEPMoELayer is not None and model is not None: - live_autoep_modules = { - module_name: module + expected_expert_prefixes = { + module_name: f"{module_name}.experts" if module_name else "experts" for module_name, module in model.named_modules() if isinstance(module, _AutoEPMoELayer) } + if not expected_expert_prefixes: + expected_expert_prefixes = None - required_fields = { - 'moe_layer_id', - 'module_path', - 'num_experts', - 'num_local_experts', - 'ep_size', - 'expert_key_prefix', - } - partitioned_fields = { - AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY, - 'ep_group_name', - 'ep_rank', - 'expert_data_parallel_rank', - 'expert_data_parallel_world_size', - 'global_expert_start', - 'global_expert_end', - } - seen_layer_ids = set() - seen_prefixes = set() - partitioned_count = 0 - - for entry in autoep_layers: - if not isinstance(entry, dict): - raise RuntimeError(f"ds_autoep_layers entry is malformed: expected dict, got " - f"{type(entry).__name__}") - missing = required_fields - entry.keys() - if missing: - raise RuntimeError(f"ds_autoep_layers entry is invalid: missing fields {sorted(missing)}") - - layer_id = entry['moe_layer_id'] - if layer_id in seen_layer_ids: - raise RuntimeError(f"ds_autoep_layers metadata has duplicate moe_layer_id: {layer_id}") - seen_layer_ids.add(layer_id) - - prefix = entry['expert_key_prefix'] - if prefix in seen_prefixes: - raise RuntimeError(f"ds_autoep_layers metadata has duplicate expert_key_prefix: {prefix}") - seen_prefixes.add(prefix) - - checkpoint_format = entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) - if checkpoint_format is None: - continue - if checkpoint_format != AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT: - continue - - missing = partitioned_fields - entry.keys() - if missing: - raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata is invalid: missing fields {sorted(missing)}") - if entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY] != AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION: - raise RuntimeError("Unsupported AutoEP ZeRO-3 checkpoint format version: " - f"{entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY]}. " - f"This DeepSpeed build supports version " - f"{AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION}.") - - num_experts = entry['num_experts'] - num_local_experts = entry['num_local_experts'] - ep_size = entry['ep_size'] - if num_local_experts * ep_size != num_experts: - raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata is inconsistent: " - f"num_local_experts={num_local_experts}, ep_size={ep_size}, " - f"num_experts={num_experts}") - - expected_start = entry['ep_rank'] * num_local_experts - expected_end = expected_start + num_local_experts - if entry['global_expert_start'] != expected_start or entry['global_expert_end'] != expected_end: - raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has inconsistent global expert range: " - f"got [{entry['global_expert_start']}, {entry['global_expert_end']}), " - f"expected [{expected_start}, {expected_end})") - - module_path = entry['module_path'] - if live_autoep_modules and module_path not in live_autoep_modules: - raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata references missing module: {module_path}") - if live_autoep_modules: - expected_prefix = f"{module_path}.experts" if module_path else "experts" - if prefix != expected_prefix: - raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has unexpected expert key prefix: " - f"got {prefix}, expected {expected_prefix}") - - partitioned_count += 1 - - if require_partitioned and partitioned_count == 0: - raise RuntimeError("AutoEP ZeRO-3 partition-native checkpoint metadata was expected but no " - "partitioned AutoEP layer entries were found") + validate_autoep_zero3_partitioned_metadata(autoep_layers, + require_partitioned=require_partitioned, + expected_expert_prefixes=expected_expert_prefixes, + version_context="This DeepSpeed build") @staticmethod def _autoep_expert_parameter_names(autoep_layers, model): @@ -4016,7 +3936,7 @@ def _autoep_expert_parameter_names(autoep_layers, model): for entry in autoep_layers: if not isinstance(entry, dict): continue - if entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) != AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT: + if not is_autoep_zero3_partitioned_entry(entry): continue prefix = entry.get('expert_key_prefix') if prefix: @@ -4470,6 +4390,29 @@ def _get_non_moe_state_dict(self, full_state_dict): return full_state_dict + def _common_checkpoint_state(self, module_state_dict, zero_optimizer_state, save_frozen_param): + return dict(module=module_state_dict, + buffer_names=self._get_buffer_names(), + optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, + param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None, + frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func) + if save_frozen_param else None, + shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None, + frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func) + if save_frozen_param else None, + lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, + data_sampler=self.training_dataloader.data_sampler.state_dict() if + (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, + random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, + sparse_tensor_module_names=self.sparse_tensor_module_names, + skipped_steps=self.skipped_steps, + global_steps=self.global_steps, + global_samples=self.global_samples, + dp_world_size=self.seq_dp_world_size, + mp_world_size=self.mp_world_size, + ds_config=self.config, + ds_version=version) + def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): save_path = self._get_ckpt_name(save_dir, tag) @@ -4690,49 +4633,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa universal_checkpoint_info[EXPERT_PARAMETER_PATTERNS] = [r'.*\.experts\.w[123]$'] universal_checkpoint_info['ds_autoep_layers'] = autoep_layer_info - state = { - 'module': - model_state_dict, - 'buffer_names': - self._get_buffer_names(), - 'optimizer': - self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, - 'param_shapes': - zero_param_shapes, - 'frozen_param_shapes': - self._get_zero_frozen_param_attributes(self._get_param_shape_func) if save_frozen_param else None, - 'shared_params': - self._get_shared_params() if self.optimizer and zero_optimizer_state else None, - 'frozen_param_fragments': - self._get_zero_frozen_param_attributes(self._get_param_fragment_func) if save_frozen_param else None, - 'lr_scheduler': - self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - 'data_sampler': - self.training_dataloader.data_sampler.state_dict() if - (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, - 'random_ltd': - self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, - 'sparse_tensor_module_names': - self.sparse_tensor_module_names, - 'skipped_steps': - self.skipped_steps, - 'global_steps': - self.global_steps, - 'global_samples': - self.global_samples, - 'dp_world_size': - self.seq_dp_world_size, - 'mp_world_size': - self.mp_world_size, - 'num_experts': - self.num_experts, - 'ds_autoep_layers': - autoep_layer_info if autoep_layer_info else None, - 'ds_config': - self.config, - 'ds_version': - version, - } + state = self._common_checkpoint_state(model_state_dict, zero_optimizer_state, save_frozen_param) + state['num_experts'] = self.num_experts + state['ds_autoep_layers'] = autoep_layer_info if autoep_layer_info else None if universal_checkpoint_info is not None: state[UNIVERSAL_CHECKPOINT_INFO] = universal_checkpoint_info # Check for reserved-key collisions with client_state @@ -4785,27 +4688,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parame module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) self._curr_ckpt_path = None - state = dict(module=module, - buffer_names=self._get_buffer_names(), - optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, - param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None, - frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func) - if save_frozen_param else None, - shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None, - frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func) - if save_frozen_param else None, - lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - data_sampler=self.training_dataloader.data_sampler.state_dict() if - (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, - random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, - sparse_tensor_module_names=self.sparse_tensor_module_names, - skipped_steps=self.skipped_steps, - global_steps=self.global_steps, - global_samples=self.global_samples, - dp_world_size=self.seq_dp_world_size, - mp_world_size=self.mp_world_size, - ds_config=self.config, - ds_version=version) + state = self._common_checkpoint_state(module, zero_optimizer_state, save_frozen_param) autotp_uc_info = getattr(self.module, UNIVERSAL_CHECKPOINT_INFO, None) if autotp_uc_info is not None: state[UNIVERSAL_CHECKPOINT_INFO] = autotp_uc_info diff --git a/tests/unit/v1/moe/test_autoep_grad_parity.py b/tests/unit/v1/moe/test_autoep_grad_parity.py index 13e5233fc770..38ecd83808ec 100644 --- a/tests/unit/v1/moe/test_autoep_grad_parity.py +++ b/tests/unit/v1/moe/test_autoep_grad_parity.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -"""One ZeRO-2 AutoEP gradient parity path.""" +"""AutoEP gradient parity paths.""" import deepspeed import deepspeed.comm as dist @@ -59,6 +59,17 @@ def _make_autoep_zero2_config(ep_size): return config +def _make_autoep_zero3_config(ep_size): + config = _make_autoep_zero2_config(ep_size) + config["zero_optimization"] = { + "stage": 3, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 5e8, + } + return config + + def _make_local_batches(*, logical_dp_world_size, logical_dp_rank, grad_accum, seed, seq_len, micro_batch_size, hidden_size, device, dtype): batches = [] @@ -178,3 +189,43 @@ def test_zero2_autoep_matches_zero2_after_one_update(self): return _assert_grad_maps_close(autoep_expert, zero2_expert, lhs_name="AutoEP expert", rhs_name="ZeRO-2 expert") + + def test_zero3_autoep_expert_grads_match_zero2_autoep(self): + ep_size = 2 + seed = 2345 + + _seed_everything(seed) + reference_state = _make_model().state_dict() + + zero2_model = _make_model() + zero3_model = _make_model() + zero2_model.load_state_dict(reference_state) + zero3_model.load_state_dict(reference_state) + + zero2_engine, _, _, _ = deepspeed.initialize(model=zero2_model, config=_make_autoep_zero2_config(ep_size)) + zero3_engine, _, _, _ = deepspeed.initialize(model=zero3_model, config=_make_autoep_zero3_config(ep_size)) + + logical_rank = dist.get_rank() // ep_size + logical_world_size = self.world_size // ep_size + _run_until_boundary(zero2_engine, + logical_dp_world_size=logical_world_size, + logical_dp_rank=logical_rank, + grad_accum=2, + seed=seed) + _run_until_boundary(zero3_engine, + logical_dp_world_size=logical_world_size, + logical_dp_rank=logical_rank, + grad_accum=2, + seed=seed) + + zero2_expert = _collect_autoep_expert_grads(zero2_engine) + zero3_expert = _collect_autoep_expert_grads(zero3_engine) + + dist.barrier() + if dist.get_rank() != 0: + return + + _assert_grad_maps_close(zero3_expert, + zero2_expert, + lhs_name="ZeRO-3 AutoEP expert", + rhs_name="ZeRO-2 AutoEP expert")