diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index ecbe0d94120e..80023781f499 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -15,6 +15,12 @@ from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads from deepspeed.utils import groups +try: + from torchembed._triton import fused_rope_forward as _torchembed_rope_forward + _torchembed_available = True +except ImportError: + _torchembed_available = False + def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input): """ @@ -100,9 +106,15 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): # ideally t_pass is empty so rotary pos embedding is applied to all tensor t t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) + if _torchembed_available and t.is_cuda and t.device.type == 'cuda' and rot_dim % 2 == 0: + orig_shape = t.shape + t_2d = t.reshape(-1, orig_shape[-2], rot_dim) + t_out, _ = _torchembed_rope_forward(t_2d, t_2d, freqs_cos, freqs_sin) + t = t_out.reshape(orig_shape) + else: + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) return res diff --git a/setup.py b/setup.py index 67befe3da31b..d4819774ced0 100755 --- a/setup.py +++ b/setup.py @@ -92,6 +92,7 @@ def get_env_if_set(key, default: typing.Any = ""): 'sd': fetch_requirements('requirements/requirements-sd.txt'), 'triton': fetch_requirements('requirements/requirements-triton.txt'), 'deepcompile': fetch_requirements('requirements/requirements-deepcompile.txt'), + 'torchembed': ['torchembed'], } # Only install pynvml on nvidia gpus. diff --git a/tests/unit/sequence/test_apply_rotary_pos_emb.py b/tests/unit/sequence/test_apply_rotary_pos_emb.py new file mode 100644 index 000000000000..d9efcbb19397 --- /dev/null +++ b/tests/unit/sequence/test_apply_rotary_pos_emb.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.sequence.layer import apply_rotary_pos_emb, _rotate_half + + +def _make_freqs(seq_len, rot_dim, theta=10000.0, device="cpu"): + inv_freq = 1.0 / (theta ** (torch.arange(0, rot_dim, 2, device=device).float() / rot_dim)) + t = torch.arange(seq_len, device=device).float() + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin() + + +def _ref_apply_rotary(t, freqs_cos, freqs_sin): + rot_dim = freqs_cos.shape[-1] + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin) + return t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1) + + +@pytest.mark.parametrize("seq_len", [1, 17, 128]) +@pytest.mark.parametrize("dim", [32, 64, 128]) +@pytest.mark.parametrize("rotary_dim", [None, 16, 32, 64]) +def test_apply_rotary_pos_emb(seq_len, dim, rotary_dim): + rot_dim = rotary_dim if rotary_dim is not None else dim + if rot_dim > dim or rot_dim % 2 != 0: + pytest.skip("rotary_dim must be <= dim and even") + + t = torch.randn(seq_len, 4, dim) + freqs_cos, freqs_sin = _make_freqs(seq_len, rot_dim) + freqs_cos = freqs_cos[:, :rot_dim] + freqs_sin = freqs_sin[:, :rot_dim] + + result = apply_rotary_pos_emb(t, freqs_cos, freqs_sin) + expected = _ref_apply_rotary(t, freqs_cos, freqs_sin) + + assert torch.allclose(result, expected, atol=1e-6), ( + f"seq_len={seq_len}, dim={dim}, rot_dim={rot_dim}: max diff={((result - expected).abs().max()).item()}" + ) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_apply_rotary_pos_emb_grad_flow(dtype): + seq_len, n_heads, dim = 8, 4, 64 + rot_dim = 64 + t = torch.randn(seq_len, n_heads, dim, dtype=dtype, requires_grad=True) + freqs_cos, freqs_sin = _make_freqs(seq_len, rot_dim) + freqs_cos = freqs_cos[:, :rot_dim] + freqs_sin = freqs_sin[:, :rot_dim] + + out = apply_rotary_pos_emb(t, freqs_cos, freqs_sin) + loss = out.sum() + loss.backward() + + assert t.grad is not None + assert not torch.isnan(t.grad).any(), "NaNs in gradient" + assert t.grad.shape == t.shape, f"grad shape {t.grad.shape} != {t.shape}"