Skip to content

[Feature]: Per tensor act + Per channel weight int8 gemm #1766

@ZhiweiYan-96

Description

@ZhiweiYan-96

Suggestion Description

We may want a feature support on pertensor act + perchannel weight int8 gemm support.
The request stems from a ci failure in vllm quantization/test_quark.py::test_quark_int8_w_per_tensor_a_per_tensor

With regard to int8 gemm, we share the same logic at Cutlass integration that expands the pertensor scale of fused_qkv linear weight to perchannel. The shape is changed from [3] to [out_channel]. The code is at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py#L53.

For clarification, we can denote the pertensor scale of qkv liner weight is [s_q, s_k, s_v]. Without reshaping, we need three int8_gemm kernels, denoted as [ h * w_q * s_q, h * w_k * s_k, h * w_v * s_v].

With reshape, the scale is expanded to

[ s_q, s_q,........... s_q, ......, s_k, s_k, ......s_k, s_v, .....s_v]. 
     no. channels of q                         ch of k         ch of v

Using the required pertensor + perchannels int8 kernel, only 1 kernel can finish the computation. Denoted as [ Wqkv * scale_per_ch * h]

Current status

The required case is explicitly refused at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py#L115

Operating System

No response

GPU

MI325

ROCm Component

aiter

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions