diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 9d1f16fbe..da443c4f6 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -72,10 +72,12 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: ) -def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None: +def safe_barrier( + group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None +) -> None: if group: hashed = hash(value) % 2**32 - out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout) + out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device) if out != hashed * group.size(): raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})") @@ -86,9 +88,10 @@ def allreduce_scalar( group: torch.distributed.ProcessGroup | None = None, op=ReduceOp.SUM, timeout: float | None = None, + device: torch.device | None = None, ) -> float | int: if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) + value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device) with set_timeout(group, timeout): torch.distributed.all_reduce(value, op=op, group=group) return value.item() diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 5ddf2ff98..4d31324fe 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -327,7 +327,9 @@ def _preprocess_data( self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool ) -> typing.Generator[None, None, None]: batch_config = context.schedule.batch_config - grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs + grad_output = ( + self._optimizer.grad_scale / batch_config.num_inputs if context.schedule.phase.is_training else None + ) for micro_batch in range(batch_config.sequential_micro_batches): micro_batch_data = next(data_iterator) if not preprocessed: diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 4cfc3b61d..050c700c9 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -93,16 +93,17 @@ def _set_activation_fn_map() -> None: MAX_DROPLESS_BLOCK_SIZE_ROW = 128 -class CrossEntropyImpl(str, enum.Enum): +class EntropyLossImplementation(enum.StrEnum): auto = "auto" torch = "torch" fused = "fused" triton = "triton" -class DistillationLossImpl(str, enum.Enum): - reverse_kl = "reverse_kl" +class EntropyLossType(enum.StrEnum): cross_entropy = "cross_entropy" + forward_kl = "forward_kl" + reverse_kl = "reverse_kl" class TargetFormat(enum.StrEnum): diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py deleted file mode 100644 index a12516b5d..000000000 --- a/fast_llm/functional/cross_entropy.py +++ /dev/null @@ -1,359 +0,0 @@ -import torch - -from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat -from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward -from fast_llm.utils import Assert - - -def _torch_cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - logits_scale_factor: float, - target_format: TargetFormat, - teacher_softmax_temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - A wrapper for the pytorch implementation of cross-entropy. - The cross-entropy kernels themselves are well-optimized, but the need for explicit casting - and separate forward and backward kernels lead to poor performance. - TODO: loss masking only works for with labels format and if the masking index is set to -100. - """ - # Torch compile doesn't understand this. - with torch.set_grad_enabled(grad_output is not None): - logits_ = logits.float().detach().requires_grad_(grad_output is not None) - if target_format == TargetFormat.logits: - if logits_scale_factor != 1.0: - target = target * logits_scale_factor - if teacher_softmax_temperature != 1.0: - target = target / teacher_softmax_temperature - target = torch.softmax(target, dim=-1) - if loss_mask is None: - loss = torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target - ) - else: - loss = ( - torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" - ) - * loss_mask - ).mean() - if grad_output is None: - grad = None - else: - loss.backward(torch.full_like(loss, grad_output)) - grad = logits_.grad.detach().to(logits.dtype) - return loss.detach_(), grad - - -@torch.compile -def _fused_softmax_base( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - logits = logits.float() - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - logits_max = torch.max(logits, dim=dim, keepdim=True)[0] - if group is not None: - all_reduce(logits_max, op=ReduceOp.MAX, group=group) - logits_norm = (logits - logits_max).float() - exp_logits = logits_norm.exp() - sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) - if group is not None: - all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) - return logits_norm, exp_logits, sum_exp_logits - - -@torch.compile -def _fused_softmax( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1 -) -> torch.Tensor: - _, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim) - return exp_logits / sum_exp_logits - - -# @torch.compile -def _fused_cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - logits_scale_factor: float, - target_format: TargetFormat, - group: ProcessGroup | None = None, - teacher_softmax_temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - A fused implementation of cross-entropy with torch compile. - It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, - but still suboptimal because it needs multiple kernels. - """ - # Do the forward and backward passes all at once, and fused with dtype conversion. - # Way faster and more memory-efficient than the pytorch version. - - logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) - - if target_format == TargetFormat.logits: - target = _fused_softmax(target, logits_scale_factor / teacher_softmax_temperature, group) - - if target_format == TargetFormat.labels: - target = target.unsqueeze(-1) - loss_mask = target >= 0 - if group is None: - # Keep values within range for scatter and gather ops to work. - target = target * loss_mask - target_mask = None - else: - # Mask the target (fused) - # TODO: Could mask earlier on cpu or overlap with reduce? - vocab_start_index = logits.size(-1) * group.rank() - target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target = (target - vocab_start_index) * target_mask - else: - # Target should be tensor-parallel already, no further manipulation needed. - target_mask = None - if loss_mask is not None: - loss_mask = loss_mask.unsqueeze(-1) - - if grad_output is None: - grad = None - else: - # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - if target_format == TargetFormat.labels: - grad_base = exp_logits.scatter_add( - 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) - ) - else: - grad_base = exp_logits - sum_exp_logits * target - - grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) - if logits_scale_factor != 1.0: - grad *= logits_scale_factor - if loss_mask is not None: - grad *= loss_mask - grad = grad.to(logits.dtype) - - # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) - if target_format == TargetFormat.labels: - predicted_logits = logits_norm.gather(1, target) - if group is not None: - predicted_logits = target_mask * predicted_logits - - all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) - else: - predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) - if group is not None and target_format != TargetFormat.labels: - # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit. - # Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i) - # = log Z - 1/K sum_ranks (sum_i t_i * z_i), where is the global predicted_logits, so without multiplying it by K 1/K there does not cancel out. - predicted_logits = predicted_logits * group.size() - - per_sample_loss = sum_exp_logits.log() - predicted_logits - if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - - loss = per_sample_loss.mean() - if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.AVG, group=group) - - return loss, grad - - -_CROSS_ENTROPY_IMPLEMENTATIONS = { - CrossEntropyImpl.torch: _torch_cross_entropy_forward_backward, - CrossEntropyImpl.fused: _fused_cross_entropy_forward_backward, - CrossEntropyImpl.triton: triton_cross_entropy_forward_backward, -} - - -def cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - implementation: CrossEntropyImpl = CrossEntropyImpl.fused, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Select the appropriate implementation of cross-entropy. - The triton implementation from the triton submodule is the fastest and recommended one. - It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, - which is faster and has a relatively small memory overhead. - """ - if target_format == TargetFormat.labels: - Assert.eq(target.shape, logits.shape[:-1]) - Assert.eq(target.dtype, torch.int64) - assert loss_mask is None - else: - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - if group: - Assert.eq(implementation, CrossEntropyImpl.fused) - return _fused_cross_entropy_forward_backward( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - group, - teacher_softmax_temperature, - ) - else: - return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - teacher_softmax_temperature=teacher_softmax_temperature, - ) - - -def distributed_log_softmax( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 -): - logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim) - - return logits_norm - sum_exp_logits.log() # log_softmax - - -@torch.compile -def _reverse_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - target_format: TargetFormat, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Reverse KL using PyTorch's native kl_div function. - This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab. - - Takes: - logits: [BxS, V] or [B, S, V] - target: [BxS, V] or [B, S, V] (logits format) - loss_mask: [BxS] or [B, S] or None - ... - """ - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - teacher_log_probs = distributed_log_softmax(target.float(), group=group) - log_ratio = distributed_log_softmax(logits, group=group) - - student_probs = log_ratio.exp() - log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs - del teacher_log_probs - # Compute loss terms: student_probs * log_ratio, then sum over vocab - # This is equivalent to kl_div(..., log_target=True) but more memory efficient - loss_terms = (student_probs * log_ratio).sum(dim=-1) - - if loss_mask is not None: - # loss mask is the same on all ranks for TP over vocab. - valid = loss_mask.to(loss_terms.dtype) - loss_terms = loss_terms * valid - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) - loss = loss_terms.sum() # sums over batch and seq. len. - - if group is not None: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= valid_tokens - - if grad_output is not None: - # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) - # where E_q[log(q/p)] is the expected log ratio under the student distribution - expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True) - if group is not None: - all_reduce(expected, op=ReduceOp.SUM, group=group) - log_ratio = log_ratio - expected - log_ratio = log_ratio * student_probs - del student_probs # Free after use - - if loss_mask is not None: - log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1) - - log_ratio = log_ratio * (grad_output / valid_tokens) - grad = log_ratio.to(logits.dtype) - else: - grad = None - - return loss.detach_(), grad - - -def reverse_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, - sequence_parallel_logits: bool = False, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). - This is mode-seeking (vs. mode-covering for forward KL) and useful for: - - Encouraging the model to focus on the modes of the target distribution - - Avoiding probability mass on low-probability regions of the target - - Distillation scenarios where you want sharp, focused predictions - - Key differences from standard cross-entropy: - - Standard CE: KL(p||q) = mode-covering (spreads mass broadly) - - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) - - Takes: - logits: [BxS, V] or [B, S, V], where V is local vocab size - target: [BxS, V] or [B, S, V] (logits format) - loss_mask: [BxS] or [B, S] or None - ... - - Returns: - loss: Reverse KL divergence loss - grad: Gradients w.r.t. logits - """ - - if sequence_parallel_logits: - # TODO: see hybrid dev branch where it is implemented - raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true") - - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # TODO: implement fused? - distillation_loss, distillation_grad = _reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - target_format=target_format, - teacher_softmax_temperature=teacher_softmax_temperature, - group=group, - ) - return distillation_loss, distillation_grad diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py new file mode 100644 index 000000000..0c0fe9fa3 --- /dev/null +++ b/fast_llm/functional/entropy_loss.py @@ -0,0 +1,345 @@ +import torch + +from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat +from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward +from fast_llm.utils import Assert + + +def _torch_entropy_loss_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + entropy_loss_type: EntropyLossType, + temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + A wrapper for the pytorch implementation of cross-entropy. + The cross-entropy kernels themselves are well-optimized, but the need for explicit casting + and separate forward and backward kernels lead to poor performance. + TODO: loss masking only works for with labels format and if the masking index is set to -100. + """ + # Torch compile doesn't understand this. + with torch.set_grad_enabled(grad_output is not None): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + logits_scaled = logits_ if logits_scale_factor == 1.0 else logits_ * logits_scale_factor + if target_format == TargetFormat.logits: + target_scale = logits_scale_factor / temperature + target = target if target_scale == 1.0 else target * target_scale + else: + Assert.eq(temperature, 1.0) + + if entropy_loss_type == EntropyLossType.cross_entropy: + if target_format == TargetFormat.logits: + target = torch.softmax(target, dim=-1) + loss = torch.nn.functional.cross_entropy( + logits_scaled, target, reduction="mean" if loss_mask is None else "none" + ) + else: + predicted_log_probability = torch.nn.functional.log_softmax(logits_scaled, dim=-1) + if target_format == TargetFormat.logits: + target_log_probability = torch.nn.functional.log_softmax(target, dim=-1) + elif target_format == TargetFormat.probabilities: + target_log_probability = target.log() + else: + target_log_probability = ( + torch.nn.functional.one_hot(target, num_classes=logits_scaled.size(-1)).add(1.0e-10).log() + ) + if entropy_loss_type == EntropyLossType.forward_kl: + loss = torch.nn.functional.kl_div( + predicted_log_probability, + target_log_probability, + reduction="batchmean" if loss_mask is None else "none", + log_target=True, + ) + elif entropy_loss_type == EntropyLossType.reverse_kl: + loss = torch.nn.functional.kl_div( + target_log_probability, + predicted_log_probability, + reduction="batchmean" if loss_mask is None else "none", + log_target=True, + ) + else: + raise NotImplementedError(entropy_loss_type) + if loss_mask is not None: + loss = loss.sum(dim=-1) + + if loss_mask is not None: + loss = (loss * loss_mask).mean() + + if grad_output is None: + grad = None + else: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.detach().to(logits.dtype) + return loss.detach_(), grad + + +@torch.compile +def _fused_softmax_base( + logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + logits = logits.float() + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + logits_max = torch.max(logits, dim=dim, keepdim=True)[0] + if group is not None: + all_reduce(logits_max, op=ReduceOp.MAX, group=group) + logits_norm = (logits - logits_max).float() + exp_logits = logits_norm.exp() + sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) + if group is not None: + all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) + return logits_norm, exp_logits, sum_exp_logits + + +def _fused_reverse_kl_base( + logits: torch.Tensor, + target: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, + temperature: float = 1.0, +): + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + predicted_log_probability = logits_norm - sum_exp_logits.log() + predicted_probability = exp_logits / sum_exp_logits + + if target_format == TargetFormat.logits: + target_logits_norm, _, sum_exp_target_logits = _fused_softmax_base( + target, logits_scale_factor / temperature, group + ) + target_log_probability = target_logits_norm - sum_exp_target_logits.log() + else: + target_log_probability = torch.log(target) + + # Compute loss terms: student_probs * log_ratio, then sum over vocab + # This is equivalent to kl_div(..., log_target=True) but more memory efficient + log_ratio = predicted_log_probability - target_log_probability + per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1) + if group is not None: + all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group) + + if grad_output is None: + grad = None + else: + # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) + # where E_q[log(q/p)] is the expected log ratio under the student distribution + grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output + + return per_sample_loss, grad + + +def _fused_cross_entropy_base( + logits: torch.Tensor, + target: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, + temperature: float = 1.0, + return_kl_loss: bool = False, +): + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + + if target_format == TargetFormat.logits: + target_logits_norm, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( + target, logits_scale_factor / temperature, group + ) + target = exp_logits_targets / sum_exp_target_logits + + # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + # KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities)) + if return_kl_loss: + if target_format == TargetFormat.logits: + target_log_probability = target_logits_norm - sum_exp_target_logits.log() + else: + target_log_probability = torch.log(target) + logits_norm = logits_norm - target_log_probability + predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + if group is not None: + # We need to sum the over the tensor-parallel group, + # but this is handled in the final averaging provided we multiply by the group size. + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + + per_sample_loss = sum_exp_logits.log() - predicted_logits + + if grad_output is None: + grad = None + else: + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + grad = (exp_logits - sum_exp_logits * target) * (grad_output / sum_exp_logits) + + return per_sample_loss, grad + + +def _fused_cross_entropy_base_from_labels( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + group: ProcessGroup | None = None, +): + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + + target = target.unsqueeze(-1) + + if group is None: + # Keep values within range for scatter and gather ops to work. + target = target * loss_mask.unsqueeze(-1) + target_mask = None + else: + # Mask the target (fused) + # TODO: Could mask earlier on cpu or overlap with reduce? + vocab_start_index = logits.size(-1) * group.rank() + target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) + target = (target - vocab_start_index) * target_mask + + # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + # KL loss is the same because P * log(P) == 0. + predicted_logits = logits_norm.gather(1, target) + if group is not None: + predicted_logits = target_mask * predicted_logits + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + per_sample_loss = sum_exp_logits.log() - predicted_logits + + if grad_output is None: + grad = None + else: + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + grad = exp_logits.scatter_add( + 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) + ) * (grad_output / sum_exp_logits) + + return per_sample_loss, grad + + +@torch.compile +def _fused_entropy_loss_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + entropy_loss_type: EntropyLossType, + group: ProcessGroup | None = None, + temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + A fused implementation of cross-entropy with torch compile. + It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, + but still suboptimal because it needs multiple kernels. + """ + grad_output = None if grad_output is None else grad_output / logits.size(0) * logits_scale_factor + if target_format == TargetFormat.labels: + assert entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl) + if loss_mask is None: + loss_mask = target >= 0 + per_sample_loss, grad = _fused_cross_entropy_base_from_labels( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + group, + ) + elif entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl): + per_sample_loss, grad = _fused_cross_entropy_base( + logits, + target, + grad_output, + logits_scale_factor, + target_format, + group, + temperature, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, + ) + elif entropy_loss_type == EntropyLossType.reverse_kl: + per_sample_loss, grad = _fused_reverse_kl_base( + logits, + target, + grad_output, + logits_scale_factor, + target_format, + group, + temperature, + ) + else: + raise NotImplementedError(entropy_loss_type) + + if loss_mask is not None: + per_sample_loss = per_sample_loss * loss_mask.unsqueeze(-1) + loss = per_sample_loss.mean() + + if grad is not None: + if loss_mask is not None: + grad = grad * loss_mask.unsqueeze(-1) + grad = grad.to(logits.dtype) + + return loss, grad + + +_CROSS_ENTROPY_IMPLEMENTATIONS = { + EntropyLossImplementation.torch: _torch_entropy_loss_forward_backward, + EntropyLossImplementation.fused: _fused_entropy_loss_forward_backward, + EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, +} + + +def entropy_loss_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + implementation: EntropyLossImplementation = EntropyLossImplementation.fused, + logits_scale_factor: float = 1.0, + temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Select the appropriate implementation of cross-entropy. + The triton implementation from the triton submodule is the fastest and recommended one. + It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, + which is faster and has a relatively small memory overhead. + """ + if target_format == TargetFormat.labels: + Assert.eq(target.shape, logits.shape[:-1]) + Assert.eq(target.dtype, torch.int64) + assert loss_mask is None + else: + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + if group: + Assert.eq(implementation, EntropyLossImplementation.fused) + return _fused_entropy_loss_forward_backward( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + group, + temperature, + ) + else: + return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + temperature=temperature, + ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 295cdb74d..709d0c52d 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,7 +1,8 @@ import torch -from fast_llm.functional.config import TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.utils import Assert @triton_jit() @@ -125,7 +126,8 @@ def triton_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, - teacher_softmax_temperature: float = 1.0, + entropy_loss_type: EntropyLossType, + temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -134,6 +136,7 @@ def triton_cross_entropy_forward_backward( TODO: Better handling of `grad_output = None` """ assert TritonConfig.TRITON_ENABLED + Assert.eq(entropy_loss_type, EntropyLossType.cross_entropy) # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() @@ -163,7 +166,7 @@ def triton_cross_entropy_forward_backward( assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, - target / teacher_softmax_temperature, + target / temperature, loss_mask, grad_logits, losses, diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index 44c2d2088..97e04de16 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -3,9 +3,9 @@ class AuxiliaryLoss(torch.autograd.Function): @staticmethod - def forward(ctx, scores: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa + def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa ctx.grad = torch.full_like(aux_loss, grad) - return scores + return input_ @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa @@ -13,13 +13,16 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: @torch.compile -def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> torch.Tensor: - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - return torch.mean(torch.logsumexp(logits, dim=-1) ** 2) +def z_loss( + logits: torch.Tensor, logits_scale_factor: float = 1.0, loss_mask: "torch.Tensor | None" = None +) -> torch.Tensor: + out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 + if loss_mask is not None: + out = out * loss_mask + return torch.mean(out) -def z_loss( +def auxiliary_z_loss( logits: torch.Tensor, z_loss_factor: float, training: bool, @@ -27,12 +30,41 @@ def z_loss( losses: dict | None = None, loss_name: str | None = None, logits_scale_factor: float = 1.0, + loss_mask: "torch.Tensor | None" = None, ) -> torch.Tensor: if losses is not None or (training and grad_scale is not None): - loss = calculate_z_loss(logits, logits_scale_factor=logits_scale_factor) + loss = z_loss(logits, logits_scale_factor, loss_mask) if losses is not None and loss_name is not None: losses[loss_name].append(loss.detach()) if training and grad_scale is not None: logits = AuxiliaryLoss.apply(logits, loss, z_loss_factor * grad_scale) return logits + + +def z_loss_forward_backward( + logits: torch.Tensor, + grad_output: float | None = None, + loss_mask: "torch.Tensor | None" = None, + logits_scale_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Compute z-loss and its gradient. + + Z-loss = mean(logsumexp(logits, dim=-1) ** 2) + + Returns: + loss: The z-loss value (unscaled) + grad: The gradient w.r.t. logits (scaled by grad_scale), or None if grad_scale is None + """ + + with torch.set_grad_enabled(grad_output is not None): + logits_ = logits.detach().requires_grad_(grad_output is not None) + loss = z_loss(logits_, logits_scale_factor, loss_mask) + if grad_output is None: + grad = None + else: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.detach().to(logits.dtype) + + return loss, grad diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index f5abd1f6d..3ae47c0a7 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -14,7 +14,6 @@ from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig -from fast_llm.layers.language_model.head import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -287,7 +286,7 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_definitions.append( LossDef( name=self._activation_distillation_loss_name, - formatted_name=_format_name(self._activation_distillation_loss_name), + formatted_name=self._activation_distillation_loss_name, count=count, ) ) diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 5cc351dac..fd3647389 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -13,7 +13,7 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, auxiliary_z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.decoder.mlp.mlp import MLPBase @@ -102,7 +102,7 @@ def _forward( # Apply z_loss if applicable if self._config.z_loss_coefficient > 0.0: - logits = z_loss( + logits = auxiliary_z_loss( logits, self._config.z_loss_coefficient, self.training, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 53dac2892..b09a354b5 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,11 +5,15 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.loss.config import ( + LanguageModelLabelEntropyLossConfig, + LanguageModelLossConfig, + LanguageModelLossKwargs, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -19,21 +23,21 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class LanguageModelKwargs(BlockKwargs): +class LanguageModelKwargs(LanguageModelLossKwargs): token_ids = "token_ids" position_ids = "position_ids" token_map = "token_map" sample_map = "sample_map" embedding_map = "embedding_map" # TODO: These are generic - labels = "labels" phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" loss_mask = "loss_mask" mask_inputs = "mask_inputs" +LM_HEAD_LOSS_NAME = "lm_head_loss" + + @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False @@ -135,44 +139,22 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) + losses: dict[str, LanguageModelLossConfig] = Field( + default_factory=dict, + desc="A dictionary of loss names and their configurations.", + hint=FieldHint.core, + ) # TODO: Cleanup output_weight: ParameterConfig = Field( desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - cross_entropy_implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - distillation_loss_implementation: DistillationLossImpl = Field( - default=DistillationLossImpl.cross_entropy, - desc="Implementation for the distillation cross-entropy computation.", - hint=FieldHint.performance, - ) - cross_entropy_splits: int | None = Field( - default=None, + cross_entropy_splits: int = Field( + default=1, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - logit_z_loss: float = Field( - default=0.0, - desc="Regularize the logits with Z-loss.", - doc="We recommend 1e-4 for stability, as used for training PaLM.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - language_model_loss_factor: float = Field( - default=None, - desc="Factor to scale the language modeling loss by when using distillation.", - hint=FieldHint.feature, - ) - distillation_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the distillation loss by when using distillation.", - hint=FieldHint.feature, - ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -181,29 +163,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - teacher_softmax_temperature: float = Field( - default=1.0, - desc="Divides distillation target logits by this factor.", - doc="Divides distillation target logits by this factor.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) - dpo_beta: float | None = Field( - default=1.0, - desc="Beta value for DPO loss.", - hint=FieldHint.feature, - ) - distillation_model: str | None = Field( - default=None, - desc="Name of the reference model to use for knowledge distillation." - "If provided, replace the loss with a distillation loss.", - hint=FieldHint.feature, - ) def get_layer( self, @@ -237,22 +196,16 @@ def layer_class(self) -> "type[LanguageModelHead]": def _validate(self) -> None: with self._set_implicit_default(): - if self.language_model_loss_factor is None: - if self.distillation_model is None: - self.language_model_loss_factor = 1.0 - else: - self.language_model_loss_factor = 0.0 + if not self.losses: + if "losses" not in self._explicit_fields: + self.losses = {"lm_loss": LanguageModelLabelEntropyLossConfig()} super()._validate() - assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + assert LM_HEAD_LOSS_NAME not in self.losses @property def max_prediction_distance(self) -> int: return 1 - @property - def enable_dpo(self) -> bool: - return self.dpo_reference_model is not None - @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 9f3b6506f..736d8faf0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,5 +1,4 @@ import abc -import functools import logging import typing @@ -9,26 +8,25 @@ from fast_llm.core.ops import gather_op, split_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward -from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( + LM_HEAD_LOSS_NAME, LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, LanguageModelKwargs, ) from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert, div, get_unique +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -69,11 +67,6 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - if prediction_distance > 0 and ( - self._config.distillation_model is not None or self._config.dpo_reference_model is not None - ): - raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") - Assert.in_range(prediction_distance, 0, prediction_heads) self._prediction_distance = prediction_distance self._prediction_heads = prediction_heads @@ -84,19 +77,9 @@ def __init__( self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_parallel_logits = self._sequence_parallel and not self._vocab_parallel - if self._config.cross_entropy_splits is not None and self._sequence_parallel: + if self._config.cross_entropy_splits > 1 and self._sequence_parallel: assert not self._vocab_parallel - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_implementation - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._vocab_parallel: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED and torch.cuda.is_available(): - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) self.final_norm = self._config.normalization.get_layer( @@ -113,6 +96,18 @@ def __init__( peft=self._peft, ) + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Add marginal compute? (loss) + return ( + 2 + * (config.forward + 2 * config.backward) + * (input_.global_shape if config.global_ else input_).numel() + * (self._vocab_dim.global_size if config.global_ else self._vocab_dim.size) + ) + + def get_output_weights(self) -> list[torch.Tensor]: + return [self.output_weights] + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -137,8 +132,6 @@ def forward( # TODO: Drop autograd entirely. # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) - if losses is not None and language_model_loss is not None: - losses[self._loss_name].append(language_model_loss.detach()) # TODO: Return the model output when needed. if self._is_last_head: # Last head should return the loss for backward. @@ -150,186 +143,114 @@ def forward( # MTP: Return shared_hidden to be used by the next head. return shared_hidden - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - # TODO: Add marginal compute? (loss) - return ( - 2 - * (config.forward + 2 * config.backward) - * (input_.global_shape if config.global_ else input_).numel() - * (self._vocab_dim.global_size if config.global_ else self._vocab_dim.size) - ) - def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: - targets = self._get_targets(kwargs) - input_ = input_.detach().requires_grad_(do_grad := targets is not None and self.training) + input_ = input_.detach().requires_grad_(self.training) with torch.enable_grad(): ln_output = self.final_norm(input_) - # Transormers expect normalized outputs for the last transformer layer, + # Transformers expect normalized outputs for the last transformer layer, # so we add the norm output to the hidden states. self._debug(ln_output, "final_norm", kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) - - grad_output = kwargs[LanguageModelKwargs.grad_output] / ( - self._parallel_dim.size if self._sequence_parallel_logits else 1 - ) - - output_weights = self.output_weights - loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), targets, output_weights, grad_output, kwargs, losses - ) - - if do_grad: - ln_output.backward(ln_output_grad) - return loss, input_.grad - else: + loss, ln_output_grad = self._logits_loss_forward_backward(ln_output.detach().flatten(0, -2), kwargs, losses) + if ln_output_grad is None: return loss, None - - def _get_targets( - self, kwargs: dict - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: - # Loss mask for distillation. (Labels are already masked.) - if self._config.enable_dpo: - dpo_target = kwargs.get(LanguageModelKwargs.labels) - lm_target = None - distillation_target = None - loss_mask = None else: - dpo_target = None - if self._config.distillation_model is None: - distillation_target, loss_mask = None, None - else: - # Target is reference model logits. - distillation_target = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - if loss_mask is not None: - loss_mask = loss_mask.flatten() - - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - else: - lm_target = None - - targets = (dpo_target, lm_target, distillation_target, loss_mask) - if self._sequence_parallel_logits: - targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] - if not any(target is not None for target in targets): - # Simplify so we don't have to check every time. - targets = None - return targets - - def get_output_weights(self) -> list[torch.Tensor]: - return [self.output_weights] + ln_output.backward(ln_output_grad.view_as(ln_output)) + return loss, input_.grad - def _logits_cross_entropy_forward_backward_split( + def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None, - weight: torch.Tensor, - grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._config.cross_entropy_splits is None or targets is None: - loss, logit_input_grad = self._logits_cross_entropy_forward_backward( - input_, targets, weight, grad_output, kwargs, losses + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + if loss_mask is not None: + loss_mask = loss_mask.flatten() + if self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) + + if not self.training: + logits, _ = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs, return_logits=True) + # TODO: Make a proper way of returning the model output. + logits = logits.detach() + if kwargs.get("global_logits"): + if self._vocab_parallel: + logits = gather_op(logits, self._parallel_dim.group, 2) + elif self._sequence_parallel_logits: + logits = gather_op( + logits, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 + ) + kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = ( + logits.detach() ) - if targets is None: - # TODO: Make a proper way of returning the model output. - loss = loss.detach() - if kwargs.get("global_logits"): - if self._vocab_parallel: - loss = gather_op(loss, self._parallel_dim.group, 2) - elif self._sequence_parallel_logits: - loss = gather_op( - loss, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 - ) - kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss - return None, None + return None, None + elif self._config.cross_entropy_splits == 1: + losses_, input_grad = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs) else: - loss = None - # TODO MTP: allow a cross_entropy_splits that is not a divisor of the sequence length - grad_output /= self._config.cross_entropy_splits - logit_input = input_.flatten(0, -2) - if self.training: - logit_input_grad = torch.empty_like(logit_input) - else: - logit_input_grad = None - split_size = div( - get_unique(target.size(0) for target in targets if target is not None), - self._config.cross_entropy_splits, - ) + input_grad = torch.empty_like(input_) tensors_split = [ - [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *targets, logit_input_grad] + ( + [None] * self._config.cross_entropy_splits + if tensor is None + else tensor.chunk(self._config.cross_entropy_splits) + ) + for tensor in [input_, loss_mask, input_grad] ] - for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): - loss_, grad_ = self._logits_cross_entropy_forward_backward( - logit_input_, - targets_, - weight, - grad_output, + for split_index, (partial_input_, loss_mask_, input_grad_) in enumerate(zip(*tensors_split, strict=True)): + partial_losses_, grad_ = self._logits_loss_forward_backward_partial( + partial_input_, + loss_mask_, kwargs, + split_index=split_index, ) # TODO: Avoid copy with explicit out argument. - if self.training: - logit_input_grad_.copy_(grad_) - loss = loss_ if loss is None else loss + loss_ - del grad_, loss_ - loss_count = (self._config.cross_entropy_splits or 1) * ( - self._parallel_dim.size if self._sequence_parallel_logits else 1 + input_grad_.copy_(grad_) + if split_index == 0: + losses_ = partial_losses_ + else: + for name in self._config.losses: + losses_[name] += partial_losses_[name] + + loss: torch.Tensor = sum( + (self.config.losses[name].weight * self._loss_coefficient / self._config.cross_entropy_splits) * loss_ + for name, loss_ in losses_.items() ) - if loss_count != 1: - loss.div_(loss_count) if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, group=self._parallel_dim.group) - return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None - - def _logits_cross_entropy_forward_backward( + all_reduce(loss, op=ReduceOp.AVG, group=self._parallel_dim.group) + + if losses is not None: + losses[self.get_full_loss_name(LM_HEAD_LOSS_NAME)].append(loss) + if len(self._config.losses) > 1: + for name, loss_ in losses_.items(): + if self._config.cross_entropy_splits != 1: + loss_ /= self._config.cross_entropy_splits + if self._sequence_parallel_logits: + # TODO: Async + all_reduce(loss_, op=ReduceOp.AVG, group=self._parallel_dim.group) + losses[name].append(loss_) + + return loss, input_grad + + def _logits_loss_forward_backward_partial( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], - weight: torch.Tensor, - grad_output: float, + loss_mask: torch.Tensor | None, kwargs: dict, - losses: dict | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + split_index: int = 0, + return_logits: bool = False, + ) -> tuple[dict[str, torch.Tensor] | torch.Tensor, torch.Tensor | None]: group = self._parallel_dim.group if self._vocab_parallel else None logits, context = output_parallel_linear_forward( input_=input_, - weight=weight, + weight=self.output_weights, bias=None, group=group, sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) - if self._config.logit_z_loss > 0.0: - logits = z_loss( - logits, - self._config.logit_z_loss, - self.training, - grad_output, - losses, - self._z_loss_name, - logits_scale_factor=self._config.logits_scale_factor, - ) - sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q if LanguageModelKwargs.hidden_dims in kwargs: batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] @@ -342,171 +263,62 @@ def _logits_cross_entropy_forward_backward( dims = None self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor) - if targets is None: - return logits * self._config.logits_scale_factor, None - dpo_target, lm_target, distillation_target, loss_mask = targets - - if dpo_target is not None: - dpo_loss, dpo_grad = compute_dpo_loss( + if return_logits: + return logits, None + + losses, grad = {}, None + for loss_name, loss_config in self._config.losses.items(): + # losses are returned unscaled but the grads are already scaled + # TODO: ====== grad_output can't be None? + grad_output = kwargs.get(LanguageModelKwargs.grad_output) + if grad_output is not None: + grad_output = ( + grad_output + * self._loss_coefficient + * loss_config.weight + / (self._parallel_dim.size if self._sequence_parallel_logits else 1) + / self._config.cross_entropy_splits + ) + loss, grad_ = loss_config.get_loss( logits, - dpo_target, - kwargs.get(f"{self._config.dpo_reference_model}_logits"), - kwargs[LanguageModelKwargs.chosen_spans], - kwargs[LanguageModelKwargs.rejected_spans], - self._config.dpo_beta, - grad_output * self._loss_coefficient, - ) - else: - dpo_loss, dpo_grad = None, None - - if lm_target is not None: - lm_loss, lm_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - lm_target, - None, + loss_mask, + grad_output=None if grad_output == 0.0 else grad_output, group=group, - grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, - implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.labels, + kwargs=kwargs, + prediction_distance=self._prediction_distance, + prediction_heads=self._prediction_heads, + split_index=split_index, + num_splits=self._config.cross_entropy_splits, + sequence_parallel_logits=self._sequence_parallel_logits, ) - lm_loss = lm_loss * self._config.language_model_loss_factor - else: - lm_loss, lm_grad = None, None - - if distillation_target is not None and self._config.distillation_loss_factor > 0.0: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: - distillation_loss, distillation_grad = reverse_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) + losses[loss_name] = loss.detach() + if grad_ is not None: + # TODO: Accumulate grads in-place to reduce memory and compute overhead. + grad = grad_ if grad is None else grad + grad_ - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: - distillation_loss, distillation_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.logits, - ) - else: - raise ValueError( - f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" - ) - distillation_loss = distillation_loss * self._config.distillation_loss_factor - else: - distillation_loss, distillation_grad = None, None - - # TODO: de-allocate earlier. - del logits - - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) - - # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) - if self.training and losses is not None: - if dpo_loss is not None: - losses[self._dpo_loss_name].append(dpo_loss.detach()) - if self._config.distillation_model is not None and distillation_loss is not None: - losses[self._distillation_loss_name].append(distillation_loss.detach()) - if self._config.distillation_model is not None and lm_loss is not None: - losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) - - return loss, output_parallel_linear_backward(grad, context) if self.training else None - - @functools.cached_property - def _loss_name(self) -> str: - name = "language_model_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _z_loss_name(self) -> str: - name = "z_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_language_model_loss_name(self) -> str: - name = "distillation_language_model_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_loss_name(self) -> str: - name = "distillation_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name + return losses, output_parallel_linear_backward(grad, context) if self.training else None def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] - if self._config.logit_z_loss: - loss_defs.append( - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) - ) - if self._config.enable_dpo: - loss_defs.append( - LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) - ) - - if self._config.distillation_model is not None: - loss_defs.append( + return [ + LossDef(name=(name := self.get_full_loss_name(LM_HEAD_LOSS_NAME)), formatted_name=name, count=count), + *( LossDef( - name=self._distillation_loss_name, - formatted_name=_format_name(self._distillation_loss_name), + name=(name_ := self.get_full_loss_name(name)), + formatted_name=name_, count=count, + dtype=DataType.float32, ) - ) - if self._config.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=self._distillation_language_model_loss_name, - formatted_name=_format_name(self._distillation_language_model_loss_name), - count=count, - ) - ) + for name, loss_config in self._config.losses.values() + ), + ] - return loss_defs + def get_full_loss_name(self, name) -> str: + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name @property def heads(self): # For compatibility with MTP. return [self] - - -def _format_name(name: str) -> str: - return name.replace("_", " ") - - -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: - tensors = [tensor for tensor in tensors if tensor is not None] - if len(tensors) > 1: - return sum(tensors) - elif len(tensors) == 1: - return tensors[0] - else: - raise RuntimeError() diff --git a/fast_llm/layers/language_model/loss/__init__.py b/fast_llm/layers/language_model/loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py new file mode 100644 index 000000000..551554132 --- /dev/null +++ b/fast_llm/layers/language_model/loss/config.py @@ -0,0 +1,306 @@ +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import torch + + +class LanguageModelLossKwargs(BlockKwargs): + labels = "labels" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + + +@config_class(registry=True) +class LanguageModelLossConfig(Config): + _abstract: typing.ClassVar[bool] = True + + weight: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + def get_name(self, prediction_distance: int = 0) -> str: + return self._name if prediction_distance == 0 else f"{self._name}_{prediction_distance}" + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + raise NotImplementedError() + + +@config_class(dynamic_type={LanguageModelLossConfig: "label"}) +class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): + _abstract: typing.ClassVar[bool] = False + + loss_type: EntropyLossType = Field( + default=EntropyLossType.cross_entropy, + desc="Type of loss to use.", + hint=FieldHint.core, + ) + + implementation: EntropyLossImplementation = Field( + default=EntropyLossImplementation.auto, + desc="Loss implementation.", + hint=FieldHint.performance, + ) + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + import torch + + from fast_llm.functional.entropy_loss import entropy_loss_forward_backward + + labels = kwargs[LanguageModelLossKwargs.labels] + + # MTP: Shift the labels + if prediction_heads > 1: + sequence_q_length = labels.size(1 - kwargs[LanguageModelLossKwargs.sequence_first]) + 1 - prediction_heads + if LanguageModelLossKwargs.sequence_q_dim in kwargs: + Assert.eq(sequence_q_length, kwargs[LanguageModelLossKwargs.sequence_q_dim].size) + label_slice = slice(prediction_distance, prediction_distance + sequence_q_length) + labels = labels[label_slice] if kwargs[LanguageModelLossKwargs.sequence_first] else labels[:, label_slice] + + labels = labels.flatten() + + # Get the local chunk. + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + labels = split_op(labels, group, 0) + + # Get the chunk for the current split. + if num_splits > 1: + labels = labels.chunk(num_splits)[split_index] + + implementation = self.implementation + if implementation == EntropyLossImplementation.auto: + if ( + TritonConfig.TRITON_ENABLED + and torch.cuda.is_available() + and group is None + and self.loss_type == EntropyLossType.cross_entropy + ): + implementation = EntropyLossImplementation.triton + else: + implementation = EntropyLossImplementation.fused + + return entropy_loss_forward_backward( + logits, + labels, + None, # Labels are already masked + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + target_format=TargetFormat.labels, + entropy_loss_type=self.loss_type, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "distillation"}) +class LanguageModelDistillationLossConfig(LanguageModelLossConfig): + _abstract: typing.ClassVar[bool] = False + + loss_type: EntropyLossType = Field( + default=EntropyLossType.cross_entropy, + desc="Type of loss to use.", + hint=FieldHint.core, + ) + implementation: EntropyLossImplementation = Field( + default=EntropyLossImplementation.auto, + desc="Loss implementation.", + hint=FieldHint.performance, + ) + reference_model: str = Field( + desc="Name of the reference model for knowledge distillation.", + hint=FieldHint.feature, + ) + temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.entropy_loss import entropy_loss_forward_backward + + if prediction_distance > 0: + raise NotImplementedError() + + reference_model_logits = kwargs[f"{self.reference_model}_logits"].flatten(0, -2) + + # Get the local chunk. + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + reference_model_logits = split_op(reference_model_logits, group, 0) + + # Get the chunk for the current split. + if num_splits > 1: + reference_model_logits = reference_model_logits.chunk(num_splits)[split_index] + + implementation = ( + EntropyLossImplementation.fused + if self.implementation == EntropyLossImplementation.auto + else self.implementation + ) + return entropy_loss_forward_backward( + logits, + reference_model_logits, + loss_mask, + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + temperature=self.temperature, + target_format=TargetFormat.logits, + entropy_loss_type=self.loss_type, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class LanguageModelDPOLossConfig(LanguageModelLossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + reference_model: str = Field( + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.dpo import compute_dpo_loss + + if num_splits > 1: + raise NotImplementedError() + if prediction_distance > 0: + raise NotImplementedError() + + if logits_scale_factor != 1.0: + # TODO: Make more efficient. + logits = logits * logits_scale_factor + + reference_model_logits = kwargs[f"{self.reference_model}_logits"].flatten(0, -2) + target = kwargs[LanguageModelLossKwargs.labels] + + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + reference_model_logits = split_op(reference_model_logits, group, 0) + target = split_op(target, group, 0) + + chosen_spans = kwargs[LanguageModelLossKwargs.chosen_spans] + rejected_spans = kwargs[LanguageModelLossKwargs.rejected_spans] + + return compute_dpo_loss( + logits=logits, + targets=target, + reference_model_logits=reference_model_logits, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + beta=self.beta, + grad_output=grad_output, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) +class LanguageModelZLossConfig(LanguageModelLossConfig): + """Z-loss regularization to prevent overconfidence.""" + + _abstract: typing.ClassVar[bool] = False + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.layers.common.auxiliary_loss import z_loss_forward_backward + + # TODO: Support vocab_parallel + assert group is None + + return z_loss_forward_backward( + logits, + grad_output, + loss_mask, + logits_scale_factor, + ) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py deleted file mode 100644 index 420316ce3..000000000 --- a/tests/functional/test_cross_entropy.py +++ /dev/null @@ -1,211 +0,0 @@ -import os -import sys -import tempfile -import traceback -import typing - -import pytest -import torch - -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward -from fast_llm.utils import Assert - - -def _get_cross_entropy_inputs( - num_columns: int, loss_masking: bool, target_format: TargetFormat -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - device = "cuda" if torch.cuda.is_available() else "cpu" - # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) / 3 - loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None - if target_format == TargetFormat.labels: - target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device=device) - logits = torch.nn.functional.one_hot(target, num_columns) + logits_var - if loss_masking: - logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) - loss_mask = None - else: - target = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) - logits = target + logits_var - if target_format == TargetFormat.probabilities: - target = torch.softmax(target, -1) - return logits, target, loss_mask - - -def _compare_cross_entropy_outputs( - loss: torch.Tensor, - ref_loss: torch.Tensor, - has_grad: bool, - grad: torch.Tensor | None, - ref_grad: torch.Tensor | None, - threshold=1e-5, -): - Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) - if has_grad: - Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) - else: - assert grad is None - assert ref_grad is None - - -@pytest.mark.slow -@pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), - ( - (8192, 1.0, 1.0, False), # Simple - (5000, 1.0, 1.0, False), # Not a power of 2 - (5000, None, 1.0, False), # No grad - (5000, 1.0, 4.0, False), # Loss scaling - (5000, 4.0, 1.0, False), # Grad scaling - (5000, 1.0, 1.0, True), # Loss masking - (65536, 1.0, 1.0, False), # Max block size - (65537, 1.0, 1.0, False), # Above max block size - ), -) -@pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) -def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format): - # TODO: Test tensor-parallel implementation. - logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) - kwargs = { - "logits": logits, - "target": target, - "loss_mask": loss_mask, - "grad_output": grad_output, - "logits_scale_factor": logits_scale_factor, - "target_format": target_format, - } - # Torch serves as the reference implementation. - out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) - out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) - - # TODO: Why is the error so high with logit scaling? - threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 - _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) - - if not torch.cuda.is_available(): - return - assert TritonConfig.TRITON_ENABLED - if num_columns > 65536: - with pytest.raises(AssertionError): - cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - else: - out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - _compare_cross_entropy_outputs( - out_triton, out_torch, grad_output is not None, grad_triton, grad_torch, threshold - ) - - -def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): - # Manual reference: sum over vocab then average over valid tokens. - logits = logits.detach().requires_grad_() - per_sample = torch.nn.functional.kl_div( - torch.log_softmax(target.float(), dim=-1), - torch.log_softmax(logits.float(), dim=-1), - reduction="none", - log_target=True, - ).sum(dim=-1) - if loss_mask is not None: - per_sample = per_sample * loss_mask - output = per_sample.mean() - output.backward() - return output, logits.grad - - -@pytest.mark.slow -# TODO: Support the same parameterization as above in the reference implementation. -@pytest.mark.parametrize("loss_masking", [False, True]) -@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) -def test_reverse_kl(loss_masking, target_format): - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) - out, grad = reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1.0, - target_format=TargetFormat.logits, - ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) - - -def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): - try: - torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) - fn_args[0](rank, torch.distributed.group.WORLD, *fn_args[1:]) - finally: - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - - -def _spawn_dist(world_size: int, *fn_args): - """ - Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. - """ - with tempfile.NamedTemporaryFile(delete=False) as tmp: - init_method = f"file://{tmp.name}" - - try: - torch.multiprocessing.spawn( - _mp_worker, - args=(world_size, init_method, fn_args), - nprocs=world_size, - join=True, - start_method="spawn", - ) - finally: - if os.path.exists(tmp.name): - os.remove(tmp.name) - - -def _compare_parallel_cross_entropy( - rank: int, - group: torch.distributed.ProcessGroup, - target_format: TargetFormat, - function: typing.Callable, - loss_masking: bool, -): - # Ensure all workers have the same inputs. - torch.manual_seed(0) - world_size = torch.distributed.get_world_size(group) - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - - out, grad = function( - logits=logits.chunk(world_size, 1)[rank], - target=target.chunk(world_size, 1)[rank], - loss_mask=loss_mask, - grad_output=1, - group=group, - target_format=target_format, - ) - - out_ref, grad_ref = function( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1, - target_format=target_format, - ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) - - -def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): - success = True - for function in (reverse_kl_forward_backward, cross_entropy_forward_backward): - for target_format in (TargetFormat.logits,): - for loss_masking in [False, True]: - try: - _compare_parallel_cross_entropy(rank, group, target_format, function, loss_masking) - except Exception: - print( - f" >>>>>> Failed {function.__name__}, target_format, use_mask={loss_masking}", file=sys.stderr - ) - traceback.print_exc() - success = False - if not success: - raise RuntimeError("Test failed") - - -@pytest.mark.slow -def test_distillation_losses(): - _spawn_dist(2, compare_parallel_cross_entropy) diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py new file mode 100644 index 000000000..4f3f5b6cb --- /dev/null +++ b/tests/functional/test_entropy_loss.py @@ -0,0 +1,179 @@ +import pathlib + +import pytest +import torch + +from fast_llm.engine.distributed.config import DistributedBackend +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.entropy_loss import entropy_loss_forward_backward +from fast_llm.utils import Assert +from tests.utils.subtest import DistributedTestContext + + +def _get_cross_entropy_inputs( + num_columns: int, loss_masking: bool, target_format: TargetFormat +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = "cuda" if torch.cuda.is_available() else "cpu" + # We want something moderately close to the target for the test to be meaningful + logits_var = torch.randn(256, num_columns, dtype=torch.float32, device=device) / 3 + loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None + if target_format == TargetFormat.labels: + target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device=device) + logits = torch.nn.functional.one_hot(target, num_columns) + logits_var + if loss_masking: + logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) + loss_mask = None + else: + target = torch.randn(256, num_columns, dtype=torch.float32, device=device) + logits = target + logits_var + if target_format == TargetFormat.probabilities: + target = torch.softmax(target, -1) + return logits, target, loss_mask + + +def _compare_entropy_loss_outputs( + loss: torch.Tensor, + ref_loss: torch.Tensor, + has_grad: bool, + grad: torch.Tensor | None, + ref_grad: torch.Tensor | None, + threshold=1e-5, + loss_min_threshold=1e-6, +): + Assert.rms_close_relative(loss, ref_loss, threshold, loss_min_threshold) + if has_grad: + Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) + else: + assert grad is None + assert ref_grad is None + + +@pytest.mark.slow +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), + ( + (8192, 1.0, 1.0, False), # Simple + (5000, 1.0, 1.0, False), # Not a power of 2 + (5000, None, 1.0, False), # No grad + (5000, 1.0, 4.0, False), # Loss scaling + (5000, 4.0, 1.0, False), # Grad scaling + (5000, 1.0, 1.0, True), # Loss masking + (65536, 1.0, 1.0, False), # Max block size + (65537, 1.0, 1.0, False), # Above max block size + ), +) +@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) +def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type): + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + pytest.skip(reason="Not implemented") + # TODO: Test tensor-parallel implementation. + logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) + kwargs = { + "logits": logits, + "target": target, + "loss_mask": loss_mask, + "grad_output": grad_output, + "logits_scale_factor": logits_scale_factor, + "target_format": target_format, + "entropy_loss_type": entropy_loss_type, + } + # Torch serves as the reference implementation. + out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch) + out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused) + + # TODO: Why is the error so high with loss masking for reverse KL? + _compare_entropy_loss_outputs( + out_fused, + out_torch, + grad_output is not None, + grad_fused, + grad_torch, + loss_min_threshold=2e-4 if entropy_loss_type == EntropyLossType.reverse_kl and loss_masking else 5e-6, + ) + + if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available(): + # Triton implementation only supports cross-entropy. + return + assert TritonConfig.TRITON_ENABLED + if num_columns > 65536: + with pytest.raises(AssertionError): + entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.triton) + else: + out_triton, grad_triton = entropy_loss_forward_backward( + **kwargs, implementation=EntropyLossImplementation.triton + ) + _compare_entropy_loss_outputs(out_triton, out_torch, grad_output is not None, grad_triton, grad_torch) + + +def _entropy_loss_distributed( + target_format: TargetFormat, + entropy_loss_type: EntropyLossType, + loss_masking: bool, + group: torch.distributed.ProcessGroup, +): + # Ensure all workers have the same inputs. + torch.manual_seed(0) + rank = group.rank() + world_size = group.size() + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + + kwargs = { + "loss_mask": loss_mask, + "grad_output": 1.0, + "target_format": target_format, + "implementation": EntropyLossImplementation.fused, + "entropy_loss_type": entropy_loss_type, + } + out_ref, grad_ref = entropy_loss_forward_backward(logits, target, **kwargs) + + out, grad = entropy_loss_forward_backward( + logits.chunk(world_size, 1)[rank], + target if target_format == TargetFormat.labels else target.chunk(world_size, 1)[rank], + group=group, + **kwargs, + ) + _compare_entropy_loss_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) + + +def _run_entropy_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path): + for entropy_loss_type in EntropyLossType: + for target_format in TargetFormat: + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + continue + for loss_masking in [False, True]: + name = f"{entropy_loss_type}_{target_format}_{loss_masking}" + with test_context.subtest(base_path, name, 2) as subtest: + if subtest.do_run: + _entropy_loss_distributed(target_format, entropy_loss_type, loss_masking, test_context.group) + + +@pytest.mark.slow +def test_entropy_loss_distributed_dependency(): + # Mock test so the distributed subtest are placed in the same dependency group. + pass + + +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) +def test_run_entropy_loss_distributed(run_parallel_script, result_path): + run_parallel_script( + _run_entropy_loss_distributed, + (result_path / "test_entropy_loss",), + world_size=2, + backend=DistributedBackend.gloo, + use_cpu=True, # Disable device count check. + ) + + +# We don't want to depend on `test_run_entropy_loss_distributed` because we still want to run this in cas of failure. +# This should still run after `test_run_entropy_loss_distributed` +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) +@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) +@pytest.mark.parametrize("loss_masking", (False, True)) +def test_entropy_loss_distributed(result_path, report_subtest, target_format, entropy_loss_type, loss_masking): + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + pytest.skip(reason="Not implemented") + report_subtest(result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 1a607b246..9aa53fcc4 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -1,347 +1,317 @@ +import collections +import dataclasses import typing import pytest import torch -from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage - -def _reverse_kl_loss( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - teacher_softmax_temperature: float = 1.0, -): - scaled_target = torch.clamp(target / teacher_softmax_temperature, min=-50, max=50) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) - - with torch.enable_grad(): - # Use log_softmax for consistency instead of _fused_softmax - logits = torch.clamp(logits, min=-50, max=50) - student_log_probs = torch.log_softmax(logits, dim=-1) - if loss_mask is None: - loss = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="batchmean", - log_target=True, - ) - else: - # Apply loss mask - this requires some reshaping - loss_per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() - return loss - - -def _lm_head( - input_: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - *, - # config:LanguageModelBaseConfig, - rms_weight: torch.Tensor, - logit_weight: torch.Tensor, - grad_output: float = 1.0, - logit_scale_factor: float = 1.0, - logit_z_loss=0.0, - distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, -): - hidden = torch.rms_norm( - input_.to(rms_weight.dtype), - input_.shape[-1:], - rms_weight, - 1e-5, - ) - logits = torch.nn.functional.linear(hidden, logit_weight).float() - - if distillation_loss_implementation == DistillationLossImpl.reverse_kl: - Assert.eq(logits.shape, target.shape) - loss = _reverse_kl_loss( - (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask - ) - loss.backward(torch.full_like(loss, grad_output)) - return loss, None - - if logit_scale_factor != 1.0: - logits *= logit_scale_factor - z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None - if target.ndim == logits.ndim: - loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" - ) - if loss_mask is not None: - loss = loss * loss_mask.flatten() - loss = loss.mean() - else: - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - loss.backward(torch.full_like(loss, grad_output)) - return loss, z_loss - - SEQUENCE_LENGTH = 200 BATCH_SIZE = 4 HIDDEN_SIZE = 256 VOCAB_SIZE = 500 -@pytest.mark.slow -@pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) -@pytest.mark.parametrize( - ("config_dict", "distributed_config_dict", "loss_masking", "prediction_heads"), - ( - ({}, {}, False, 1), - ({}, {"compute_dtype": DataType.bfloat16}, False, 1), - ({"embeddings": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False, 1), - ({"sequence_first": True}, {}, False, 1), - ({"head": {"logit_z_loss": 1e-3}}, {}, False, 1), - ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), - ({"tied_embedding_weight": True}, {}, False, 1), - ({}, {}, False, 2), - ({}, {}, True, 1), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - "language_model_loss_factor": 1.0, - } - }, - {}, - True, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - True, - 1, - ), - ), -) -def test_lm_head( - cross_entropy_impl: CrossEntropyImpl, - config_dict: dict[str, typing.Any], - distributed_config_dict: dict[str, typing.Any], - loss_masking: bool, - prediction_heads: int, -): - if cross_entropy_impl in (CrossEntropyImpl.auto, CrossEntropyImpl.triton) and not torch.cuda.is_available(): - pytest.skip("Cuda is not available") - head_config = { - "cross_entropy_implementation": cross_entropy_impl, - "normalization": {"type": "rms_norm", "implementation": "auto" if torch.cuda.is_available() else "torch"}, - } - config = GPTBaseModelConfig.from_dict( - { - "decoder": {"num_blocks": 0}, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "head": ( - head_config - if prediction_heads == 1 - else { - "type": "multi_token_prediction", - "head": head_config, - "prediction_heads": prediction_heads, - } - ), - "hidden_size": HIDDEN_SIZE, - }, - config_dict, - update_type=UpdateType.update, - ) - head_config: LanguageModelHeadConfig = config.head if prediction_heads == 1 else config.head.head +@dataclasses.dataclass +class LMHeadTestConfig: + name: str + label_loss: bool | float = False + distillation_loss: bool | float = False + z_loss: bool | float = False + logits_scale_factor: float = 1.0 + compute_dtype: DataType = DataType.float32 + full_precision_residual: bool = False + sequence_first: bool = False + loss_masking: bool = False + prediction_heads: int = 1 + tied_embedding_weight: bool = False + cross_entropy_splits: int = 1 + + @property + def actual_label_loss(self): + return ( + True + if self.label_loss is False and self.distillation_loss is False and self.z_loss is False + else self.label_loss + ) + + def get_config(self) -> GPTModelConfig: + head_config = { + "normalization": {"type": "rms_norm"}, + "logits_scale_factor": self.logits_scale_factor, + "cross_entropy_splits": self.cross_entropy_splits, + } + losses = {} + if self.label_loss is not False: + losses["label"] = {"type": "label"} + if isinstance(self.label_loss, float): + losses["label"]["weight"] = self.label_loss + if self.distillation_loss is not False: + losses["distillation"] = {"type": "distillation", "reference_model": "distillation"} + if isinstance(self.distillation_loss, float): + losses["distillation"]["weight"] = self.distillation_loss + if self.z_loss is not False: + losses["z_loss"] = {"type": "z_loss"} + if isinstance(self.z_loss, float): + losses["z_loss"]["weight"] = self.z_loss + if losses: + head_config["losses"] = losses - model, distributed = get_base_model( - GPTModelConfig.from_dict( + return GPTModelConfig.from_dict( { - "base_model": config, - "distributed": {**distributed_config_dict, "use_cuda": torch.cuda.is_available()}, + "base_model": { + "decoder": {"num_blocks": 0}, + "embeddings": {"vocab_size": VOCAB_SIZE, "full_precision_residual": self.full_precision_residual}, + "head": ( + head_config + if self.prediction_heads == 1 + else { + "type": "multi_token_prediction", + "head": head_config, + "prediction_heads": self.prediction_heads, + } + ), + "hidden_size": HIDDEN_SIZE, + "tied_embedding_weight": self.tied_embedding_weight, + }, + "distributed": {"compute_dtype": self.compute_dtype, "use_cuda": torch.cuda.is_available()}, }, ) - ) - sequence_first = config.sequence_first or ( - head_config.cross_entropy_splits is not None and head_config.cross_entropy_splits > 1 - ) - input_ = torch.randn( - (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), - dtype=( - distributed.config.optimization_dtype.torch - if config.embeddings.full_precision_residual - else distributed.config.compute_dtype.torch - ), - device=distributed.device, - requires_grad=True, - ) - label_shape = ( - (SEQUENCE_LENGTH + config.head.max_prediction_distance - 1, BATCH_SIZE) - if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.head.max_prediction_distance - 1) - ) - if loss_masking: - loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) - else: - loss_mask = None - kwargs = { - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.grad_output: 1.0, - } - if head_config.distillation_model is None: - target = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.int64, - device=distributed.device, + def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: + device = "cuda" if torch.cuda.is_available() else "cpu" + input_ = torch.randn( + ( + (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) + if self.sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE) + ), + dtype=(torch.float32 if self.full_precision_residual else self.compute_dtype.torch), + device=device, + requires_grad=True, + ) + label_shape = ( + (SEQUENCE_LENGTH + self.prediction_heads - 1, BATCH_SIZE) + if self.sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH + self.prediction_heads - 1) ) - if loss_mask is not None: - target *= loss_mask - - kwargs[LanguageModelKwargs.labels] = target - else: - assert config.head.max_prediction_distance == 1 - target = torch.randn( - input_.shape[:-1] + (VOCAB_SIZE,), - dtype=input_.dtype, - device=distributed.device, + kwargs: dict[str, typing.Any] = { + AttentionKwargs.sequence_first: self.sequence_first, + AttentionKwargs.grad_output: 1.0, + } + if self.loss_masking: + kwargs[LanguageModelKwargs.loss_mask] = torch.randint(0, 2, label_shape, dtype=torch.bool, device=device) + if self.actual_label_loss is not False: + labels = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=device, + ) + if LanguageModelKwargs.loss_mask in kwargs: + labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels) + kwargs[LanguageModelKwargs.labels] = labels + + if self.distillation_loss is not False: + assert self.prediction_heads == 1 + kwargs[f"distillation_logits"] = torch.randn( + input_.shape[:-1] + (VOCAB_SIZE,), + dtype=input_.dtype, + device=device, + ) + return input_, kwargs + + def get_reference_outputs( + self, + head: LanguageModelHead, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + tied_logit_weight: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: + # Get reference outputs and grads + logit_weight = ( + (head.output_weights if tied_logit_weight is None else tied_logit_weight).detach().requires_grad_() ) - kwargs[f"{head_config.distillation_model}_logits"] = target - if loss_mask is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask + normalization_weight = head.final_norm.weight.detach().requires_grad_() + input_ = input_.detach().requires_grad_() + + hidden = torch.rms_norm(input_.to(normalization_weight.dtype), input_.shape[-1:], normalization_weight, 1e-5) + logits = torch.nn.functional.linear(hidden, logit_weight).float() + + if self.logits_scale_factor is not None: + logits = logits * self.logits_scale_factor + + total_loss = 0 + losses = {} - if config.tied_embedding_weight or config.head.max_prediction_distance > 1: - logit_weight = torch.nn.Parameter( + if self.actual_label_loss is not False: + if self.sequence_first: + labels = kwargs[LanguageModelKwargs.labels][ + head._prediction_distance : head._prediction_distance + logits.size(0) + ] + else: + labels = kwargs[LanguageModelKwargs.labels][ + :, head._prediction_distance : head._prediction_distance + logits.size(1) + ] + label_loss = torch.nn.functional.cross_entropy( + logits.flatten(0, -2), labels.flatten(), reduction="none" + ).mean() + losses["label"] = label_loss.detach() + total_loss = total_loss + float(self.actual_label_loss) * label_loss + + if self.distillation_loss is not False: + distillation_loss = torch.nn.functional.cross_entropy( + logits.flatten(0, -2), + torch.softmax(kwargs[f"distillation_logits"].flatten(0, -2), -1), + reduction="none", + ) + if LanguageModelKwargs.loss_mask in kwargs: + distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask].flatten() + distillation_loss = distillation_loss.mean() + losses["distillation"] = distillation_loss.detach() + total_loss = total_loss + float(self.distillation_loss) * distillation_loss + + if self.z_loss is not False: + z_loss = torch.logsumexp(logits, dim=-1) ** 2 + if LanguageModelKwargs.loss_mask in kwargs: + z_loss = z_loss * kwargs[LanguageModelKwargs.loss_mask] + z_loss = z_loss.mean() + losses["z_loss"] = z_loss.detach() + total_loss = total_loss + float(self.z_loss) * z_loss + + total_loss.backward() + + if len(losses) > 1: + losses[LM_HEAD_LOSS_NAME] = total_loss.detach() + else: + losses = {LM_HEAD_LOSS_NAME: total_loss.detach()} + + if head._prediction_distance > 0: + losses = {f"{name}_{head._prediction_distance}": loss for name, loss in losses.items()} + + return total_loss.detach(), input_.grad, logit_weight.grad, normalization_weight.grad, losses + + +_lm_head_test_configs = ( + # TODO: Test DPO loss. + # TODO: Add more configs + # TODO: Add distributed test + LMHeadTestConfig("default"), + LMHeadTestConfig("bfloat16", compute_dtype=DataType.bfloat16), + LMHeadTestConfig("full_precision_residual", full_precision_residual=True), + LMHeadTestConfig("sequence_first", sequence_first=True), + LMHeadTestConfig("logit_scaling", logits_scale_factor=5.0), + LMHeadTestConfig("tied_embedding_weight", tied_embedding_weight=True), + LMHeadTestConfig("multi_token_prediction", prediction_heads=2), + LMHeadTestConfig("cross_entropy_splits", cross_entropy_splits=2, sequence_first=True), + LMHeadTestConfig("loss_masking", loss_masking=True), + LMHeadTestConfig("label_loss", label_loss=True), + LMHeadTestConfig("distillation_loss", distillation_loss=True), + LMHeadTestConfig("distillation_loss_masked", distillation_loss=True, loss_masking=True), + LMHeadTestConfig("z_loss", z_loss=True), + LMHeadTestConfig("z_loss_masked", z_loss=True, loss_masking=True), + LMHeadTestConfig("label_and_distillation_loss", label_loss=True, distillation_loss=True), + LMHeadTestConfig("label_and_z_loss_weighted", label_loss=True, z_loss=0.5), + LMHeadTestConfig("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0), +) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "test_config", + [ + pytest.param(_lm_head_test_config, id=_lm_head_test_config.name) + for _lm_head_test_config in _lm_head_test_configs + ], +) +def test_lm_head(test_config): + model_config = test_config.get_config() + model, distributed = get_base_model(model_config) + input_, kwargs = test_config.get_inputs() + + tied_logit_weight = ( + torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ).normal_(config.hidden_size**-0.5) + ).normal_(HIDDEN_SIZE**-0.5) ) - else: - logit_weight = None + if test_config.tied_embedding_weight or test_config.prediction_heads > 1 + else None + ) for prediction_distance, head in enumerate(model.head.heads): # Prepare the LM head Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - is_duplicate = config.tied_embedding_weight or prediction_distance > 0 + is_duplicate = test_config.tied_embedding_weight or prediction_distance > 0 stage = get_stage( [head], distributed, tied_parameter_duplicates=[head.output_weights.tensor_name] if is_duplicate else [], - tied_parameter_duplicate_buffers={head.output_weights.tensor_name: logit_weight} if is_duplicate else {}, + tied_parameter_duplicate_buffers=( + {head.output_weights.tensor_name: tied_logit_weight} if is_duplicate else {} + ), # Names must be kept as-is for tied weights. set_names=False, ) - # Get reference outputs and grads - if is_duplicate: - logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) - logit_weight.param_grad_is_zero = True - else: - logit_weight = head.output_weights - - ref_input = input_.detach().requires_grad_() - ref_rms_weight = head.final_norm.weight.detach().requires_grad_() - ref_logit_weight = logit_weight.detach().requires_grad_() - - ref_loss, ref_z_loss = _lm_head( - ref_input, - ( - target[prediction_distance : prediction_distance + SEQUENCE_LENGTH] - if sequence_first - else target[:, prediction_distance : prediction_distance + SEQUENCE_LENGTH] - ), - loss_mask, - rms_weight=ref_rms_weight, - logit_weight=ref_logit_weight, - logit_scale_factor=head_config.logits_scale_factor, - logit_z_loss=head_config.logit_z_loss, - distillation_loss_implementation=head_config.distillation_loss_implementation, + ref_total_loss, ref_input_grad, ref_logit_weight_grad, ref_normalization_weight_grad, ref_losses = ( + test_config.get_reference_outputs( + head, input_, kwargs, tied_logit_weight if prediction_distance > 0 else None + ) ) # Prepare LM head inputs if head._is_last_head: - head_input = input_ - output_grad = ref_input.new_full((), float("nan")) + head_input = input_.detach().requires_grad_() + output_grad = input_.new_full((), float("nan")) else: shared_hidden = torch.randn_like(input_) head_input = torch.stack((shared_hidden, input_.detach())).requires_grad_() output_grad = torch.randn_like(shared_hidden) - loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - loss_keys = {loss_name} - if ref_z_loss is not None: - loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head_config.distillation_model is not None: - loss_keys.add("distillation_loss") - if head_config.language_model_loss_factor > 0: - loss_keys.add("distillation_language_model_loss") - - Assert.eq( - {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, - {loss_key: 1 for loss_key in loss_keys}, - ) - losses = {key: [] for key in loss_keys} + if is_duplicate: + logit_weight = tied_logit_weight + logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) + logit_weight.param_grad_is_zero = True + else: + logit_weight = head.output_weights + + losses = collections.defaultdict(list) output, context = stage.forward(head_input, kwargs, losses) + print(losses) stage.backward(output_grad, context) - threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 - ) * head_config.logits_scale_factor + ) * test_config.logits_scale_factor - Assert.eq(losses.keys(), loss_keys) - Assert.eq(len(losses[loss_name]), 1) - if ref_z_loss is not None: - Assert.eq(len(losses["z_loss"]), 1) - Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) + Assert.eq(losses.keys(), ref_losses.keys()) + for name, loss in losses.items(): + assert len(loss) == 1, name + losses = {name: loss[0] for name, loss in losses.items()} - Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) + for name, loss in losses.items(): + Assert.rms_close_relative(loss, ref_losses[name], threshold, min_threshold, msg=name) if head._is_last_head: - Assert.all_equal(output, losses[loss_name][0]) + # Assert.all_equal(output, losses[lm_head_loss_name][0]) input_grad = head_input.grad else: Assert.all_equal(output, shared_hidden) shared_hidden_grad, input_grad = head_input.grad.unbind() Assert.all_equal(shared_hidden_grad, output_grad) - Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) - Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) - Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) + Assert.rms_close_relative(input_grad, ref_input_grad, threshold, min_threshold) + Assert.rms_close_relative( + head.final_norm.weight.grad_buffer, ref_normalization_weight_grad, threshold, min_threshold + ) + Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight_grad, threshold, min_threshold) diff --git a/tests/test_config.py b/tests/test_config.py index 4020b6fbc..2e900cb14 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,12 +148,15 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, + "head": {"losses": {"lm_loss": {"type": "cross_entropy"}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } else: expected_config["base_model"] = base_model_update + # added by default + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy"}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 84466fe29..4e1e9d507 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -565,6 +565,12 @@ def update_and_add_testing_config( "mistral_distill_logits", updates={ ("model", "base_model", "head", "distillation_model"): "teacher", + ("model", "base_model", "head", "losses"): { + "distillation_loss": { + "type": "reverse_kl_distillation", + "factor": 1.0, + }, + }, ("batch", "use_loss_masking_spans"): True, ("reference_models"): { "teacher": { @@ -587,32 +593,12 @@ def update_and_add_testing_config( skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) -update_and_add_testing_config( - "mistral_distill_logits", - "mistral_reverse_kl", - updates={ - ("model", "base_model", "head", "distillation_loss_implementation"): "reverse_kl", - }, - megatron_args=None, - checkpoint_format=MistralCheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, - ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, - ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 - }, - compare_factor=2, - # Modes not supported with reference models - skip_tests=("sdp", "ms", "pp"), -) update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ - ("model", "base_model", "head", "distillation_loss_factor"): 0.001, + ("model", "base_model", "head", "losses", "distillation_loss", "factor"): 0.001, ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, ("reference_models"): { diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index 4fea1fbba..a30440ad1 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -2,6 +2,7 @@ import json import logging import math +import os import pathlib import sys import time @@ -27,11 +28,13 @@ def __init__( timeout: float = 20.0, init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, + use_cpu: bool = False, ) -> None: self._do_capture = do_capture self._timeout = timeout self._init_method = init_method self._backend = backend + self._use_cpu = use_cpu def __enter__(self): if self._do_capture: @@ -40,7 +43,7 @@ def __enter__(self): ) self._pool = ProcessGroupPool( - timeout=self._timeout, init_method=self._init_method, backend=self._backend + timeout=self._timeout, init_method=self._init_method, backend=self._backend, use_cpu=self._use_cpu ).__enter__() self._rank = self._pool.rank self._world_size = self._pool.world_size @@ -48,12 +51,12 @@ def __enter__(self): self._configure_logging() self._group = self._pool.get_process_group(range(self._world_size), self._rank) # TODO: Barriers needed? - safe_barrier(self._group, "start") + safe_barrier(self._group, "start", device=self._pool.device) return self def __exit__(self, exc_type, exc_val, exc_tb): # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(self._group, "testing end") + safe_barrier(self._group, "testing end", device=self._pool.device) # Let pytest know how things went. # These should already be reported above, we repeat for convenience. if self._failures: @@ -75,6 +78,10 @@ def rank(self) -> int: def world_size(self) -> int: return self._world_size + @property + def group(self) -> torch.distributed.ProcessGroup: + return self._group + class DistributedSubtestContext: def __init__( self, test_context: "DistributedTestContext", base_path: pathlib.Path, name: str, num_gpus: int @@ -83,7 +90,7 @@ def __init__( self._path = base_path / name self._name = name self._num_gpus = num_gpus - self._skip = self._test_context._world_size < self._num_gpus + self._skip = self._test_context._world_size < self._num_gpus and not self._test_context._use_cpu self._do_run = self._test_context._rank < num_gpus and not self._skip self._do_capture = self._test_context._do_capture and self._do_run self._success = False @@ -131,10 +138,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): if (group := self._test_context._group) is not None: # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(group, self._name) - self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() + safe_barrier(group, self._name, device=self._test_context._pool.device) + self._success = ( + allreduce_scalar( + self._success, dtype=torch.int64, group=group, device=self._test_context._pool.device + ) + == group.size() + ) - if self._do_capture: + if self._do_capture and torch.cuda.is_available(): # Free resources to limit memory usage. report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) report["duration"] = time.perf_counter() - self._start @@ -233,13 +245,14 @@ def parallel_worker( init_method: str, backend: DistributedBackend, do_capture: bool, + use_cpu: bool, fn: typing.Callable, fn_args: typing.Sequence[typing.Any], ): DistributedConfig.default_rank = rank DistributedConfig.default_world_size = world_size DistributedConfig.default_local_world_size = world_size - with DistributedTestContext(do_capture, 60, init_method, backend) as test_context: + with DistributedTestContext(do_capture, 60, init_method, backend, use_cpu) as test_context: fn(test_context, *fn_args) @@ -251,14 +264,17 @@ def do_run_parallel_script( world_size: int, timeout: float = 240, backend: DistributedBackend = DistributedBackend.nccl, + use_cpu: bool = False, # Use CPU device in process group pool. May be used to disable device count check ): + if "PYTHONHASHSEED" not in os.environ: + os.environ["PYTHONHASHSEED"] = "0" if do_capture: logger.warning( "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." ) torch.multiprocessing.spawn( parallel_worker, - args=(world_size, f"tcp://localhost:{port}", backend, do_capture, fn, fn_args), + args=(world_size, f"tcp://localhost:{port}", backend, do_capture, use_cpu, fn, fn_args), nprocs=world_size, join=False, ).join(timeout, grace_period=5)