Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions bitsandbytes/backends/triton/kernels_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def quantize_fp4_blockwise_kernel(

packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
out_mask = out_offsets < n_elements // 2
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
out_mask = out_offsets < (n_elements - n_elements // 2)
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)


Expand Down Expand Up @@ -148,7 +149,8 @@ def quantize_nf4_blockwise_kernel(

packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
out_mask = out_offsets < n_elements // 2
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
out_mask = out_offsets < (n_elements - n_elements // 2)
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)


Expand Down Expand Up @@ -330,7 +332,14 @@ def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.const
# )
@triton.jit
def dequant_4bit_kernel(
a_ptr, c_ptr, quant_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
Expand All @@ -350,7 +359,7 @@ def dequant_4bit_kernel(

out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)


Expand All @@ -367,7 +376,13 @@ def dequant_4bit_kernel(
# )
@triton.jit
def dequant_fp4_kernel(
a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
a_ptr,
c_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
Expand All @@ -386,7 +401,7 @@ def dequant_fp4_kernel(

out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)


Expand All @@ -403,7 +418,13 @@ def dequant_fp4_kernel(
# )
@triton.jit
def dequant_nf4_kernel(
a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
a_ptr,
c_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
Expand All @@ -422,7 +443,7 @@ def dequant_nf4_kernel(

out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)


Expand All @@ -439,15 +460,16 @@ def dequantize_4bit_impl(
# Elements are in uint8 format, so interleaved
# so total amount of data is 2 * elem_count
number_of_paired_elements = A.numel()
num_output_elements = out.numel()
# we assume that split_size > quant_blocksize

SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
if quant_type == "fp4":
dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE)
else:
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE)


def dequantize_4bit_impl_passing_code(
Expand All @@ -459,12 +481,15 @@ def dequantize_4bit_impl_passing_code(
out: torch.Tensor,
) -> None:
number_of_paired_elements = A.numel()
num_output_elements = out.numel()
# we assume that split_size > quant_blocksize

SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_4bit_kernel[grid](A, out, code, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
dequant_4bit_kernel[grid](
A, out, code, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE
)


######################### Fallback dequantization functions #########################
Expand Down
3 changes: 2 additions & 1 deletion bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def quantize_4bit(
blocks = -(n // -(blocksize * 2))

absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
out = torch.empty((n - n // 2, 1), device=A.device, dtype=torch.uint8)

with torch_accelerator_module.device(A.device):
kernels_4bit.quantize_4bit_blockwise_triton(
Expand Down
15 changes: 8 additions & 7 deletions csrc/xpu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,21 @@ inline float dDequantizeNF4(unsigned char val) {

template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE>
SYCL_EXTERNAL void kDequantizeBlockwise<T, TILE_SIZE, NUM_PER_TH, DATA_TYPE>::operator()(sycl::nd_item<1> item) const {
const int base_idx = item.get_group(0) * TILE_SIZE;
size_t local_idx = item.get_local_id(0) * NUM_PER_TH;
const int64_t base_idx = static_cast<int64_t>(item.get_group(0)) * TILE_SIZE;
int64_t local_idx = static_cast<int64_t>(item.get_local_id(0)) * NUM_PER_TH;
float local_abs_max = -FLT_MAX;
int local_load_idx = 0;
int local_store_idx = 0;
int64_t local_load_idx = 0;
int64_t local_store_idx = 0;

uint8_t qvals[NUM_PER_TH];
T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)];

if (DATA_TYPE > 0) {
local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx);
local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2);
// Cast n to int64_t to avoid overflow for large n (same as CUDA)
local_load_idx = sycl::min(static_cast<int64_t>(TILE_SIZE), (static_cast<int64_t>(n) + 1) / 2 - base_idx);
local_store_idx = sycl::min(static_cast<int64_t>(TILE_SIZE * 2), static_cast<int64_t>(n) - base_idx * 2);
} else {
local_load_idx = sycl::min(TILE_SIZE, n - base_idx);
local_load_idx = sycl::min(static_cast<int64_t>(TILE_SIZE), static_cast<int64_t>(n) - base_idx);
local_store_idx = local_load_idx;
}

Expand Down
6 changes: 4 additions & 2 deletions csrc/xpu_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ void dequantizeBlockwise(
const int num_per_th = 4;
const int tile_size = workgroup_size * num_per_th;
if (DATA_TYPE > 0) {
const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2);
// Upcast to int64 to avoid overflow for large n (same as CUDA)
const int workgroup_num = (static_cast<int64_t>(n) + tile_size * 2 - 1) / (tile_size * 2);
sycl::range<1> local_range{(size_t)workgroup_size};
sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
kDequantizeBlockwise<T, tile_size, num_per_th, DATA_TYPE> kfn(code, A, absmax, out, blocksize / 2, n);
sycl_kernel_submit<decltype(kfn), 1, 32>(
sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
);
} else {
const int workgroup_num = (n + tile_size - 1) / tile_size;
// Upcast to int64 to avoid overflow for large n (same as CUDA)
const int workgroup_num = (static_cast<int64_t>(n) + tile_size - 1) / tile_size;
sycl::range<1> local_range{(size_t)workgroup_size};
sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
kDequantizeBlockwise<T, tile_size, num_per_th, DATA_TYPE> kfn(code, A, absmax, out, blocksize, n);
Expand Down