From c8e4e0d5acf2fe746b0b9451c5433aa5e6f382f9 Mon Sep 17 00:00:00 2001 From: andsonder Date: Thu, 14 Aug 2025 21:49:13 +0800 Subject: [PATCH] update --- paddlenlp/trainer/trainer.py | 114 ++++++++++++++++++++--- paddlenlp/transformers/llama/modeling.py | 15 ++- paddlenlp/transformers/model_utils.py | 13 +++ 3 files changed, 127 insertions(+), 15 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b48679dbf26a..185da8082c61 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -166,6 +166,8 @@ ) from .training_args import TrainingArguments from .unified_checkpoint import UnifiedCheckpointHandler +from .unified_checkpoint.utils import generate_base_static_name +from paddle.distributed.checkpoint.sharded_tensor import ShardedTensor, build_sharded_state_dict from .utils import reshard as reshard_util from .utils.async_save import AsyncSaver @@ -197,7 +199,6 @@ if is_datasets_available(): import datasets - try: from paddle.distributed.fleet.utils import mix_precision_utils except: @@ -862,6 +863,7 @@ def train( train_dataloader = self.get_train_dataloader() + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size len_dataloader = None if has_length(train_dataloader): @@ -2212,7 +2214,7 @@ def _wrap_model(self, model, training=True): in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1 in_sep_parallel_mode = self.args.sep_parallel_degree > 1 in_cp_parallel_mode = self.args.context_parallel_degree > 1 - + # Multi-gpu training if self.args.world_size > 1 and (not self.args.use_hybrid_parallel): # MOE use DDP to broadcaset parameters. @@ -2234,7 +2236,6 @@ def _wrap_model(self, model, training=True): mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) assert self.optimizer is not None, "optimizer is empty!" self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - # Pipeline mode if in_pipeline_parallel_mode: if self.args.amp_master_grad: @@ -2279,12 +2280,11 @@ def get_expected_keys(inputs, keys): "Using default prepare pipeline inputs func, only support input_ids and labels as inputs." ) model._prepare_pipeline_inputs_func = _prepare_pipeline_inputs_func - + assert self.optimizer is not None, "Pipeline mode need decorate optimizer, pelease init optimizer." if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - if ( hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap @@ -2292,7 +2292,6 @@ def get_expected_keys(inputs, keys): and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model.register_sharding_comm_overlap_hook(self.optimizer) - # No pipeline mode, sharding only if not in_pipeline_parallel_mode and in_sharding_parallel_mode: # Sharded DDP! @@ -2306,7 +2305,6 @@ def get_expected_keys(inputs, keys): model = paddle.distributed.fleet.meta_parallel.TensorParallel( model, hcg, strategy=fleet.fleet._user_defined_strategy ) - if ShardingOption.SHARD_OP in self.args.sharding: if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use @@ -2325,7 +2323,6 @@ def get_expected_keys(inputs, keys): level = "p_g_os" from paddle.distributed.sharding import group_sharded_parallel - # add dp_group and exclude_layer params # https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel extra_kwargs = {} @@ -2348,6 +2345,7 @@ def get_expected_keys(inputs, keys): offload=cpu_offload, **extra_kwargs, ) + if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad: assert hasattr(optimizer, "use_main_grad"), ( "Current installed paddle doesn't support sharding stage 2 with main grad, " @@ -2373,7 +2371,6 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - # stage1 has v1 and v2 version if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: if "split_param" in self.args.sharding_parallel_config: @@ -2657,7 +2654,7 @@ def _filter_moe_no_sync_optimizer_params(self): filter_optimzier_state_dict[op_k] = op_v return filter_optimzier_state_dict - def _ordered_save(self, state_dict, save_path, signal_path=None): + def _ordered_save(self, state_dict, save_path): group_size = self.args.ordered_save_group_size hcg = fleet.get_hybrid_communicate_group() if hcg.get_sharding_parallel_world_size() > 1 or hcg.get_model_parallel_world_size() <= 1: @@ -2677,10 +2674,6 @@ def _ordered_save(self, state_dict, save_path, signal_path=None): paddle.save(state_dict, save_path) dist.barrier(mp_group) - if signal_path is not None: - with open(signal_path, mode="w+") as f: - f.write("1") - def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" if self.args.enable_zero_cost_checkpoint: @@ -2749,6 +2742,99 @@ def _save_checkpoint(self, model, metrics=None): # only save model state dict, ignore optimizer and scheduler if not self.args.ignore_save_lr_and_optim: + + + model_sharded_state_dict = self.model.sharded_state_dict() + print(type(self.optimizer)) + print(type(self.optimizer._inner_opt)) + print(type(self.optimizer._inner_opt._inner_opt)) + opt_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + + + + sharded_state_dict = {} + sharded_state_dict.update(model_sharded_state_dict) + sharded_state_dict.update(opt_sharded_state_dict) + + + + # for k,v in sharded_state_dict.items(): + # print(k, " ====> ",v,) + + # pre_save_md5 = {} + # for k,v in sharded_state_dict.items(): + # print(k, " ====> ",v.local_tensor.dtype) + # pre_save_md5[k] = v.local_tensor._md5sum() + + + from paddle.distributed.flex_checkpoint import ShardedTensor + for k,v in sharded_state_dict.items(): + if not isinstance(v, ShardedTensor): + print(k," is not ShardedTensor!") + + # dist.save_state_dict( + # state_dict=sharded_state_dict, + # path = "./tmp" + # ) + print("====> end save") + + for k,v in sharded_state_dict.items(): + paddle.assign(paddle.zeros(v.local_tensor.shape,v.local_tensor.dtype),v.local_tensor) + + # after_save_md5 = {} + # for k,v in sharded_state_dict.items(): + + # print(k, " ====> ", v.local_tensor._md5sum()) + # after_save_md5[k] = v.local_tensor._md5sum() + + + # for k in after_save_md5.keys(): + # assert pre_save_md5[k] != after_save_md5[k] , f"{k} not be assign zeros" + + + aoa_config = { + "aoa_statements" : ['llama.layers.0.self_attn.qkv_proj.weight -> llama.layers.0.self_attn.qkv_proj.weight ,fused_qkv , num_heads=64, num_key_value_groups = 8 \n'] + } + print("====> begin load") + dist.load_state_dict( + state_dict=sharded_state_dict, + path = "./tmp", + aoa_config=aoa_config, + ) + # print("====> end load") + # after_load_md5 = {} + # for k,v in sharded_state_dict.items(): + # print(k, " ====> ", v.local_tensor._md5sum()) + # after_load_md5[k] = v.local_tensor._md5sum() + # if k == "lm_head.weight": + # print(v.local_tensor) + + # print("====> m5 check!!!") + # for k in after_save_md5.keys(): + # if "lm_head" in k: + # continue + # print(k ," ===>. ", pre_save_md5[k], " ===> ",after_load_md5[k]) + # assert pre_save_md5[k] == after_load_md5[k] , f"{k} error load!" + + + # print("===> begin load!") + # dist.load_state_dict( + # state_dict=sharded_state_dict, + # path = "./tmp" + # ) + # print("===> end load!") + # print("==================> ") + # print("===> begin save!") + # dist.save_state_dict( + # state_dict=sharded_state_dict, + # path = "./tmp" + # ) + # print("===> end save!") + + + import sys + sys.exit() + optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 84c9bcc7ff4e..426bc6052cdc 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -30,6 +30,7 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.flex_checkpoint import build_sharded_state_dict from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, @@ -1367,7 +1368,8 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): - + + # raise NotImplementedError from paddlenlp.transformers.conversion_utils import split_or_merge_func fn = split_or_merge_func( @@ -1427,6 +1429,7 @@ def get_tensor_parallel_split_mappings(num_layers): @classmethod def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False): + raise NotImplementedError # return parameter fuse utils from paddlenlp.transformers.conversion_utils import split_or_fuse_func @@ -1995,6 +1998,16 @@ def forward(self, hidden_states, tensor_parallel_output=None): ) return logits + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + axis = 0 if self.transpose_y else 1 + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict( + state_dict, {"weight": axis}, structured_name_prefix + ) + class LlamaForCausalLM(LlamaPretrainedModel): enable_to_static_method = True diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b5756896c65a..a22016d2d982 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -3167,6 +3167,19 @@ def state_dict(self, *args, **kwargs): return state_dict + def sharded_state_dict(self,*args,**kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + if self._single_to_pp_mapping is None: + self._set_pipeline_name_mapping() + assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!" + + for k in list(sharded_state_dict.keys()): + v = sharded_state_dict.pop(k) + v.tensor_key = self._pp_to_single_mapping[k] + sharded_state_dict[self._pp_to_single_mapping[k]] = v + + return sharded_state_dict + def set_state_dict(self, state_dict, *args, **kwargs): if self._single_to_pp_mapping is None: self._set_pipeline_name_mapping()