Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads
from deepspeed.utils import groups

try:
from torchembed._triton import fused_rope_forward as _torchembed_rope_forward
_torchembed_available = True
except ImportError:
_torchembed_available = False


def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
"""
Expand Down Expand Up @@ -100,9 +106,15 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin)
if _torchembed_available and t.is_cuda and t.device.type == 'cuda' and rot_dim % 2 == 0:
orig_shape = t.shape
t_2d = t.reshape(-1, orig_shape[-2], rot_dim)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve the actual sequence dimension for fused RoPE

When torchembed is installed and the tensor is CUDA, this reshapes every input as if the sequence length were orig_shape[-2] (the head dimension for [seq,b,heads,dim] or [b,seq,heads,dim] tensors). The existing function contract says t is [seq_length, ..., dim], and callers such as fpdt_layer.py also pass [b,l,nh,hd], so this sends the fused kernel a fake sequence length equal to the number of heads while freqs_cos/sin still describe the real sequence length, producing wrong rotations or shape/indexing failures only in the optional fused path.

Useful? React with 👍 / 👎.

t_out, _ = _torchembed_rope_forward(t_2d, t_2d, freqs_cos, freqs_sin)
t = t_out.reshape(orig_shape)
else:
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin)

res = t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1)
return res
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_env_if_set(key, default: typing.Any = ""):
'sd': fetch_requirements('requirements/requirements-sd.txt'),
'triton': fetch_requirements('requirements/requirements-triton.txt'),
'deepcompile': fetch_requirements('requirements/requirements-deepcompile.txt'),
'torchembed': ['torchembed'],
}

# Only install pynvml on nvidia gpus.
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/sequence/test_apply_rotary_pos_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch

from deepspeed.sequence.layer import apply_rotary_pos_emb, _rotate_half


def _make_freqs(seq_len, rot_dim, theta=10000.0, device="cpu"):
inv_freq = 1.0 / (theta ** (torch.arange(0, rot_dim, 2, device=device).float() / rot_dim))
t = torch.arange(seq_len, device=device).float()
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()


def _ref_apply_rotary(t, freqs_cos, freqs_sin):
rot_dim = freqs_cos.shape[-1]
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs_cos) + (_rotate_half(t) * freqs_sin)
return t if t_pass.shape[-1] == 0 else torch.cat((t, t_pass), dim=-1)


@pytest.mark.parametrize("seq_len", [1, 17, 128])
@pytest.mark.parametrize("dim", [32, 64, 128])
@pytest.mark.parametrize("rotary_dim", [None, 16, 32, 64])
def test_apply_rotary_pos_emb(seq_len, dim, rotary_dim):
rot_dim = rotary_dim if rotary_dim is not None else dim
if rot_dim > dim or rot_dim % 2 != 0:
pytest.skip("rotary_dim must be <= dim and even")

t = torch.randn(seq_len, 4, dim)
freqs_cos, freqs_sin = _make_freqs(seq_len, rot_dim)
freqs_cos = freqs_cos[:, :rot_dim]
freqs_sin = freqs_sin[:, :rot_dim]

result = apply_rotary_pos_emb(t, freqs_cos, freqs_sin)
expected = _ref_apply_rotary(t, freqs_cos, freqs_sin)

assert torch.allclose(result, expected, atol=1e-6), (
f"seq_len={seq_len}, dim={dim}, rot_dim={rot_dim}: max diff={((result - expected).abs().max()).item()}"
)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_apply_rotary_pos_emb_grad_flow(dtype):
seq_len, n_heads, dim = 8, 4, 64
rot_dim = 64
t = torch.randn(seq_len, n_heads, dim, dtype=dtype, requires_grad=True)
freqs_cos, freqs_sin = _make_freqs(seq_len, rot_dim)
freqs_cos = freqs_cos[:, :rot_dim]
freqs_sin = freqs_sin[:, :rot_dim]

out = apply_rotary_pos_emb(t, freqs_cos, freqs_sin)
loss = out.sum()
loss.backward()

assert t.grad is not None
assert not torch.isnan(t.grad).any(), "NaNs in gradient"
assert t.grad.shape == t.shape, f"grad shape {t.grad.shape} != {t.shape}"
Loading