Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions csrc/selective_scan/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,8 @@

#pragma once

////////////////////////////////////////////////////////////////////////////////////////////////////

struct SSMScanParamsBase {
using index_t = uint64_t;

int batch, seqlen, n_chunks;
index_t a_batch_stride;
index_t b_batch_stride;
index_t out_batch_stride;

// Common data pointers.
void *__restrict__ a_ptr;
void *__restrict__ b_ptr;
void *__restrict__ out_ptr;
void *__restrict__ x_ptr;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct SSMParamsBase {
using index_t = uint32_t;
using index_t = uint64_t;

int batch, dim, seqlen, dstate, n_groups, n_chunks;
int dim_ngroups_ratio;
Expand Down
2 changes: 1 addition & 1 deletion csrc/selective_scan/selective_scan_bwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
scan_t *x = params.x_ptr == nullptr
? nullptr
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
: reinterpret_cast<scan_t *>(params.x_ptr) + (int64_t)(batch_id * params.dim + dim_id) * params.n_chunks * params.dstate;
float dD_val = 0;
float ddelta_bias_val = 0;

Expand Down
6 changes: 3 additions & 3 deletions csrc/selective_scan/selective_scan_fwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int batch_id = blockIdx.x;
const int dim_id = blockIdx.y;
const int group_id = dim_id / (params.dim_ngroups_ratio);
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + (int64_t)batch_id * params.u_batch_stride
+ (int64_t)dim_id * kNRows * params.u_d_stride;
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
+ dim_id * kNRows * params.u_d_stride;
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
+ dim_id * kNRows * params.delta_d_stride;
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (int64_t)(batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;

float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) {
Expand Down