Skip to content

Commit 97ec071

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 47be9b2 commit 97ec071

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ __device__ __forceinline__ size_t get_grouped_scale_base_offset(
459459
const int64_t *const __restrict__ last_dims_ptr, const bool rowwise) {
460460
size_t scale_base = 0;
461461
for (size_t t = 0; t < tensor_id; ++t) {
462-
const size_t rows = get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
462+
const size_t rows =
463+
get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
463464
const size_t cols = get_tensor_cols_num(t, shape_rep, last_logical_dim, last_dims_ptr);
464465

465466
const size_t scale_rows = rowwise ? rows : cols;
@@ -500,7 +501,8 @@ __device__ __forceinline__ JobDescriptor decode_job(
500501
job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset,
501502
static_cast<size_t>(ctaid_Y), first_logical_dim,
502503
last_logical_dim, offsets_ptr);
503-
job.rows = get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
504+
job.rows =
505+
get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
504506
job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
505507
return job;
506508
}
@@ -720,12 +722,14 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu
720722

721723
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
722724

723-
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
724-
? 1.0f
725-
: core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
726-
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
727-
? S_enc_rowwise
728-
: core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
725+
const float S_enc_rowwise =
726+
(amax_rowwise_ptr == nullptr)
727+
? 1.0f
728+
: core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
729+
const float S_enc_colwise =
730+
(amax_colwise_ptr == nullptr)
731+
? S_enc_rowwise
732+
: core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
729733

730734
__shared__ size_t rowwise_scale_base[MAX_SUPPORTED_TENSOR_DESCRIPTORS + 1];
731735
__shared__ size_t colwise_scale_base[MAX_SUPPORTED_TENSOR_DESCRIPTORS + 1];
@@ -739,10 +743,11 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu
739743
size_t rowwise_scale_base_acc = 0;
740744
size_t colwise_scale_base_acc = 0;
741745

742-
for (size_t t = 0; t < num_tensors; ++t) {
743-
const size_t rows = get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
746+
for (size_t t = 0; t < num_tensors; ++t) {
747+
const size_t rows =
748+
get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
744749
const size_t cols = get_tensor_cols_num(t, shape_rep, last_logical_dim, last_dims_ptr);
745-
750+
746751
rowwise_scale_base_acc += rows * get_nvfp4_scale_stride(cols);
747752
colwise_scale_base_acc += cols * get_nvfp4_scale_stride(rows);
748753
rowwise_scale_base[t + 1] = rowwise_scale_base_acc;
@@ -856,9 +861,8 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu
856861
const size_t scale_stride_t = get_nvfp4_scale_stride(rows);
857862

858863
nvfp4_scale_t *const scales_rowwise = scales_ptr + rowwise_scale_base[current_job.tensor_id];
859-
nvfp4_scale_t *const scales_colwise = RETURN_TRANSPOSE
860-
? (scales_t_ptr + colwise_scale_base[current_job.tensor_id])
861-
: nullptr;
864+
nvfp4_scale_t *const scales_colwise =
865+
RETURN_TRANSPOSE ? (scales_t_ptr + colwise_scale_base[current_job.tensor_id]) : nullptr;
862866

863867
const CUtensorMap &tensor_map_input = g_tensor_maps_input[current_job.tensor_id];
864868
const CUtensorMap &tensor_map_output = g_tensor_maps_output[current_job.tensor_id];
@@ -1205,9 +1209,8 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n
12051209
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
12061210
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
12071211
kernel<<<grid, block_size, dshmem_size, stream>>>(
1208-
shape_rep,
1209-
num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr,
1210-
last_dims_ptr, scales_ptr, scales_t_ptr, noop_ptr, amax_rowwise_ptr,
1212+
shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr,
1213+
first_dims_ptr, last_dims_ptr, scales_ptr, scales_t_ptr, noop_ptr, amax_rowwise_ptr,
12111214
amax_colwise_ptr, work_blocks_X, work_blocks_Y, rng_state);
12121215
NVTE_CHECK_CUDA(cudaGetLastError());
12131216
});););

0 commit comments

Comments
 (0)