Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 100 additions & 14 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@
)
from .training_args import TrainingArguments
from .unified_checkpoint import UnifiedCheckpointHandler
from .unified_checkpoint.utils import generate_base_static_name
from paddle.distributed.checkpoint.sharded_tensor import ShardedTensor, build_sharded_state_dict
from .utils import reshard as reshard_util
from .utils.async_save import AsyncSaver

Expand Down Expand Up @@ -197,7 +199,6 @@
if is_datasets_available():
import datasets


try:
from paddle.distributed.fleet.utils import mix_precision_utils
except:
Expand Down Expand Up @@ -862,6 +863,7 @@ def train(

train_dataloader = self.get_train_dataloader()


total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size
len_dataloader = None
if has_length(train_dataloader):
Expand Down Expand Up @@ -2212,7 +2214,7 @@ def _wrap_model(self, model, training=True):
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1
in_sep_parallel_mode = self.args.sep_parallel_degree > 1
in_cp_parallel_mode = self.args.context_parallel_degree > 1

# Multi-gpu training
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
# MOE use DDP to broadcaset parameters.
Expand All @@ -2234,7 +2236,6 @@ def _wrap_model(self, model, training=True):
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
assert self.optimizer is not None, "optimizer is empty!"
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)

# Pipeline mode
if in_pipeline_parallel_mode:
if self.args.amp_master_grad:
Expand Down Expand Up @@ -2279,20 +2280,18 @@ def get_expected_keys(inputs, keys):
"Using default prepare pipeline inputs func, only support input_ids and labels as inputs."
)
model._prepare_pipeline_inputs_func = _prepare_pipeline_inputs_func

assert self.optimizer is not None, "Pipeline mode need decorate optimizer, pelease init optimizer."
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

if (
hasattr(self.args, "enable_sharding_comm_overlap")
and self.args.enable_sharding_comm_overlap
and self.args.unified_checkpoint
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model.register_sharding_comm_overlap_hook(self.optimizer)

# No pipeline mode, sharding only
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
# Sharded DDP!
Expand All @@ -2306,7 +2305,6 @@ def get_expected_keys(inputs, keys):
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
model, hcg, strategy=fleet.fleet._user_defined_strategy
)

if ShardingOption.SHARD_OP in self.args.sharding:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
Expand All @@ -2325,7 +2323,6 @@ def get_expected_keys(inputs, keys):
level = "p_g_os"

from paddle.distributed.sharding import group_sharded_parallel

# add dp_group and exclude_layer params
# https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel
extra_kwargs = {}
Expand All @@ -2348,6 +2345,7 @@ def get_expected_keys(inputs, keys):
offload=cpu_offload,
**extra_kwargs,
)

if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad:
assert hasattr(optimizer, "use_main_grad"), (
"Current installed paddle doesn't support sharding stage 2 with main grad, "
Expand All @@ -2373,7 +2371,6 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
Expand Down Expand Up @@ -2657,7 +2654,7 @@ def _filter_moe_no_sync_optimizer_params(self):
filter_optimzier_state_dict[op_k] = op_v
return filter_optimzier_state_dict

def _ordered_save(self, state_dict, save_path, signal_path=None):
def _ordered_save(self, state_dict, save_path):
group_size = self.args.ordered_save_group_size
hcg = fleet.get_hybrid_communicate_group()
if hcg.get_sharding_parallel_world_size() > 1 or hcg.get_model_parallel_world_size() <= 1:
Expand All @@ -2677,10 +2674,6 @@ def _ordered_save(self, state_dict, save_path, signal_path=None):
paddle.save(state_dict, save_path)
dist.barrier(mp_group)

if signal_path is not None:
with open(signal_path, mode="w+") as f:
f.write("1")

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
if self.args.enable_zero_cost_checkpoint:
Expand Down Expand Up @@ -2749,6 +2742,99 @@ def _save_checkpoint(self, model, metrics=None):

# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:


model_sharded_state_dict = self.model.sharded_state_dict()
print(type(self.optimizer))
print(type(self.optimizer._inner_opt))
print(type(self.optimizer._inner_opt._inner_opt))
opt_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)



sharded_state_dict = {}
sharded_state_dict.update(model_sharded_state_dict)
sharded_state_dict.update(opt_sharded_state_dict)



# for k,v in sharded_state_dict.items():
# print(k, " ====> ",v,)

# pre_save_md5 = {}
# for k,v in sharded_state_dict.items():
# print(k, " ====> ",v.local_tensor.dtype)
# pre_save_md5[k] = v.local_tensor._md5sum()


from paddle.distributed.flex_checkpoint import ShardedTensor
for k,v in sharded_state_dict.items():
if not isinstance(v, ShardedTensor):
print(k," is not ShardedTensor!")

# dist.save_state_dict(
# state_dict=sharded_state_dict,
# path = "./tmp"
# )
print("====> end save")

for k,v in sharded_state_dict.items():
paddle.assign(paddle.zeros(v.local_tensor.shape,v.local_tensor.dtype),v.local_tensor)

# after_save_md5 = {}
# for k,v in sharded_state_dict.items():

# print(k, " ====> ", v.local_tensor._md5sum())
# after_save_md5[k] = v.local_tensor._md5sum()


# for k in after_save_md5.keys():
# assert pre_save_md5[k] != after_save_md5[k] , f"{k} not be assign zeros"


aoa_config = {
"aoa_statements" : ['llama.layers.0.self_attn.qkv_proj.weight -> llama.layers.0.self_attn.qkv_proj.weight ,fused_qkv , num_heads=64, num_key_value_groups = 8 \n']
}
print("====> begin load")
dist.load_state_dict(
state_dict=sharded_state_dict,
path = "./tmp",
aoa_config=aoa_config,
)
# print("====> end load")
# after_load_md5 = {}
# for k,v in sharded_state_dict.items():
# print(k, " ====> ", v.local_tensor._md5sum())
# after_load_md5[k] = v.local_tensor._md5sum()
# if k == "lm_head.weight":
# print(v.local_tensor)

# print("====> m5 check!!!")
# for k in after_save_md5.keys():
# if "lm_head" in k:
# continue
# print(k ," ===>. ", pre_save_md5[k], " ===> ",after_load_md5[k])
# assert pre_save_md5[k] == after_load_md5[k] , f"{k} error load!"


# print("===> begin load!")
# dist.load_state_dict(
# state_dict=sharded_state_dict,
# path = "./tmp"
# )
# print("===> end load!")
# print("==================> ")
# print("===> begin save!")
# dist.save_state_dict(
# state_dict=sharded_state_dict,
# path = "./tmp"
# )
# print("===> end save!")


import sys
sys.exit()

optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")

Expand Down
15 changes: 14 additions & 1 deletion paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.recompute.recompute import recompute
from paddle.distributed.flex_checkpoint import build_sharded_state_dict

from paddlenlp.transformers.refined_recompute import (
RRColumnParallelLinear,
Expand Down Expand Up @@ -1367,7 +1368,8 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:

@classmethod
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):


# raise NotImplementedError
from paddlenlp.transformers.conversion_utils import split_or_merge_func

fn = split_or_merge_func(
Expand Down Expand Up @@ -1427,6 +1429,7 @@ def get_tensor_parallel_split_mappings(num_layers):

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
raise NotImplementedError
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import split_or_fuse_func

Expand Down Expand Up @@ -1995,6 +1998,16 @@ def forward(self, hidden_states, tensor_parallel_output=None):
)
return logits

def sharded_state_dict(
self,
structured_name_prefix: str = "",
):
axis = 0 if self.transpose_y else 1
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(
state_dict, {"weight": axis}, structured_name_prefix
)


class LlamaForCausalLM(LlamaPretrainedModel):
enable_to_static_method = True
Expand Down
13 changes: 13 additions & 0 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3167,6 +3167,19 @@ def state_dict(self, *args, **kwargs):

return state_dict

def sharded_state_dict(self,*args,**kwargs):
sharded_state_dict = super().sharded_state_dict(*args, **kwargs)
if self._single_to_pp_mapping is None:
self._set_pipeline_name_mapping()
assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!"

for k in list(sharded_state_dict.keys()):
v = sharded_state_dict.pop(k)
v.tensor_key = self._pp_to_single_mapping[k]
sharded_state_dict[self._pp_to_single_mapping[k]] = v

return sharded_state_dict

def set_state_dict(self, state_dict, *args, **kwargs):
if self._single_to_pp_mapping is None:
self._set_pipeline_name_mapping()
Expand Down
Loading