Skip to content

Commit 57b97d3

Browse files
SigureMoCopilot
andauthored
[Cherry-Pick][Optimization] Use a separate driver when using Triton with Paddle (#6897) (#7114)
--------- Co-authored-by: Copilot <[email protected]>
1 parent 6051d12 commit 57b97d3

File tree

11 files changed

+87
-2
lines changed

11 files changed

+87
-2
lines changed

fastdeploy/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import os
1818
import uuid
1919

20+
import paddle
21+
2022
# suppress warning log from paddlepaddle
2123
os.environ["GLOG_minloglevel"] = "2"
2224
# suppress log from aistudio
@@ -44,7 +46,13 @@
4446

4547
from fastdeploy.engine.sampling_params import SamplingParams
4648
from fastdeploy.entrypoints.llm import LLM
47-
from fastdeploy.utils import current_package_version, envs
49+
from fastdeploy.utils import _is_package_installed, current_package_version, envs
50+
51+
# We can use enable_compat only when torch is not installed, otherwise it will
52+
# cause some unexpected issues in triton kernels. We use enable_compat_on_triton_kernel
53+
# for these cases.
54+
if not _is_package_installed("torch"):
55+
paddle.enable_compat(scope={"triton"})
4856

4957
if envs.FD_DEBUG != 1:
5058
import logging

fastdeploy/model_executor/guided_decoding/kernels/xgrammar_apply_token_bitmask.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,15 @@
2323
try:
2424
import triton
2525
import triton.language as tl
26+
27+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
28+
enable_compat_on_triton_kernel,
29+
)
2630
except ImportError as err:
2731
raise ImportError("Triton is not installed") from err
2832

2933

34+
@enable_compat_on_triton_kernel
3035
@triton.jit
3136
def apply_token_bitmask_inplace_kernel(
3237
logits_ptr,

fastdeploy/model_executor/layers/backends/dcu/triton_moe_kernels.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import triton
1818
import triton.language as tl
1919

20+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
21+
enable_compat_on_triton_kernel,
22+
)
2023

24+
25+
@enable_compat_on_triton_kernel
2126
@triton.jit
2227
def fused_moe_kernel_paddle(
2328
a_ptr,

fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
import triton
1616
import triton.language as tl
1717

18+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
19+
enable_compat_on_triton_kernel,
20+
)
1821

22+
23+
@enable_compat_on_triton_kernel
1924
@triton.jit()
2025
def fused_moe_kernel_paddle(
2126
a_ptr,

fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import triton
1111
import triton.language as tl
1212

13+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
14+
enable_compat_on_triton_kernel,
15+
)
16+
1317
__all__ = [
1418
"set_batch_invariant_mode",
1519
"is_batch_invariant_mode_enabled",
@@ -33,6 +37,7 @@ def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, args: Dict[st
3337
return ret
3438

3539

40+
@enable_compat_on_triton_kernel
3641
@triton.jit
3742
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
3843
group_id = tile_id // num_pid_in_group
@@ -43,6 +48,7 @@ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
4348
return pid_m, pid_n
4449

4550

51+
@enable_compat_on_triton_kernel
4652
@triton.jit(launch_metadata=_matmul_launch_metadata)
4753
def matmul_kernel_persistent(
4854
a_ptr,
@@ -220,6 +226,7 @@ def grid(META):
220226
return c
221227

222228

229+
@enable_compat_on_triton_kernel
223230
@triton.jit
224231
def _log_softmax_kernel(
225232
input_ptr,
@@ -324,6 +331,7 @@ def log_softmax(input: paddle.Tensor, axis: int = -1) -> paddle.Tensor:
324331
return output.reshape(original_shape)
325332

326333

334+
@enable_compat_on_triton_kernel
327335
@triton.jit
328336
def mean_kernel(
329337
input_ptr,

fastdeploy/model_executor/layers/moe/routing_indices_cache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@
3636
from paddleformers.utils.log import logger
3737

3838
from fastdeploy.config import FDConfig, RoutingReplayConfig
39+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
40+
enable_compat_on_triton_kernel,
41+
)
3942

4043

44+
@enable_compat_on_triton_kernel
4145
@triton.jit
4246
def _save_routing_kernel(
4347
ROUTING_REPLAY_TABLE_PTR,

fastdeploy/model_executor/layers/moe/triton_moe_kernels.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import triton
1818
import triton.language as tl
1919

20+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
21+
enable_compat_on_triton_kernel,
22+
)
2023

24+
25+
@enable_compat_on_triton_kernel
2126
@triton.jit()
2227
def fused_moe_kernel_paddle(
2328
a_ptr,

fastdeploy/model_executor/layers/sample/logprobs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
import triton
1919
import triton.language as tl
2020

21+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
22+
enable_compat_on_triton_kernel,
23+
)
2124
from fastdeploy.platforms import current_platform
2225

2326

27+
@enable_compat_on_triton_kernel
2428
@triton.jit
2529
def count_greater_kernel(
2630
x_ptr, # [num_tokens, n_elements]

fastdeploy/model_executor/ops/triton_ops/repetition_early_stop_kernel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import triton
1818
import triton.language as tl
1919

20+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
21+
enable_compat_on_triton_kernel,
22+
)
2023

24+
25+
@enable_compat_on_triton_kernel
2126
@triton.jit
2227
def repetition_early_stopper_kernel(
2328
trunc_ptr, # float32[B, W]

fastdeploy/model_executor/ops/triton_ops/triton_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,45 @@
2424
from paddle.base.framework import OpProtoHolder
2525

2626
from fastdeploy import envs
27+
from fastdeploy.utils import _is_package_installed
2728

2829
compile_file = triton.__path__[0] + "/tools/compile.py"
2930
link_file = triton.__path__[0] + "/tools/link.py"
3031
python_path = sys.executable
3132

33+
if _is_package_installed("torch"):
34+
with paddle.use_compat_guard(enable=True, silent=True):
35+
from triton.runtime.driver import _create_driver
36+
37+
paddle_driver = _create_driver()
38+
39+
40+
def swap_driver_guard(fn):
41+
from triton.runtime.driver import driver
42+
43+
# A lightweight wrapper to enable compatibility for triton kernel
44+
def wrapped_fn(*args, **kwargs):
45+
driver.set_active(paddle_driver)
46+
try:
47+
return fn(*args, **kwargs)
48+
finally:
49+
driver.reset_active()
50+
51+
return wrapped_fn
52+
3253

3354
def enable_compat_on_triton_kernel(triton_kernel):
55+
# When torch is not installed, this decorator does not do anything, just return the original triton kernel.
56+
# Because the `paddle.enable_compat(scope={"triton"})` already enabled in `__init__.py`, it will take zero runtime overhead.
57+
if not _is_package_installed("torch"):
58+
return triton_kernel
59+
3460
class WrappedTritonKernel:
3561
def __init__(self, kernel):
3662
self.kernel = kernel
3763

3864
def __getitem__(self, index):
39-
return paddle.use_compat_guard(enable=True, silent=True)(self.kernel[index])
65+
return swap_driver_guard(self.kernel[index])
4066

4167
return WrappedTritonKernel(triton_kernel)
4268

0 commit comments

Comments
 (0)