diff --git a/.compatibility b/.compatibility index e1836506aae6..69d483524dcc 100644 --- a/.compatibility +++ b/.compatibility @@ -1,3 +1,3 @@ -2.2.2-12.1.0 2.3.0-12.1.0 2.4.0-12.4.1 +2.5.1-12.4.1 diff --git a/.cuda_ext.json b/.cuda_ext.json index 1e617755b01b..01a30a9c1204 100644 --- a/.cuda_ext.json +++ b/.cuda_ext.json @@ -1,11 +1,11 @@ { "build": [ { - "torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121", + "torch_command": "pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121", "cuda_image": "hpcaitech/cuda-conda:12.1" }, { - "torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124", + "torch_command": "pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124", "cuda_image": "hpcaitech/cuda-conda:12.4" } ] diff --git a/.gitignore b/.gitignore index 8bc74b4c8c2c..16f764c1b1ef 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ coverage.xml # log, test files - ColossalChat applications/ColossalChat/logs applications/ColossalChat/tests/logs +applications/ColossalChat/wandb diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3c5b..b582ee8b33c2 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,12 +66,13 @@ def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) plugin_config = dict( - tp_size=1, + tp_size=2, pp_size=1, precision="bf16", - zero_stage=1, + zero_stage=2, + parallel_output=False, ) - if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: + if plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 55dfd09ab244..cd942c53a677 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -114,20 +114,25 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: need_update = (step_idx + 1) % self.num_microbatches == 0 - ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + # ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + ctx = nullcontext() with ctx: policy_model_logits = self.policy_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], )["logits"] - action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) + action_log_probs = calc_action_log_probs( + policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config + ) with torch.no_grad(): reference_model_logits = self.reference_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], )["logits"] - reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + reference_action_log_probs = calc_action_log_probs( + reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config + ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index bc0ae5c36673..58414b29fd47 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -61,12 +61,22 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] self.generate_config = generate_config.copy() self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer + self.num_generations = 8 @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + micro_batch_size = input_ids.size(0) input_ids = input_ids.to(get_current_device()) attention_mask = attention_mask.to(get_current_device()) - out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config) + gt_answer = None + if "gt_answer" in kwargs: + gt_answer = kwargs.pop("gt_answer") + if self.num_generations > 1: + input_ids = input_ids.repeat_interleave(self.num_generations, dim=0) + attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0) + out = self.model.generate( + input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer + ) input_len = input_ids.shape[-1] new_token_ids = out.sequences[:, input_len:] # get log probs @@ -76,10 +86,13 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1])) action_log_probs = torch.cat(action_log_probs, dim=1) # get action mask + response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device()) action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype) if self.tokenizer.eos_token_id is not None: for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id): action_mask[indices[0], indices[1] + 1 :] = 0 + response_idx[:, 0] = input_len + response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1 if attention_mask.size(0) != action_mask.size(0): assert action_mask.size(0) % attention_mask.size(0) == 0 @@ -91,7 +104,15 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar "attention_mask": attention_mask, "action_log_probs": action_log_probs, "action_mask": action_mask, + "response_idx": response_idx, } + + data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + + if gt_answer is not None: + # repeat gt_answer for each prompt. + data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1) + data = {k: v.to(get_current_device()) for k, v in data.items()} return data def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: @@ -154,7 +175,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict( logprobs=0, - n=4, + n=8, ) def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 222540b9270d..af5776731a25 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -26,15 +26,10 @@ def forward( ) -> torch.Tensor: skip = False if action_mask is None: - ratio_ = (log_probs - old_log_probs).exp() + ratio = (log_probs - log_probs.detach()).exp() else: - ratio_ = ((log_probs - old_log_probs) * action_mask).exp() + ratio = ((log_probs - log_probs.detach()) * action_mask).exp() - # note that if dropout is disabled (recommanded), ratio will always be 1. - if ratio_.mean() > self.skip_threshold: - skip = True - - ratio = ratio_.clamp(0.0, 10.0) surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages loss = -torch.min(surr1, surr2) + self.beta * per_token_kl @@ -44,4 +39,4 @@ def forward( else: loss = loss.mean(dim=1) loss = loss.mean() - return loss, skip, ratio_.max() + return loss, skip, ratio.max() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 3e4a5277ad2e..a3ae22a7935c 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -154,7 +154,11 @@ def __init__( @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): - return self.model.generate(input_ids, attention_mask, **kwargs) + rollouts = self.model.generate(input_ids, attention_mask, **kwargs) + if self.producer_idx == 1: + print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) + + return rollouts def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 98b54815b5b4..919e4434faa6 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -2,6 +2,8 @@ import torch +from colossalai.shardformer.layer.loss import dist_log_prob + def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: batches = [] @@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return per_label_logps.squeeze(-1) -def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: +def calc_action_log_probs( + logits: torch.Tensor, + sequences: torch.LongTensor, + num_actions: int, + shard_config, + vocab_size: int = None, +) -> torch.Tensor: """Calculate action log probs. Args: - output (torch.Tensor): Output tensor of Actor.forward.logits. + logits (torch.Tensor): Output tensor of Actor.forward.logits. sequences (torch.LongTensor): Input sequences. num_actions (int): Number of actions. + shard_config + vocab_size + Returns: torch.Tensor: Action log probs. """ - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + # logits: torch.Tensor, # [B, S, Vocab_size] + log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) + log_probs = log_probs.squeeze(-1) return log_probs[:, -num_actions:] diff --git a/applications/ColossalChat/coati/experience_maker/naive.py b/applications/ColossalChat/coati/experience_maker/naive.py index c7ad4f316179..81f8fb80cc93 100755 --- a/applications/ColossalChat/coati/experience_maker/naive.py +++ b/applications/ColossalChat/coati/experience_maker/naive.py @@ -140,7 +140,7 @@ def make_experience( num_actions = 0 for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size): - s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size + s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size if input_ids[s:e].size(0) == 0: break sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs) diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py index cde13d41e06f..25a3d32eeb7a 100755 --- a/applications/ColossalChat/coati/trainer/dpo.py +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -380,8 +380,8 @@ def _criterion(outputs, inputs): self.accumulative_meter.get("accuracy"), global_step, ) - self.num_train_step += 1 self.accumulative_meter.reset() + self.num_train_step += 1 if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0: # save checkpoint diff --git a/applications/ColossalChat/coati/trainer/grpo.py b/applications/ColossalChat/coati/trainer/grpo.py index 08710b19600c..6fae5c1fec9d 100755 --- a/applications/ColossalChat/coati/trainer/grpo.py +++ b/applications/ColossalChat/coati/trainer/grpo.py @@ -231,7 +231,6 @@ def _training_step(self, experience: Experience): experience: sequences: [batch_size, prompt_length + response_length] --- ............ """ - self.num_train_step += 1 self.actor.train() num_actions = experience.action_log_probs.size(1) # policy loss @@ -294,7 +293,7 @@ def _training_step(self, experience: Experience): self.temperature_annealing_scheduler.step_forward() # preparing logging model output and corresponding rewards. - if self.num_train_step % 10 == 1: + if self.num_train_step % 10 == 0: response_text = self.experience_maker.tokenizer.batch_decode( experience.sequences, skip_special_tokens=True ) @@ -327,6 +326,7 @@ def _training_step(self, experience: Experience): self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step) self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step) self.accumulative_meter.reset() + self.num_train_step += 1 def _learn(self, update_step: int): """ diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py index 2d7e2fa85c25..6dd1ed407165 100755 --- a/applications/ColossalChat/coati/trainer/kto.py +++ b/applications/ColossalChat/coati/trainer/kto.py @@ -256,7 +256,7 @@ def _train(self, epoch: int): self.coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" ) - self.num_train_step += 1 + self.num_train_step += 1 step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py index 0224c8f347ae..1de9f527c96f 100644 --- a/applications/ColossalChat/coati/trainer/orpo.py +++ b/applications/ColossalChat/coati/trainer/orpo.py @@ -233,7 +233,7 @@ def _train(self, epoch: int): self.coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" ) - self.num_train_step += 1 + self.num_train_step += 1 step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py index 331425174c1d..54630a6b7782 100755 --- a/applications/ColossalChat/coati/trainer/ppo.py +++ b/applications/ColossalChat/coati/trainer/ppo.py @@ -220,7 +220,6 @@ def _training_step(self, experience: Experience): experience: sequences: [batch_size, prompt_length + response_length] --- ............ """ - self.num_train_step += 1 self.actor.train() self.critic.train() num_actions = experience.action_log_probs.size(1) @@ -294,7 +293,7 @@ def _training_step(self, experience: Experience): self.critic_scheduler.step() # preparing logging model output and corresponding rewards. - if self.num_train_step % 10 == 1: + if self.num_train_step % 10 == 0: response_text = self.experience_maker.tokenizer.batch_decode( experience.sequences, skip_special_tokens=True ) @@ -336,6 +335,7 @@ def _training_step(self, experience: Experience): self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step) self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step) self.accumulative_meter.reset() + self.num_train_step += 1 def _learn(self, update_step: int): """ diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py index 991167a9125f..f52b22e8e5a6 100755 --- a/applications/ColossalChat/coati/trainer/rm.py +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -193,7 +193,7 @@ def _train(self, epoch): self.coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}" ) - self.num_train_step += 1 + self.num_train_step += 1 step_bar.close() def _eval(self, epoch): diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py index fe7f4978bc09..b121369d0721 100755 --- a/applications/ColossalChat/coati/trainer/sft.py +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -152,9 +152,9 @@ def _train(self, epoch: int): if self.writer: self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step) self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step) - self.num_train_step += 1 self.accumulative_meter.reset() step_bar.update() + self.num_train_step += 1 # Save checkpoint if ( diff --git a/applications/ColossalChat/coati/utils/compare_tool.py b/applications/ColossalChat/coati/utils/compare_tool.py new file mode 100644 index 000000000000..ac72abadb22e --- /dev/null +++ b/applications/ColossalChat/coati/utils/compare_tool.py @@ -0,0 +1,1098 @@ +######################## +# Reference from https://github.com/MooreThreads/torch_musa/blob/main/torch_musa/utils/compare_tool.py +# A verion for gpu and npu +# BSD 3-Clause License +# Copyright (c) 2023 , Moore Threads Technology Co., Ltd. +# Copyright (c) 2022, Facebook Inc. and the respective contributors +# All rights reserved. +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ------------------------------------------------------------------------------------ +# This product bundles various third-party components under other open source licenses. +# This section summarizes those components and their licenses. See licenses/ +# for text of these licenses. + +# License in PyToch(https://github.com/pytorch/pytorch/blob/main/LICENSE) +# ----------------- +# tools/setup_helper +# torch_musa/csrc/ +# torch_musa/core +# Apache Software Foundation License 2.0 +# -------------------------------------- +# tools/lint +# BSD 2-clause License +# -------------------- +# docs +# Apache Software Foundation License 2.0 +# -------------------------------------- +# examples/cpp +######################## +"compare tool with cpu" +# pylint: disable=broad-exception-caught,broad-exception-raised,redefined-builtin,unused-argument +import os +import pickle +import sys +from datetime import datetime +from functools import partial + +import torch +from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily + + +class ModuleInfo(object): + """ + A class to store information about a module in a neural network, including its name, + relationship to other modules (parent and children), and whether the module is a leaf + or is being executed in a forward or backward pass. + """ + + def __init__(self, name, father, is_leaf=False, is_forward=True) -> None: + self.name = name # Name of the module + self.father = father # Parent module in the hierarchy + self.children = [] # List of child modules + self.is_leaf = is_leaf # Flag indicating if the module is a leaf module + self.is_forward = is_forward # Flag indicating if the module is in the forward pass + + def name_with_prefix(self): + """ + Constructs the module's full name with hierarchical prefix based on parent names. + + Returns: + str: The full hierarchical name of the module. + """ + if self.father is None or self.father.name_with_prefix() == "": + prefix = "" + else: + prefix = self.father.name_with_prefix() + "/" + return prefix + self.name + + def full_name(self): + """ + Generates the module's full name with an additional suffix indicating + whether the current state is forward or backward pass. + + Returns: + str: The full name of the module including its forward/backward state. + """ + is_forward = self.is_forward if not self.is_leaf else self.father.is_forward + suffix = "(forward)" if is_forward else "(backward)" + return self.name_with_prefix() + suffix + + +# Initialize root module info as the base of the module hierarchy +root_module_info = ModuleInfo(name="", father=None) +current_module_info = root_module_info + + +def pre_forward_hook(module, input): + """ + Hook to be executed before a module's forward pass. It updates the module hierarchy + by adding the current module as a child of the current module in the hierarchy. + + Parameters: + - module: The module where the hook is registered. + - input: The input to the forward method of the module. + """ + global current_module_info + module_info = ModuleInfo(module.__class__.__name__, father=current_module_info, is_forward=True) + current_module_info.children.append(module_info) + current_module_info = module_info + + +def post_forward_hook(module, input, output): + """ + Hook to be executed after a module's forward pass. It steps back in the module hierarchy + to the parent module. + + Parameters: + - module: The module where the hook is registered. + - input: The input to the forward method of the module. + - output: The output from the forward method of the module. + """ + global current_module_info + current_module_info = current_module_info.father + + +def pre_backward_hook(module, grad_output): + """ + Hook to be executed before a module's backward pass. Similar to the pre_forward_hook, + it adds the module to the hierarchy with an indication that it's part of the backward pass. + + Parameters: + - module: The module where the hook is registered. + - grad_output: The gradients at the output of the module. + """ + global current_module_info + module_info = ModuleInfo(module.__class__.__name__, current_module_info, is_forward=False) + current_module_info.children.append(module_info) + current_module_info = module_info + + +def post_backward_hook(module, grad_input, grad_output): + """ + Hook to be executed after a module's backward pass. + It steps back to the parent module in the hierarchy. + + Parameters: + - module: The module where the hook is registered. + - grad_input: The gradients at the input of the module. + - grad_output: The gradients at the output of the module. + """ + global current_module_info + current_module_info = current_module_info.father + + +def register_hooks(module): + """ + Registers the forward and backward hooks on the module and all its submodules. + + Parameters: + - module: The root module to register hooks on. + """ + module.register_forward_pre_hook(pre_forward_hook) + module.register_forward_hook(post_forward_hook) + module.register_full_backward_pre_hook(pre_backward_hook) + module.register_full_backward_hook(post_backward_hook) + + +def open_module_tracker(module): + """ + Initializes the module tracking by applying the register_hooks function to the module + and all its submodules. + + Parameters: + - module: The root module to start tracking on. + """ + module.apply(register_hooks) + + +def recursive_apply(func): + """ + Applies a function recursively to all tensors in a nested structure of + tensors, lists, tuples, and dictionaries. + + Parameters: + - func (function): A function to apply to every tensor found in the input structure. + + Returns: + - A function that takes an input structure and applies 'func' + to every tensor within that structure. + """ + + def recursive_apply_fn(inputs): + if isinstance(inputs, (list, tuple)): + # Recursively apply to each element in lists or tuples + inputs_dst = [None] * len(inputs) + for i, x in enumerate(inputs): + inputs_dst[i] = recursive_apply_fn(x) + return tuple(inputs_dst) if isinstance(inputs, tuple) else inputs_dst + if isinstance(inputs, dict): + # Recursively apply to each value in dictionaries + return {k: recursive_apply_fn(v) for k, v in inputs.items()} + if isinstance(inputs, torch.Tensor): + # Apply the function to tensors + return func(inputs.detach().clone()) + # Return non-tensor objects unchanged + return inputs + + return recursive_apply_fn + + +def convert_to_dtype(inputs, dtype): + """ + Converts all tensors in a nested structure to a specified data type. + + Parameters: + - inputs: The input structure containing tensors. + - dtype: The target data type. + + Returns: + - The input structure with all tensors converted to the specified data type. + """ + return recursive_apply(lambda x: x.to(dtype=dtype))(inputs) + + +def convert_to_cpu(inputs): + """ + Converts all tensors in a nested structure to CPU memory. + + Parameters: + - inputs: The input structure containing tensors. + + Returns: + - The input structure with all tensors moved to CPU memory. + """ + return recursive_apply(lambda x: x.cpu())(inputs) + + +def convert_to_npu(inputs): + """ + Converts all tensors in a nested structure to a specified device, + in this case, a fictional "npu" device. + + Parameters: + - inputs: The input structure containing tensors. + + Returns: + - The input structure with all tensors moved to the "npu" device. + """ + return recursive_apply(lambda x: x.to("npu"))(inputs) + + +def convert_to_gpu(inputs): + """ + Converts all tensors in a nested structure to a specified device, + in this case, a fictional "npu" device. + + Parameters: + - inputs: The input structure containing tensors. + + Returns: + - The input structure with all tensors moved to the "npu" device. + """ + return recursive_apply(lambda x: x.to("cuda"))(inputs) + + +def compare_tensors(tensor1, tensor2, atol, rtol): + """ + Compares two tensors element-wise to check if they are approximately + equal within the given tolerances, + and returns a mask indicating where they differ. + + Parameters: + - tensor1, tensor2: The tensors to compare. + - atol (float): Absolute tolerance. + - rtol (float): Relative tolerance. + + Returns: + - A boolean tensor mask where True indicates positions + where tensor1 and tensor2 are not close. + """ + # Initial comparisons for NaN and Inf values + nan_mask1, nan_mask2 = torch.isnan(tensor1), torch.isnan(tensor2) + inf_mask1, inf_mask2 = torch.isinf(tensor1), torch.isinf(tensor2) + + # Identify differences in NaN and Inf + nan_diff = nan_mask1 != nan_mask2 + inf_diff = inf_mask1 != inf_mask2 + + # Calculate the difference and tolerance + basic_diff = torch.abs(tensor1 - tensor2) + tolerance = atol + rtol * torch.abs(tensor2) + + # Check normal (non-NaN, non-Inf) value differences + normal_not_close = (basic_diff > tolerance) & ~(nan_mask1 | inf_mask1) + + # Aggregate differences + not_close = nan_diff | inf_diff | normal_not_close + + return not_close + + +def format_tensor(tensor): + """ + Formats a tensor for printing, providing information about + NaN and Inf values, shape, and other properties. + + Parameters: + - tensor: The tensor to format. + + Returns: + - A string representing the formatted tensor information. + """ + nan_num = torch.isnan(tensor).sum().item() + inf_num = torch.isinf(tensor).sum().item() + head = "[WARNING] " if nan_num or inf_num else "" + warnings = f"nan_num={nan_num}, inf_num={inf_num}" if nan_num or inf_num else "" + return ( + f"{head}Tensor " + ) + + +def print_tensors_diff(tensor1, tensor2, atol, rtol): + """ + Prints the differences between two tensors, indicating + where they do not match within the given tolerances. + + Parameters: + - tensor1, tensor2: The tensors to compare. + - atol (float): Absolute tolerance. + - rtol (float): Relative tolerance. + + Returns: + - A string detailing the indices and values where the tensors differ. + """ + not_close = compare_tensors(tensor1.to(tensor2.device).to(tensor2.dtype), tensor2, atol, rtol) + indices = torch.nonzero(not_close) + indices_np = indices.cpu().numpy() + diff_str = "" + # If the indices are too large, only process the front part + if len(indices_np) > 20: + diff_str += f"\nToo many indices (total {len(indices_np)}) to print \n\n...\n\n" + indices_np = indices_np[:20] + idx_tuples = [tuple(idx) for idx in indices_np] + elements_out1 = [tensor1[idx].item() for idx in idx_tuples] + elements_out2 = [tensor2[idx].item() for idx in idx_tuples] + + for idx_tuple, elem1, elem2 in zip(idx_tuples, elements_out1, elements_out2): + diff_str += ( + f"Element at index {idx_tuple} is not close:" + f"{elem1}({tensor1.device}) vs " + f"{elem2}({tensor2.device})\n" + ) + diff_str += "\n...\n\n" + diff_str += f"{format_tensor(tensor1)}\n{tensor1}\n" + diff_str += f"{format_tensor(tensor2)}\n{tensor2}\n" + + return diff_str + + +def get_op_name(op_func): + """ + Extracts a simplified operation name from a function, trimming common prefixes and suffixes. + + Parameters: + - op_func: The operation function. + + Returns: + - A string representing the simplified name of the operation. + """ + full_op_name = f"{op_func.__module__}.{op_func.__name__}" + full_op_name = full_op_name.replace("torch._ops", "torch.ops") + full_op_name = full_op_name.replace(".default", "") + return full_op_name + + +def get_full_op_name(op_name): + """ + Generates the full name of an operation including its module hierarchy + and whether it's a forward or backward operation. + + Parameters: + - op_name: The name of the operation. + + Returns: + - The full name of the operation, considering its module hierarchy + and operation type (forward or backward). + """ + module_info = ModuleInfo(name=op_name, father=current_module_info, is_leaf=True) + full_op_name = module_info.full_name() + current_module_info.children.append(module_info) + return full_op_name + + +def recursive_compare(out1, out2, atol, rtol, depth=0): + """ + Recursively compares two outputs (tensors or collections of tensors) + to check if they are approximately equal + within specified absolute (atol) and relative (rtol) tolerances. + + Parameters: + - out1, out2: The outputs to compare. These can be tensors, + lists/tuples of tensors, or nested structures thereof. + - atol (float): Absolute tolerance. + - rtol (float): Relative tolerance. + - depth (int): Recursion depth, used internally to format + the output string with indentation for readability. + + Returns: + - (bool, str): A tuple containing a boolean indicating whether + the outputs are approximately equal and a string + detailing differences if any. + """ + indent = " " * depth # Indentation based on recursion depth + tensors_diff_str = "" + # Direct comparison for tensor objects + if isinstance(out1, torch.Tensor) and isinstance(out2, torch.Tensor): + # Check if tensors are approximately equal considering atol and rtol + if not torch.allclose( + out1.to(out2.device).to(out2.dtype), + out2, + atol=atol, + rtol=rtol, + equal_nan=True, + ): + # Record difference + tensors_diff_str += indent + "Tensor values are not close\n" + diff_str = print_tensors_diff(out1, out2, atol, rtol) + # Indent each line of the diff_str + indented_diff_str = "\n".join(indent + line for line in diff_str.split("\n")) + tensors_diff_str += indented_diff_str + return False, tensors_diff_str + return True, tensors_diff_str + + # Recursive comparison for list or tuple containers + if isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): + all_results = [] + for i, (value1, value2) in enumerate(zip(out1, out2)): + # Recurse into elements + result, tensors_diff_str_child = recursive_compare(value1, value2, atol, rtol, depth + 1) + all_results.append(result) + if not result: + tensors_diff_str += indent + f"Output {i} is not close:\n" + tensors_diff_str += tensors_diff_str_child + return all(all_results), tensors_diff_str + + # Check for NaN equivalence separately since NaN != NaN + def are_both_nan(a, b): + try: + tensor_a, tensor_b = torch.tensor(a), torch.tensor(b) + except TypeError: + return False + return torch.isnan(tensor_a.detach()).all() and torch.isnan(tensor_b.detach()).all() + + # Fallback comparison for non-tensor types + if out1 != out2 and not are_both_nan(out1, out2): + return False, indent + f"Values are not equal: {out1} vs {out2}\n" + return True, tensors_diff_str + + +def recursive_print(args, top_level=True): + """ + Recursively prints the structure and contents of the argument which can be a tensor, + list, tuple, dictionary, or nested combinations thereof. + + Parameters: + - args: The input to print, can be of any type but primarily aimed at tensors and collections. + - top_level (bool): Indicates if the current call is at the top level (for pretty printing). + Top level entries are printed with an index or key, followed by a newline. + + Returns: + - str: A string representation of the input, formatted for readability. + """ + out_str = "" + if isinstance(args, torch.Tensor): + out_str += format_tensor(args) # Format tensor for printing + elif isinstance(args, (list, tuple)): + # Recursively process list or tuple elements + for i, x in enumerate(args): + out_str += f"{i}: {recursive_print(x, False)}, \n" if top_level else f"{recursive_print(x, False)}, " + # Enclose in brackets or parentheses based on type + if not top_level: + out_str = f"[{out_str}]" if isinstance(args, list) else f"({out_str})" + elif isinstance(args, dict): + # Recursively process dictionary key-value pairs + for k, v in args.items(): + out_str += f"{k}: {recursive_print(v, False)}, \n" + if not top_level: + out_str = f"{{{out_str}}}" # Enclose in braces + else: + # For non-collection types, simply convert to string + out_str += str(args) + + # Print the formatted string if at the top level + if top_level: + print(out_str) + + return out_str + + +def compare_for_single_op(inputs_data_save_path, op_func, atol, rtol): + """ + Compare the output of a single operation with saved data for correctness. + + Parameters: + inputs_data_save_path (str): Path to the file containing saved input data. + op_func (function): The operation function to test. + atol (float): Absolute tolerance for comparison. + rtol (float): Relative tolerance for comparison. + + Returns: + tuple: The output of the operation and a boolean indicating if it passes the comparison. + """ + # Load saved input data + with open(inputs_data_save_path, "rb") as f: + inputs_data = pickle.load(f) + args_cpu = inputs_data["args"] + kwargs_cpu = inputs_data["kwargs"] + + # Convert inputs to a specific device/format, if necessary + # args_npu = convert_to_npu(args_cpu) + # kwargs_npu = convert_to_npu(kwargs_cpu) + + args_npu = convert_to_gpu(args_cpu) + kwargs_npu = convert_to_gpu(kwargs_cpu) + + # Compare the operation output with the expected results + out, correct = compare_for_single_func(op_func, args_npu, kwargs_npu, atol, rtol, verbose=True) + return correct, args_npu, kwargs_npu, out + + +def compare_for_single_func(func, args, kwargs, atol, rtol, func_name=None, verbose=True): + """ + Compares the output of a function against expected results with given tolerances. + + Parameters: + func (function): The function to test. + args (tuple): Arguments for the function. + kwargs (dict): Keyword arguments for the function. + atol (float): Absolute tolerance. + rtol (float): Relative tolerance. + func_name (str, optional): The name of the function for logging. Defaults to None. + verbose (bool, optional): Enables detailed logging. Defaults to False. + + Returns: + tuple: The output of the function and a boolean indicating if it passes the comparison. + """ + if func_name is None: + func_name = get_op_name(func) + print(f"{func_name} starts to run ...") + + # Convert arguments for CPU execution + args_cpu = convert_to_cpu(args) + kwargs_cpu = convert_to_cpu(kwargs) + + # Execute the function + out = func(*args, **kwargs) + correct = False + + try: + try: + # Attempt to run the function with CPU converted arguments + out_cpu = func(*args_cpu, **kwargs_cpu) + except RuntimeError as excp: + # Handle runtime errors, possibly due to datatype issues + print(excp) + print("Convert to float32 ...") + args_cpu = convert_to_dtype(args_cpu, torch.float32) + kwargs_cpu = convert_to_dtype(kwargs_cpu, torch.float32) + out_cpu = func(*args_cpu, **kwargs_cpu) + + # Compare the outputs + correct, tensors_diff_str = recursive_compare(out, out_cpu, atol, rtol) + if correct: + print(f"{func_name} succeeds to pass CompareWithCPU test") + else: + print("\n============================") + print(f"[ERROR] {func_name} fails to pass CompareWithCPU test") + if verbose: + # Log inputs and outputs for detailed debugging + print("....... input .........") + recursive_print(args) + recursive_print(kwargs) + print("...... output ........") + recursive_print(out) + print("\n...... compare with cpu .......") + print(tensors_diff_str) + print("============================") + except Exception as excp: + # Catch all other exceptions + print(excp) + print(f"[WARNING] {func_name} has not been tested!") + + return out, correct + + +def nan_inf_track_for_single_op(inputs_data_save_path, op_func): + """ + Track NaN or Inf values in the output of a single operation using saved data. + + Parameters: + inputs_data_save_path (str): Path to the file containing saved input data. + op_func (function): The operation function to test. + + Returns: + tuple: The output of the operation and a boolean indicating if NaN or Inf is detected. + """ + # Load saved input data + with open(inputs_data_save_path, "rb") as f: + inputs_data = pickle.load(f) + args_cpu = inputs_data["args"] + kwargs_cpu = inputs_data["kwargs"] + + # Convert inputs for specific device/format, if necessary + # args_npu = convert_to_npu(args_cpu) + # kwargs_npu = convert_to_npu(kwargs_cpu) + + args_npu = convert_to_gpu(args_cpu) + kwargs_npu = convert_to_gpu(kwargs_cpu) + + # Track NaN/Inf in the function output + out, has_nan_or_inf = nan_inf_track_for_single_func(op_func, args_npu, kwargs_npu) + return has_nan_or_inf, args_npu, kwargs_npu, out + + +def nan_inf_track_for_single_func(func, args, kwargs, func_name=None): + """ + Detects NaN or Inf values in the output of a function. + + Parameters: + func (function): The function to test. + args (tuple): Arguments for the function. + kwargs (dict): Keyword arguments for the function. + func_name (str, optional): The name of the function for logging. Defaults to None. + + Returns: + tuple: The output of the function and a boolean indicating if NaN or Inf is detected. + """ + if func_name is None: + func_name = get_op_name(func) + print_str = "" + print("\n============================") + print(func_name) + print("....... input .........") + print_str += recursive_print(args) + print_str += recursive_print(kwargs) + out = func(*args, **kwargs) + print("...... output ........") + print_str += recursive_print(out) + has_nan_or_inf = "[WARNING]" in print_str + + return out, has_nan_or_inf + + +def save_data_for_op(out, args, kwargs, save_dir, op_name, file_suffix=""): + """ + Saves input and output data for an operation to files for later comparison or analysis. + + Parameters: + out (Tensor): The output of the operation. + args (tuple): Arguments used for the operation. + kwargs (dict): Keyword arguments used for the operation. + save_dir (str): The directory where data will be saved. + op_name (str): The name of the operation. + file_suffix (str, optional): A suffix for the filename. Defaults to ''. + + Returns: + None + """ + inputs_pkl_save_path = os.path.join(save_dir, f"{op_name}_inputs{file_suffix}.pkl") + outputs_pkl_save_path = os.path.join(save_dir, f"{op_name}_outputs{file_suffix}.pkl") + + # Prepare data for saving + inputs_data = {"args": convert_to_cpu(args), "kwargs": convert_to_cpu(kwargs)} + outputs_data = convert_to_cpu(out) + + # Save data to files + with open(inputs_pkl_save_path, "wb") as f: + pickle.dump(inputs_data, f) + with open(outputs_pkl_save_path, "wb") as f: + pickle.dump(outputs_data, f) + + print(f"Input data is saved to {inputs_pkl_save_path},\n" f"output data is saved to {outputs_pkl_save_path}") + + +class LogToFile: + """ + A context manager for redirecting stdout to a file. + + This class provides a mechanism for capturing the standard output + to a specified file. It's useful for logging purposes, where the output + of a block of code needs to be saved. If the file has not been written to + before, it opens in write mode and logs the current timestamp. Otherwise, + it appends to the existing file. + + Attributes: + _has_cleared_files (dict): + Tracks whether files have been cleared to avoid duplicate headers. + filepath (str or None): Path to the log file. If None, stdout is not redirected. + original_stdout (io.TextIOWrapper): Reference to the original stdout. + + """ + + _has_cleared_files = {} # Used to track whether each file has been cleared + + def __init__(self, filepath=None): + """ + Initializes the context manager with the path to the log file. + + Args: + filepath (str, optional): The path to the file where stdout will be redirected. + """ + self.filepath = filepath + self.original_stdout = sys.stdout + + def __enter__(self): + """ + Enters the runtime context related to this object. + + The stdout is redirected to the specified file. If the file has not been + written to before, it is cleared and initialized with a timestamp. + + Returns: + LogToFile: The runtime context object. + """ + if self.filepath: + # Check if this file has already been cleared + if self.filepath not in self._has_cleared_files: + # If not, open in "w" mode to clear it and mark as cleared + self.file = open(self.filepath, "w", encoding="utf-8") + self.file.write(datetime.now().strftime("%Y-%m-%d, %H:%M:%S") + "\n") + self._has_cleared_files[self.filepath] = True + else: + # If already cleared, open in "a" mode to append content + self.file = open(self.filepath, "a", encoding="utf-8") + sys.stdout = self.file + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exits the runtime context and restores the original stdout. + + The log file is closed if it was opened, and stdout is restored + to its original state. + + Args: + exc_type: Exception type. + exc_val: Exception value. + exc_tb: Traceback object. + """ + sys.stdout = self.original_stdout + if self.filepath: + self.file.close() + + +# https://github.com/pytorch/pytorch/issues/94403 +class TorchFuncMockNoDispatch: + """ + Wraps a method to call it without the custom + pytorch dispatcher + """ + + def __init__(self, pt_impl): + self.pt_impl = pt_impl + + def __get__(self, obj, c): + return partial(self, obj) + + def __call__(self, obj, *args, **kwargs): + with _pop_mode_temporarily(): + return self.pt_impl(obj, *args, **kwargs) + + +class CompareWithCPU(TorchDispatchMode): + """ + A class for comparing the outputs of tensor operations against + their CPU results to ensure correctness. + This is useful for debugging and verifying the consistency + of operations across different devices. + + Attributes: + enabled (bool): Flag to enable/disable comparison. + atol (float): Absolute tolerance for comparison. + rtol (float): Relative tolerance for comparison. + target_list (list[str]): + Specific operations to compare; if not empty, only these are considered. + white_list (list[str]): + Operations to ignore during comparison; considered if target_list is empty. + dump_error_data (bool): If True, saves args of the first failing op and exits. + verbose (bool): If True, prints detailed info about the args of the ops being compared. + enable_ranks (list[int]): + MPI ranks that are allowed to perform comparisons; None means all ranks. + should_log_to_file (bool): If True, logs comparison results to a file. + output_dir (str): Directory to save logs and error data. + start_step (int): Step number to start comparisons. + end_step (int): Step number to end comparisons. + """ + + TENSOR_FUNCS_NO_DISPATCH = [ + # Can't convert Stream argument to Python object + "record_stream" + ] + + def __enter__(self) -> None: + self._pt_impls = {} + for k in self.TENSOR_FUNCS_NO_DISPATCH: + impl = getattr(torch.Tensor, k) + self._pt_impls[k] = impl + setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl)) + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + for k in self.TENSOR_FUNCS_NO_DISPATCH: + setattr(torch.Tensor, k, self._pt_impls[k]) + return super().__exit__(exc_type, exc_val, exc_tb) + + def __init__( + self, + enabled=True, + atol=0.001, + rtol=0.001, + target_list=None, + white_list=None, + dump_error_data=False, + verbose=True, + enable_ranks=None, + should_log_to_file=False, + output_dir="", + start_step=None, + end_step=None, + ) -> None: + super().__init__() + self.enabled = enabled + self.atol = atol + self.rtol = rtol + # Default whitelist has ops often non-deterministic or not key for comparison + self.white_list = [ + "torch.ops._record_function_enter_new", + "torch.ops.detach", + "torch.ops.aten.detach", + "torch.ops.uniform_", + "torch.ops.set_.source_Storage", + "torch.ops.set_.source_Storage_storage_offset", + "torch.ops.new_empty", + "torch.ops.aten.random_", + "torch.ops.isinf", + "torch.ops.aten.isinf", + "torch.ops.aten.isnan", + "torch.ops.isnan", + "torch.ops.normal_", + "torch.ops.barrier", + "torch.ops.aten.randperm.generator", + "torch.ops._to_copy", + "torch.ops.aten.clone", + "torch.ops.aten.t", + "torch.ops.aten.empty.memory_format", + "torch.ops.aten.empty_like", + "torch.ops.profiler._record_function_enter_new", + "torch.ops.profiler._record_function_exit._RecordFunction", + "torch.ops.aten._unsafe_view", + "torch.ops.aten.view", + "torch.ops.aten._to_copy", + "torch.ops.aten.copy_", + ] + if white_list is not None: + self.white_list += white_list + self.target_list = target_list if target_list is not None else [] + self.dump_error_data = dump_error_data + self.verbose = verbose + self.enable_ranks = enable_ranks + self.should_log_to_file = should_log_to_file + self.output_dir = output_dir + self.step_cnt = 0 + self.start_step = start_step + self.end_step = end_step + self.is_active = True # Initially active, can be toggled based on step counts + self.update_active_state() + self.global_rank = int(os.environ.get("RANK", "-1")) # Fetch MPI rank if available + self.file_suffix = f"_rank{self.global_rank}" if self.global_rank >= 0 else "" + self.log_file_path = os.path.join(self.output_dir, f"compare_result{self.file_suffix}.txt") + + def update_active_state(self): + """ + Updates the active state based on current step count and the start/end steps defined. + """ + is_after_start = self.start_step is None or self.start_step <= self.step_cnt + is_before_end = self.end_step is None or self.end_step > self.step_cnt + self.is_active = is_after_start and is_before_end + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + """ + The core method that intercepts tensor operations. It compares the output of each operation + executed with its CPU counterpart, based on + the specified tolerances, white list, and target list. + """ + with torch._C.DisableTorchFunction(): + # Skip comparison if disabled or not active + if not self.enabled or not self.is_active: + return func(*args, **kwargs) + + op_name = get_op_name(func) # Utility function to extract operation name + full_op_name = get_full_op_name(op_name) # May include namespace + + # LogToFile context manager for optional file logging + with LogToFile(self.log_file_path if self.should_log_to_file else None): + # Skip comparison for certain conditions (ranks, target list, white list) + if self.enable_ranks is not None and self.global_rank not in self.enable_ranks: + return func(*args, **kwargs) + if len(self.target_list) > 0 and op_name not in self.target_list: + print(f"{full_op_name} is not in target_list, pass") + return func(*args, **kwargs) + if op_name in self.white_list: + print(f"{full_op_name} is in white_list, pass") + return func(*args, **kwargs) + + # Perform the actual comparison + out, correct = compare_for_single_func( + func, + args, + kwargs, + atol=self.atol, + rtol=self.rtol, + func_name=full_op_name, + verbose=self.verbose, + ) + + # Handle comparison failure + if self.dump_error_data and not correct: + save_data_for_op(out, args, kwargs, self.output_dir, op_name, self.file_suffix) + raise Exception("CompareWithCPU Failed!") + + return out + + def step(self): + """ + Increments the step count and updates the active state. This method should be called + at the beginning or end of each step (iteration) to properly manage the comparison scope. + """ + if not self.enabled: + return + self.step_cnt += 1 + self.update_active_state() + # Optional logging for active steps + if self.is_active: + with LogToFile(self.log_file_path if self.should_log_to_file else None): + print(f"------------------------- step = {self.step_cnt} ----------------------\n") + + +class NanInfTracker(TorchDispatchMode): + """ + A class for tracking NaN (Not a Number) and Inf (Infinity) values in tensor operations. + This helps in identifying operations that produce these values, which are often indicative + of numerical instabilities in computations. + + Attributes: + enabled (bool): + Flag to enable/disable NaN/Inf tracking. + target_list (list[str]): + Specific operations to track; if not empty, only these are considered. + white_list (list[str]): + Operations to ignore during tracking; considered if target_list is empty. + enable_ranks (list[int]): + MPI ranks that are allowed to perform tracking; None means all ranks. + should_log_to_file (bool): + If True, logs tracking results to a file. + output_dir (str): + Directory to save logs and error data. + dump_error_data (bool): + If True, saves args of the first failing op and exits. + start_step (int): + Step number to start tracking. + end_step (int): + Step number to end tracking. + """ + + TENSOR_FUNCS_NO_DISPATCH = [ + # Can't convert Stream argument to Python object + "record_stream" + ] + + def __enter__(self) -> None: + self._pt_impls = {} + for k in self.TENSOR_FUNCS_NO_DISPATCH: + impl = getattr(torch.Tensor, k) + self._pt_impls[k] = impl + setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl)) + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + for k in self.TENSOR_FUNCS_NO_DISPATCH: + setattr(torch.Tensor, k, self._pt_impls[k]) + return super().__exit__(exc_type, exc_val, exc_tb) + + def __init__( + self, + enabled=True, + target_list=None, + white_list=None, + enable_ranks=None, + should_log_to_file=False, + output_dir="", + dump_error_data=False, + start_step=None, + end_step=None, + ) -> None: + super().__init__() + self.enabled = enabled + self.white_list = [ + "torch.ops._record_function_enter_new", + "torch.ops.detach", + "torch.ops.aten.detach", + "torch.ops.set_.source_Storage", + "torch.ops.set_.source_Storage_storage_offset", + "torch.ops.isinf", + "torch.ops.aten.isinf", + "torch.ops.aten.isnan", + "torch.ops.isnan", + "torch.ops.barrier", + "torch.ops.aten.randperm.generator", + "torch.ops.aten.empty.memory_format", + "torch.ops.profiler._record_function_enter_new", + "torch.ops.profiler._record_function_exit._RecordFunction", + ] + if white_list is not None: + self.white_list += white_list + self.target_list = target_list if target_list is not None else [] + self.enable_ranks = enable_ranks + self.should_log_to_file = should_log_to_file + self.output_dir = output_dir + self.dump_error_data = dump_error_data + self.step_cnt = 0 + self.start_step = start_step + self.end_step = end_step + self.is_active = True # Initially active, can be toggled based on step counts + self.update_active_state() + self.global_rank = int(os.environ.get("RANK", "-1")) # Fetch MPI rank if available + self.file_suffix = f"_rank{self.global_rank}" if self.global_rank >= 0 else "" + self.log_file_path = os.path.join(self.output_dir, f"nan_inf_report{self.file_suffix}.txt") + + def update_active_state(self): + """ + Updates the active state based on current step count and the start/end steps defined. + """ + is_after_start = self.start_step is None or self.start_step <= self.step_cnt + is_before_end = self.end_step is None or self.end_step > self.step_cnt + self.is_active = is_after_start and is_before_end + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + """ + The core method that intercepts tensor operations. It checks the output of each operation + for NaN or Inf values based on the specified white list and target list. + """ + with torch._C.DisableTorchFunction(): + if not self.enabled or not self.is_active: + return func(*args, **kwargs) + + op_name = get_op_name(func) # Utility function to extract operation name + full_op_name = get_full_op_name(op_name) # May include namespace + + with LogToFile(self.log_file_path if self.should_log_to_file else None): + if self.enable_ranks is not None and self.global_rank not in self.enable_ranks: + return func(*args, **kwargs) + if len(self.target_list) > 0 and op_name not in self.target_list: + print(f"{full_op_name} is not in target_list, pass") + return func(*args, **kwargs) + if op_name in self.white_list: + print(f"{full_op_name} is in white_list, pass") + return func(*args, **kwargs) + + out, has_nan_or_inf = nan_inf_track_for_single_func(func, args, kwargs, full_op_name) + + if self.dump_error_data and has_nan_or_inf: + save_data_for_op(out, args, kwargs, self.output_dir, op_name, self.file_suffix) + raise Exception("Nan or Inf Detected!") + + return out + + def step(self): + """ + Increments the step count and updates the active state. This method should be called + at the beginning or end of each step (iteration) to properly manage the tracking scope. + """ + if not self.enabled: + return + self.step_cnt += 1 + self.update_active_state() + if self.is_active: + with LogToFile(self.log_file_path if self.should_log_to_file else None): + print(f"------------------------- step = {self.step_cnt} ----------------------\n") diff --git a/applications/ColossalChat/examples/training_scripts/lora_finetune.py b/applications/ColossalChat/examples/training_scripts/lora_finetune.py index 851ad6a2d9e3..4045556d7ece 100644 --- a/applications/ColossalChat/examples/training_scripts/lora_finetune.py +++ b/applications/ColossalChat/examples/training_scripts/lora_finetune.py @@ -257,7 +257,7 @@ def is_master(): ) torch.set_default_dtype(torch.float) - booster.load_model(model, args.pretrained) + booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8) coordinator.print_on_master( f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 3d4b8a575cad..1de8b649d5d1 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -11,11 +11,11 @@ parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16) - parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) + parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) + parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) + parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) parser.add_argument("-b", "--backend", type=str, default="transformers") - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GPRO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) args = parser.parse_args() ray.init(address="local", namespace="ray-example") @@ -24,29 +24,31 @@ train_model_config = dict(path=args.model) generate_config = dict( top_k=50, - top_p=0.8, + top_p=0.9, + temperature=1.0, ) if args.backend == "transformers": inference_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, ) ) train_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, use_cache=False, ) ) generate_config.update( dict( - max_length=512, + max_length=1024 + 512, do_sample=True, max_new_tokens=None, early_stopping=False, + stop_strings=[""], ) ) elif args.backend == "vllm": @@ -82,12 +84,12 @@ num_producers=args.num_inferencer, num_proc_per_producer=1, num_consumer_procs=args.num_trainers, - num_episodes=1, + num_episodes=10, inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_microbatch_size=args.train_microbatch_size, - dataset_config={"path": args.dataset, "max_length": 256}, + dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, diff --git a/applications/ColossalChat/tests/test_model/test_comm.py b/applications/ColossalChat/tests/test_model/test_comm.py new file mode 100644 index 000000000000..ce211802fb3a --- /dev/null +++ b/applications/ColossalChat/tests/test_model/test_comm.py @@ -0,0 +1,87 @@ +import ray +import ray.util.collective as cc +import torch +from coati.distributed.comm import ray_broadcast_object, ray_broadcast_tensor_dict + +from colossalai.testing import parameterize + + +@ray.remote(num_gpus=1) +class Worker: + def __init__(self, rank, world_size): + self.rank = rank + self.world_size = world_size + self.group_name = "default" + cc.init_collective_group(world_size, rank, backend="nccl", group_name=self.group_name) + + def run_ray_broadcast_object(self, obj, src, device): + # ray_broadcast_object + received_obj = ray_broadcast_object(obj, src, device, group_name=self.group_name) + return received_obj + + def run_ray_broadcast_tensor_dict(self, tensor_dict, src, device): + # ray_broadcast_tensor_dict + received_dict = ray_broadcast_tensor_dict(tensor_dict, src, device, group_name=self.group_name) + return received_dict + + def destroy_worker(self): + cc.destroy_collective_group(self.group_name) + + +@parameterize( + "test_config", + [ + { + "precision": torch.bfloat16, + "device": "npu", + "num_devices": 4, + }, + ], +) +def test_comm(test_config): + ray.init(num_gpus=4) + src = 0 + device = test_config["device"] + # create 4 + workers = [Worker.remote(i, test_config["num_devices"]) for i in range(test_config["num_devices"])] + + ############# + # 1. test ray_broadcast_object + ############# + # init broadcast_object data + test_obj = {"data": torch.tensor([1, 2, 3]), "message": "hello"} + + # run run_ray_broadcast_object + results = [worker.run_ray_broadcast_object.remote(test_obj, src, device) for worker in workers] + + # get result + results = ray.get(results) + + for i, result in enumerate(results): + print(f"ray_broadcast_object Rank {i} received object: {result}") + + ############# + # 2. test ray_broadcast_tensor_dict + ############# + test_tensor_dict = { + "tensor1": torch.tensor([1, 2, 3], device=device), + "tensor2": torch.tensor([[4, 5], [6, 7]], device=device), + } + + # run ray_broadcast_tensor_dict + results = [worker.run_ray_broadcast_tensor_dict.remote(test_tensor_dict, src, device) for worker in workers] + + # get result + results = ray.get(results) + + for i, result in enumerate(results): + print(f"run_ray_broadcast_tensor_dict Rank {i} received object: {result}") + + # destory workers + for worker in workers: + worker.destroy_worker.remote() + ray.shutdown() + + +if __name__ == "__main__": + test_comm() diff --git a/applications/ColossalChat/tests/test_model/test_hybrid.py b/applications/ColossalChat/tests/test_model/test_hybrid.py new file mode 100644 index 000000000000..433456f7d345 --- /dev/null +++ b/applications/ColossalChat/tests/test_model/test_hybrid.py @@ -0,0 +1,172 @@ +import torch +import torch.distributed as dist +from coati.dataset.loader import RawConversationDataset +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer, Qwen2ForCausalLM + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin, Plugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam + +# 定义训练参数 +BATCH_SIZE = 8 +NUM_EPOCHS = 3 +LEARNING_RATE = 2e-5 +GRADIENT_ACCUMULATION_STEPS = 1 +DATA_PATH = "/home/duanjunwen/datasets/math_dataset.jsonl" +Device = torch.device("npu" if torch.npu.is_available() else "cpu") + + +class RandomDataset(Dataset): + def __init__(self, num_samples, sequence_length, vocab_size=10000): + self.num_samples = num_samples + self.sequence_length = sequence_length + self.vocab_size = vocab_size + self.input_idx = torch.randint(0, vocab_size, (num_samples, sequence_length)) + self.attention_mask = torch.randint(0, 2, (num_samples, sequence_length), dtype=torch.long) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return {"input_ids": self.input_idx[idx], "attention_mask": self.attention_mask[idx]} + + +def load_model_and_tokenizer(): + attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2" + tokenizer = AutoTokenizer.from_pretrained( + "/home/duanjunwen/models/Qwen/Qwen2.5-3B", + trust_remote_code=True, + attn_implementation=attn_impl, + ) + model = Qwen2ForCausalLM.from_pretrained("/home/duanjunwen/models/Qwen/Qwen2.5-3B", trust_remote_code=True) + return tokenizer, model + + +def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor: + loss = loss.data + group = getattr(plugin, "dp_group", None) + dist.all_reduce(loss, group=group) + return loss / dist.get_world_size(group) + + +# def train(model, dataloader, booster, optimizer): +# model.train() + +# for epoch in range(NUM_EPOCHS): +# if booster.plugin.pp_size > 1: +# data_iter = iter(dataloader) +# step_bar = tqdm( +# range(len(dataloader)), +# desc="Step", +# disable=not is_master(), +# ) +# else: +# total_loss = 0 +# for step, batch in enumerate(dataloader): +# input_ids = batch["input_ids"].to(device=model.module.device) +# attention_mask = batch["attention_mask"].to(device=model.module.device) +# outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) +# loss = outputs.loss +# print(f"loss {loss} outputs {outputs}") +# loss = loss / GRADIENT_ACCUMULATION_STEPS +# booster.backward(loss, optimizer) + +# if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: +# optimizer.step() +# optimizer.zero_grad() + +# total_loss += loss.item() + +# print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}") + + +def test_hybrid_qwen(): + colossalai.launch_from_torch() + get_accelerator() + coordinator = DistCoordinator() + tokenizer, model = load_model_and_tokenizer() + # dataset = RandomDataset(num_samples=100, sequence_length=2304) + dataset = RawConversationDataset(tokenizer, DATA_PATH, 1024) + # dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) + + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE) + plugin = HybridParallelPlugin(tp_size=2, pp_size=1, precision="bf16", zero_stage=2) + # plugin = HybridParallelPlugin(tp_size=2, pp_size=2, precision="bf16", zero_stage=1, num_microbatches=4, enable_flash_attention=True) + + dataloader = plugin.prepare_dataloader( + dataset=dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ) + + booster = Booster(plugin=plugin) + + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, None, dataloader) + + def is_master(): + if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1: + return coordinator.rank == coordinator.world_size - 1 + return coordinator.is_master() + + ##### + # train + ##### + model.train() + + for epoch in range(NUM_EPOCHS): + if booster.plugin.pp_size > 1: + data_iter = iter(dataloader) + step_bar = tqdm( + range(len(dataloader)), + desc="Step", + disable=not is_master(), + ) + for step in step_bar: + print(f"data_iter {data_iter}") + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if booster.plugin.stage_manager.is_last_stage(): + global_loss = all_reduce_mean(loss, plugin) + + optimizer.step() + + if booster.plugin.stage_manager.is_last_stage(): + grad_norm = optimizer.get_grad_norm() + step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm}) + + optimizer.step() + optimizer.zero_grad() + else: + total_loss = 0 + for step, batch in enumerate(dataloader): + input_ids = batch["input_ids"].to(device=model.module.device) + attention_mask = batch["attention_mask"].to(device=model.module.device) + outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) + loss = outputs.loss + print(f"loss {loss} outputs {outputs}") + loss = loss / GRADIENT_ACCUMULATION_STEPS + booster.backward(loss, optimizer) + + if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0: + optimizer.step() + optimizer.zero_grad() + + total_loss += loss.item() + + print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}") + + +if __name__ == "__main__": + test_hybrid_qwen() diff --git a/applications/ColossalChat/tests/test_model/test_loss.py b/applications/ColossalChat/tests/test_model/test_loss.py new file mode 100644 index 000000000000..e7202d393b29 --- /dev/null +++ b/applications/ColossalChat/tests/test_model/test_loss.py @@ -0,0 +1,77 @@ +import copy + +import torch +from coati.distributed.loss import PolicyLoss +from torch.testing import assert_close + +from colossalai.testing import parameterize +from colossalai.utils import set_seed + + +@parameterize( + "test_config", + [ + { + "precision": torch.bfloat16, + "device": "npu", + }, + ], +) +def run_policy_loss_fn(test_config): + dtype = test_config["precision"] + device = test_config["device"] + set_seed(42) + policy_loss_fn = PolicyLoss() + + ############ + # init npu tensor + ############ + action_log_probs = torch.rand(8, 2048, dtype=dtype, device=device) # float [8, 2048] + old_action_log_probs = torch.rand(8, 2048, dtype=dtype, device=device) # float [8, 2048] + advantages = torch.rand(8, dtype=dtype, device=device) # float [8] + per_token_kl = torch.rand(8, 2048, dtype=dtype, device=device) # float [8, 2048] + action_mask = torch.randint( + low=0, high=2, size=(8, 2048), dtype=torch.int32, device=device + ) # torch.int32 [8, 2048] in range(0,1) + + loss, skip_update, _ = policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask, + ) + + ############ + # init cpu tensor + ############ + action_log_probs_cpu = copy.deepcopy(action_log_probs.cpu()) + old_action_log_probs_cpu = copy.deepcopy(old_action_log_probs.cpu()) + advantages_cpu = copy.deepcopy(advantages.cpu()) + per_token_kl_cpu = copy.deepcopy(per_token_kl.cpu()) + action_mask_cpu = copy.deepcopy(action_mask.cpu()) + + loss_cpu, skip_update_cpu, _ = policy_loss_fn( + action_log_probs_cpu, + old_action_log_probs_cpu, + advantages_cpu.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl_cpu, + action_mask_cpu, + ) + + # assert close + assert_close( + loss.to("cpu"), + loss_cpu, + rtol=5e-4, + atol=5e-4, + # msg=f"NPU/CPU {test_config['precision']} not close" + ) + + +def test_loss_func(): + run_policy_loss_fn() + + +if __name__ == "__main__": + test_loss_func() diff --git a/applications/ColossalChat/tests/test_model/test_reward.py b/applications/ColossalChat/tests/test_model/test_reward.py new file mode 100644 index 000000000000..7d607cc19265 --- /dev/null +++ b/applications/ColossalChat/tests/test_model/test_reward.py @@ -0,0 +1,49 @@ +import torch +from coati.distributed.reward.reward_fn import math_reward_fn +from coati.distributed.reward.verifiable_reward import VerifiableReward +from transformers import AutoTokenizer + +from colossalai.testing import parameterize +from colossalai.utils import set_seed + + +@parameterize( + "test_config", + [ + {"device": "npu"}, + ], +) +def run_math_reward_fn(test_config): + device = test_config["device"] + set_seed(42) + # init tensor + input_ids = torch.randint(low=0, high=151644, size=(8, 2304), dtype=torch.int64, device=device) # [8, 2304] + gt_answer = torch.randint(low=0, high=151644, size=(8, 128), dtype=torch.int64, device=device) # [8, 128] + response_idx = torch.randint(low=256, high=2303, size=(8, 2), dtype=torch.int64, device=device) # [8, 2] + + # load tokenizer + # Qwen/Qwen2.5-3B + tokenizer = AutoTokenizer.from_pretrained("/home/share/data/model/Qwen2.5-3B") + + response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + reward_model = VerifiableReward(reward_fns=[math_reward_fn], tokenizer=tokenizer, tags=response_format_tags) + reward_model(input_ids, gt_answer, response_idx) + + +# not in use +def run_gsm8k_reward_fn(): + pass + + +def test_reward_func(): + run_math_reward_fn() + # run_gsm8k_reward_fn() + + +if __name__ == "__main__": + test_reward_func() diff --git a/applications/ColossalChat/tests/test_model/test_utils.py b/applications/ColossalChat/tests/test_model/test_utils.py new file mode 100644 index 000000000000..126ca81abe97 --- /dev/null +++ b/applications/ColossalChat/tests/test_model/test_utils.py @@ -0,0 +1,87 @@ +import torch +from coati.distributed.utils import log_probs_from_logits, masked_mean +from torch.testing import assert_close + +from colossalai.testing import parameterize +from colossalai.utils import set_seed + + +@parameterize( + "test_config", + [ + {"precision": torch.bfloat16, "device": "npu"}, + ], +) +def run_log_probs_from_logits(test_config): + torch.set_default_dtype(test_config["precision"]) + set_seed(42) + + # generate input + logits_cpu = torch.randn(2, 10, 50257) # (batch, seq_len, vocab_size) + labels_cpu = torch.randint(0, 50257, (2, 10)) + + # to npu + logits_cpu = logits_cpu + labels_cpu = labels_cpu + logits_gpu = logits_cpu.clone().to(device=test_config["device"]) + labels_gpu = labels_cpu.clone().to(device=test_config["device"]) + + # fwd + output_cpu = log_probs_from_logits(logits_cpu, labels_cpu) + output_gpu = log_probs_from_logits(logits_gpu, labels_gpu).cpu() + + # assert close + assert_close( + output_gpu, + output_cpu, + rtol=5e-4, + atol=5e-4, + # msg=f"NPU/CPU {test_config['precision']} not close" + ) + + +@parameterize( + "test_config", + [ + {"precision": torch.bfloat16, "device": "npu"}, + {"precision": torch.float32, "device": "npu"}, + ], +) +def run_calc_action_log_probs(test_config): + # same with run_log_probs_from_logits + pass + + +@parameterize( + "test_config", + [ + {"precision": torch.bfloat16, "device": "npu"}, + ], +) +def run_masked_mean(test_config): + torch.set_default_dtype(test_config["precision"]) + set_seed(42) + + # init tensor and mask + tensor = torch.randn(1, 10, 128) # batch_size, seq_length, hidden_size + mask = torch.rand(1, 10, 128) > 0.3 # init mask + + tensor_gpu = tensor.to(device=test_config["device"]) + mask_gpu = mask.to(device=test_config["device"]) + + # fwd + cpu_output = masked_mean(tensor, mask, dim=1) + gpu_output = masked_mean(tensor_gpu, mask_gpu, dim=1).cpu() + + # assert close + torch.testing.assert_close(cpu_output, gpu_output, atol=1e-2, rtol=1e-2) + + +def test_util_func(): + run_calc_action_log_probs() + run_log_probs_from_logits() + run_masked_mean() + + +if __name__ == "__main__": + test_util_func() diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4b1224c68ffd..a81f9b05d7d7 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -85,11 +85,11 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) for k, v in state_dict.items(): - self.pinned_state_dicts[id(model)][k].copy_(v) - state_dict[k] = self.pinned_state_dicts[id(model)][k] + self.pinned_state_dicts[hash(model)][k].copy_(v) + state_dict[k] = self.pinned_state_dicts[hash(model)][k] writer = save(checkpoint, state_dict) self.async_writers.append(writer) else: @@ -172,9 +172,9 @@ def save_sharded_model( Path(checkpoint_path).mkdir(parents=True, exist_ok=True) if use_async and self.coordinator.is_master(): - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = model.state_dict_shard( diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702e70..1e0f7be240f6 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.model import PeftUnwrapMixin from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed @@ -225,7 +226,7 @@ def unwrap(self, unwrap_peft: bool = True): if isinstance(model, DDP): model = model.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model def _force_wait_all_gather(self): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index e74b1a9598b9..9cb5adf01972 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -12,6 +12,7 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface.model import PeftUnwrapMixin from colossalai.logging import get_dist_logger from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.utils import get_current_device @@ -201,7 +202,7 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None: def unwrap(self, unwrap_peft: bool = True) -> nn.Module: model = self.module.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index d713203fe905..6e652e549b5e 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -103,11 +103,11 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state) for k, v in full_model_state.items(): - self.pinned_state_dicts[id(model)][k].copy_(v) - full_model_state[k] = self.pinned_state_dicts[id(model)][k] + self.pinned_state_dicts[hash(model)][k].copy_(v) + full_model_state[k] = self.pinned_state_dicts[hash(model)][k] writer = save(checkpoint, full_model_state) self.async_writers.append(writer) else: @@ -186,9 +186,9 @@ def save_sharded_model( state_dict = model.unwrap().state_dict() if use_async and self.coordinator.is_master(): - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = utils.shard_model_checkpoint( diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 3e600c94dfc5..5dfb09248b53 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -60,9 +60,9 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import move_and_save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) + writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)]) self.async_writers.append(writer) else: # save the checkpoint @@ -234,7 +234,7 @@ def save_sharded_model( index_file = CheckpointIndexFile(checkpoint_path) if use_async: - pinned_state_dict = self.pinned_state_dicts.get(id(model), None) + pinned_state_dict = self.pinned_state_dicts.get(hash(model), None) total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint_path, @@ -243,7 +243,7 @@ def save_sharded_model( is_master=True, pinned_state_dict=pinned_state_dict, ) - self.pinned_state_dicts[id(model)] = new_pinned_state_dict + self.pinned_state_dicts[hash(model)] = new_pinned_state_dict self.async_writers.extend(writers) else: # Save shards of optimizer states. diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 5de32e66655c..9d972635214d 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -249,9 +249,9 @@ def save_sharded_model( # Only devices with tp_rank == 0 are responsible for model saving. control_saving = self.tp_rank == 0 and self.sp_rank == 0 if control_saving and use_async: - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = HybridParallelCheckpointIO._model_sharder( @@ -789,11 +789,11 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) for name, param in state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = self.pinned_state_dicts[id(model)][name] + self.pinned_state_dicts[hash(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[hash(model)][name] writer = save(path=checkpoint, state_dict=state_dict) self.async_writers.append(writer) else: @@ -811,11 +811,11 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict) for name, param in complete_state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] + self.pinned_state_dicts[hash(model)][name].copy_(param) + complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name] writer = save(path=checkpoint, state_dict=complete_state_dict) self.async_writers.append(writer) else: diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 586c7863f4bf..85e36f7c6336 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -701,15 +701,18 @@ def pre_save_model(self, model: nn.Module) -> dict: all_param = None # gather param from every ep rank # dist.all_gather(all_param, param, group=ep_group) - dist.gather(param, all_param, group=ep_group) + dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group) if ep_rank == 0: all_param = torch.cat(all_param, dim=0) state_dict[name] = all_param.cpu() if self.pp_size > 1: if self.dp_rank == 0: - out = [None for _ in range(self.pp_size)] - dist.gather_object(state_dict, out, group=self.pp_group) + if self.pp_rank == 0: + out = [None for _ in range(self.pp_size)] + else: + out = None + dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group) if self.pp_rank == 0: new_state_dict = {} for o in out: diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 2d826bd15f52..4b36dbe002bb 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -20,6 +20,7 @@ from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from colossalai.accelerator import get_accelerator +from colossalai.interface.model import PeftUnwrapMixin from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model except ImportError: return + if isinstance(model, PeftUnwrapMixin): + model = model.base_model if not isinstance(model, PreTrainedModel): return @@ -692,6 +695,9 @@ def load_state_dict_into_model( state_dict (dict): a dict containing parameters and persistent buffers. """ + if isinstance(model, PeftUnwrapMixin): + state_dict = model.patch_state_dict(state_dict) + model = model.base_model if not isinstance(state_dict, Mapping): raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py index d112c27230b0..8dbd15c63cf9 100644 --- a/colossalai/interface/model.py +++ b/colossalai/interface/model.py @@ -1,5 +1,102 @@ +import re +from typing import Dict, Set + +import torch import torch.nn as nn -from peft import PeftModel +from peft import PeftModel, PeftType + + +def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"): + config = model.peft_config[adapter_name] + if config.peft_type != PeftType.LORA: + raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.") + # to_return = lora_state_dict(model, bias=model.peft_config.bias) + # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` + # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP + bias = config.bias + if bias == "none": + to_return = {k for k in names if "lora_" in k} + elif bias == "all": + to_return = {k for k in names if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = set() + for k in names: + if "lora_" in k: + to_return.add(k) + bias_name = k.split("lora_")[0] + "bias" + if bias_name in names: + to_return.add(bias_name) + else: + raise NotImplementedError + to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))} + if config.use_dora: + # Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a + # ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since + # we want the state_dict format not to change, we remove the "weight" part. + new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight" + + def renamed_dora_weights(k): + if k.endswith(new_dora_suffix): + k = k[:-7] # remove ".weight" + return k + + to_return = {renamed_dora_weights(k) for k in to_return} + + to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return} + return to_return + + +class PeftUnwrapMixin: + def __init__(self, peft_model: PeftModel): + self.base_model = peft_model.get_base_model() + # peft does not affect buffers + self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters())) + potential_lora_weights = set() + for n in self.lora_layers: + potential_lora_weights.add(f"{n}.weight") + potential_lora_weights.add(f"{n}.bias") + self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights} + self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()} + + def named_parameters(self): + for n, p in self.base_model.named_parameters(): + if n in self.lora_param_to_origin_param: + n = self.lora_param_to_origin_param[n] + yield n, p + + def named_buffers(self): + return self.base_model.named_buffers() + + @property + def _modules(self): + return self.base_model._modules + + @property + def _non_persistent_buffers_set(self): + return self.base_model._non_persistent_buffers_set + + def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]): + new_state_dict = {} + for k, v in state_dict.items(): + if k in self.origin_param_to_lora_param: + k = self.origin_param_to_lora_param[k] + new_state_dict[k] = v + return new_state_dict + + def state_dict(self): + state_dict = {} + for k, v in self.base_model.state_dict().items(): + if k in self.lora_param_to_origin_param: + k = self.lora_param_to_origin_param[k] + state_dict[k] = v + return state_dict + + def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): + state_dict = self.patch_state_dict(state_dict) + self.base_model.load_state_dict(state_dict, strict=strict, assign=assign) + + def __hash__(self): + return hash(self.base_model) class ModelWrapper(nn.Module): @@ -23,7 +120,7 @@ def unwrap(self, unwrap_peft: bool = True): else: model = self.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model def forward(self, *args, **kwargs): diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 0bd1b60923e9..a1b80bf56b63 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d, dist_cross_entropy +from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import ( @@ -28,6 +28,8 @@ "DropoutForReplicatedInput", "cross_entropy_1d", "dist_cross_entropy", + "dist_log_prob_1d", + "dist_log_prob", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 0e2241af9fc9..51419a38a0ed 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -3,13 +3,21 @@ from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss +from torch.nn.functional import log_softmax from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig from .utils import is_share_sp_tp -__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] +__all__ = [ + "DistCrossEntropy", + "cross_entropy_1d", + "dist_cross_entropy", + "DistLogProb", + "dist_log_prob_1d", + "dist_log_prob", +] _IGNORE_IDX = -100 @@ -137,6 +145,98 @@ def backward(ctx, grad_output): return grad_logits, None, None, None, None, None, None +class DistLogProb(Function): + r""" + Overwrite the forward and backward function to calculate the log prob before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + process_group: ProcessGroup, + vocab_size: int, + dtype=torch.float32, + ): + + ################## + # Step1:Find the global maximum value of logits + ################## + logits_max = torch.max(vocab_logits, dim=-1)[0] + handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) + + ################## + # Step2:Find the local mask. local mask will be use to select log_probs value in Step 4. + # For accleration, we overlap Step 2 and Step 3 + ################## + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + if vocab_size is None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size + # down and up threshold for local logits + delta = (global_vocab_size + world_size - 1) // world_size + down_threshold = rank * delta + up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size + # mask + mask = (target < down_threshold) | (target >= up_threshold) + masked_target = target.clone() - down_threshold + masked_target[mask] = 0 + masked_target_1d = masked_target.view(-1).contiguous() + handle.wait() + + ################## + # Step3:Calculate global summation exp logits + ################## + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + exp_logits = torch.exp(vocab_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + ################## + # Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask + ################## + log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax + log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1)) + log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero + dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group) + + ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits) + ctx.dtype = dtype + return log_probs + + @staticmethod + def backward(ctx, grad_output): + exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors + ################## + # Step1:Find the global sofmax value + ################## + softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1) + + ################## + # Step2:Update softmax value based on local target index + ################## + partion_vocab_size = softmax_logits.shape[-1] + softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size) + update = 1.0 - mask.view(-1).float().to(ctx.dtype) + softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update + + ################## + # Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax + ################## + grad_logits = -softmax_logits.mul_(grad_output) + return grad_logits, None, None, None, None, None, None + + def cross_entropy_1d( vocab_logits: torch.Tensor, labels: torch.Tensor, @@ -149,6 +249,16 @@ def cross_entropy_1d( return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) +def dist_log_prob_1d( + vocab_logits: torch.Tensor, + labels: torch.Tensor, + process_group: ProcessGroup = None, + vocab_size: int = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype) + + def dist_cross_entropy( labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] @@ -243,3 +353,41 @@ def dist_cross_entropy( loss, num_nonzero = loss[0], loss[1].detach() loss = (loss / num_nonzero).squeeze() return loss + + +def dist_log_prob( + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + logits: torch.Tensor, # [B, S, Vocab_size] + shard_config: ShardConfig, + vocab_size: int, + dtype: torch.dtype, + seq_dim: int = 1, +) -> torch.Tensor: + """ + Helper to compute log prob for most shardformer models supporting PP, TP. + """ + # Split labels if not gather output + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism + + # TODO:support sp + labels = labels[..., 1:] + logits = logits[..., :-1, :] + labels = labels.contiguous() + logits = logits.contiguous() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + + # Flatten the tokens + if is_tp and parallel_output: + log_prob = dist_log_prob_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=vocab_size, + dtype=dtype, + ) + else: + log_prob = log_softmax(logits) + log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) + + return log_prob diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 569fc4a459c5..71e3557fe214 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -832,7 +832,6 @@ def forward( loss = None if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 84d2b2fdbd99..fd14029a3a36 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -11,8 +11,10 @@ Linear1D_Row, LinearWithGradAccum, PaddingEmbedding, + PaddingLMHead, RMSNorm, VocabParallelEmbedding1D, + VocabParallelLMHead1D, ) from ..modeling.qwen2 import ( @@ -429,8 +431,12 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=not self.shard_config.parallel_output, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, @@ -444,9 +450,18 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=LinearWithGradAccum, + target_module=PaddingLMHead, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), - ) + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, + ), ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f357c45fde64..688c47cc2221 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.2.0,<=2.4.1 +torch>=2.2.0,<=2.5.1 safetensors einops pydantic diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py index 5d140064ba94..c4a92a138d7c 100644 --- a/tests/test_cluster/test_device_mesh_manager.py +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -1,7 +1,7 @@ from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_device_mesh_manager(rank, world_size, port): @@ -24,6 +24,7 @@ def check_device_mesh_manager(rank, world_size, port): assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]] +@rerun_if_address_is_in_use() def test_device_mesh_manager(): spawn(check_device_mesh_manager, 4) diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py new file mode 100644 index 000000000000..05a6a5d4766f --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py @@ -0,0 +1,52 @@ +import pytest +import torch +from coati.distributed.utils import log_probs_from_logits + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer import dist_log_prob_1d +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict( + parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")), +) + + +def check_dist_log_prob(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True).cuda() + labels = torch.randint(8, (2, 4)).cuda() + + logprob = log_probs_from_logits(pred, labels) + + pred.retain_grad() + logprob.mean().backward() + + dist_pred = pred.clone().chunk(world_size, -1)[rank].detach() + dist_pred.requires_grad = True + dist_logprob = dist_log_prob_1d(dist_pred, labels) + + dist_pred.retain_grad() + dist_logprob.squeeze(-1).mean().backward() + + assert torch.allclose( + logprob, dist_logprob.squeeze(-1), atol=1e-5 + ), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}" + + pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach() + assert torch.allclose( + pred_grad_partial, dist_pred.grad + ), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_log_prob(): + spawn(check_dist_log_prob, 2) + + +if __name__ == "__main__": + test_dist_log_prob() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 6cdf5bf41c68..40b4e368d384 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: - atol, rtol = 5e-2, 5e-2 + atol, rtol = 9e-2, 0 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: row_layer_grads = get_grad_tensors_for_check( t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 diff --git a/version.txt b/version.txt index cb498ab2c89f..76914ddc02f8 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.8 +0.4.9