-
Notifications
You must be signed in to change notification settings - Fork 212
Description
Suggestion Description
We may want a feature support on pertensor act + perchannel weight int8 gemm support.
The request stems from a ci failure in vllm quantization/test_quark.py::test_quark_int8_w_per_tensor_a_per_tensor
With regard to int8 gemm, we share the same logic at Cutlass integration that expands the pertensor scale of fused_qkv linear weight to perchannel. The shape is changed from [3] to [out_channel]. The code is at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py#L53.
For clarification, we can denote the pertensor scale of qkv liner weight is [s_q, s_k, s_v]. Without reshaping, we need three int8_gemm kernels, denoted as [ h * w_q * s_q, h * w_k * s_k, h * w_v * s_v].
With reshape, the scale is expanded to
[ s_q, s_q,........... s_q, ......, s_k, s_k, ......s_k, s_v, .....s_v].
no. channels of q ch of k ch of v
Using the required pertensor + perchannels int8 kernel, only 1 kernel can finish the computation. Denoted as [ Wqkv * scale_per_ch * h]
Current status
The required case is explicitly refused at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py#L115
Operating System
No response
GPU
MI325
ROCm Component
aiter