Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
import torch


# This fixture ensures the torch defaults don't get left in modified states between
# tests (e.g., when a test fails before restoring the original value), which
# can cause subsequent tests to fail.
@pytest.fixture(autouse=True)
def reset_torch_defaults():
orig_default_device = torch.get_default_device()
orig_default_dtype = torch.get_default_dtype()
yield
torch.set_default_dtype(orig_default_dtype)
torch.set_default_device(orig_default_device)
27 changes: 27 additions & 0 deletions tests/spatial/test_greenctx_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import create_greenctx_stream_by_value, get_sm_available
import utils

device = utils.get_device()

def test_green_ctx():
A = torch.randn(5120, 5120).to(device)
B = torch.randn(5120, 5120).to(device)
C = torch.matmul(A, B)
sm_counts = get_sm_available(0)
stream_group = create_greenctx_stream_by_value(sm_counts // 2, sm_counts // 2, 0)
with torch.Stream(stream_group[0]):
for _ in range(100):
result_0 = torch.matmul(A, B)
with torch.Stream(stream_group[1]):
for _ in range(100):
result_1 = torch.matmul(A, B)
torch.accelerator.synchronize()
assert torch.allclose(result_0, C)
assert torch.allclose(result_1, C)


if __name__ == "__main__":
pytest.main([__file__])
30 changes: 16 additions & 14 deletions tests/speculative/test_eagle_utils.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,47 @@
import pytest
import torch
import torch.nn.functional as F
import utils
from sgl_kernel import verify_tree_greedy

device = utils.get_device()


def test_verify_tree_greedy():
candidates = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12],
],
dtype=torch.int32,
device="cuda",
dtype=torch.int64,
device=device,
)
retrive_index = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
],
dtype=torch.int32,
device="cuda",
dtype=torch.int64,
device=device,
)
retrive_next_token = torch.tensor(
[
[1, 2, -1, 4, 5, -1],
[4, 2, 3, -1, 5, -1],
],
dtype=torch.int32,
device="cuda",
dtype=torch.int64,
device=device,
)
retrive_next_sibling = torch.tensor(
[
[-1, 3, -1, -1, -1, -1],
[-1, -1, -1, -1, 1, -1],
],
dtype=torch.int32,
device="cuda",
dtype=torch.int64,
device=device,
)

target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device)
target_logits[0, 0, 3] = 10
target_logits[0, 3, 4] = 10
target_logits[0, 4, 5] = 10
Expand All @@ -49,20 +52,19 @@ def test_verify_tree_greedy():
if torch.max(target_logits[i][j]) < 10:
target_logits[i][j][18] = 10

target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
target_predict = torch.argmax(target_logits, dim=-1)
predict_shape = (12,)

bs = candidates.shape[0]
num_spec_step = 4
num_draft_tokens = candidates.shape[1]

predicts = torch.full(
predict_shape, -1, dtype=torch.int32, device="cuda"
predict_shape, -1, dtype=torch.int32, device=device
) # mutable
accept_index = torch.full(
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
(bs, num_spec_step), -1, dtype=torch.int32, device=device
) # mutable
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device) # mutable

verify_tree_greedy(
predicts=predicts,
Expand Down
78 changes: 78 additions & 0 deletions tests/speculative/test_ngram_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import reconstruct_indices_from_tree_mask
import utils

device = utils.get_device()

def test_reconstruct_indices_from_tree_mask():
bs = 1
num_branch_token = 4
seq_lens = torch.tensor([12], device=device, dtype=torch.int64)

retrive_index = torch.full(
(bs, num_branch_token), -1, device=device, dtype=torch.int64
)
retrive_next_token = torch.full(
(bs, num_branch_token), -1, device=device, dtype=torch.int64
)
retrive_next_sibling = torch.full(
(bs, num_branch_token), -1, device=device, dtype=torch.int64
)
positions = torch.empty((bs * num_branch_token), device=device, dtype=torch.int64)

tree_mask = torch.tensor(
[
1,
0,
0,
0,
1,
1,
0,
0,
1,
0,
1,
0,
1,
0,
1,
1,
],
device=device,
dtype=torch.int32,
).to(torch.bool)

reconstruct_indices_from_tree_mask(
tree_mask,
seq_lens,
positions, # mutable
retrive_index, # mutable
retrive_next_token, # mutable
retrive_next_sibling, # mutable
bs,
num_branch_token,
)
# print(f"debug: \n\n{tree_mask=}, {retrive_index=}, {retrive_next_token=}, {retrive_next_sibling=}, {positions=}\n\n")
assert retrive_index.tolist() == [
[0, 1, 2, 3],
], f"{retrive_index=}"
assert retrive_next_token.tolist() == [
[1, -1, 3, -1],
], f"{retrive_next_token=}"
assert retrive_next_sibling.tolist() == [
[-1, 2, -1, -1],
], f"{retrive_next_sibling=}"
assert positions.tolist() == [
12,
13,
13,
14,
], f"{positions=}"


if __name__ == "__main__":
test_reconstruct_indices_from_tree_mask()
pytest.main([__file__])
15 changes: 10 additions & 5 deletions tests/speculative/test_speculative_sampling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
import torch
import torch.nn.functional as F
import utils
from sgl_kernel import tree_speculative_sampling_target_only

device = utils.get_device()

test_cases = [
(
1,
Expand Down Expand Up @@ -35,38 +38,38 @@ def test_tree_speculative_sampling_target_only(
"""
Tests the tree_speculative_sampling_target_only function using Pytest parameterization.
"""
device = "cuda"
device = device

candidates = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[7, 8, 9, 10, 11, 12],
],
dtype=torch.int32,
dtype=torch.int64,
device=device,
)
retrive_index = torch.tensor(
[
[0, 1, 2, 3, 4, 5],
[6, 7, 8, 9, 10, 11],
],
dtype=torch.int32,
dtype=torch.int64,
device=device,
)
retrive_next_token = torch.tensor(
[
[1, 2, -1, 4, 5, -1],
[4, 2, 3, -1, 5, -1],
],
dtype=torch.int32,
dtype=torch.int64,
device=device,
)
retrive_next_sibling = torch.tensor(
[
[-1, 3, -1, -1, -1, -1],
[-1, -1, -1, -1, 1, -1],
],
dtype=torch.int32,
dtype=torch.int64,
device=device,
)

Expand Down Expand Up @@ -95,6 +98,7 @@ def test_tree_speculative_sampling_target_only(
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32)

tree_speculative_sampling_target_only(
predicts=predicts,
Expand All @@ -105,6 +109,7 @@ def test_tree_speculative_sampling_target_only(
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
uniform_samples=coins,
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=threshold_single,
Expand Down
Loading
Loading