Skip to content

Commit f6b5928

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9e37b4c commit f6b5928

File tree

3 files changed

+38
-41
lines changed

3 files changed

+38
-41
lines changed

transformer_engine/common/cast/nvfp4/core_nvfp4.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,23 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor(const floa
9090

9191
// Compute "correct" per-block encoding scaling factor
9292
template <typename SF_TYPE>
93-
__device__ __forceinline__ SF_TYPE
94-
compute_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) {
93+
__device__ __forceinline__ SF_TYPE compute_scaling_coefficient(const nvfp4_scale_t S_dec_block,
94+
const float S_enc) {
9595
NVTE_DEVICE_ERROR("Unsupported scaling-factor type. Only FP32 and BF16 are supported.");
9696
}
9797

9898
template <>
99-
__device__ __forceinline__ float compute_scaling_coefficient<float>(
100-
const nvfp4_scale_t S_dec_block, const float S_enc) {
99+
__device__ __forceinline__ float compute_scaling_coefficient<float>(const nvfp4_scale_t S_dec_block,
100+
const float S_enc) {
101101
const float S_dec = 1.0f / S_enc;
102102
const float scale_rcp =
103103
fminf(1.0f / (static_cast<float>(S_dec_block) * S_dec), detail::TypeExtrema<float>::max);
104104
return scale_rcp;
105105
}
106106

107107
template <>
108-
__device__ __forceinline__ bf16
109-
compute_scaling_coefficient<bf16>(const nvfp4_scale_t S_dec_block, const float S_enc) {
108+
__device__ __forceinline__ bf16 compute_scaling_coefficient<bf16>(const nvfp4_scale_t S_dec_block,
109+
const float S_enc) {
110110
const float scale_rcp =
111111
fminf(S_enc / (static_cast<float>(S_dec_block)), detail::TypeExtrema<bf16>::max);
112112
return static_cast<bf16>(scale_rcp);

transformer_engine/common/cast/nvfp4/specialized/group_quantize_transpose_nvfp4_tuned_1D.cuh

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@ static_assert(CHUNK_DIM_Y % TILE_DIM_Y == 0,
7777
static_assert(CHUNK_DIM_X % TILE_DIM_X == 0,
7878
"Chunk size X must be evenly divisible by the tile size X");
7979

80-
static_assert(TILE_DIM_Y % SCALE_DIM == 0,
81-
"Tile size Y must be evenly divisible by the scale dim");
82-
static_assert(TILE_DIM_X % SCALE_DIM == 0,
83-
"Tile size X must be evenly divisible by the scale dim");
80+
static_assert(TILE_DIM_Y % SCALE_DIM == 0, "Tile size Y must be evenly divisible by the scale dim");
81+
static_assert(TILE_DIM_X % SCALE_DIM == 0, "Tile size X must be evenly divisible by the scale dim");
8482

8583
constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y;
8684
constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X;
@@ -311,7 +309,7 @@ __device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_pt
311309
const float block_amax = get_amax_of_pair(thread_amax_2x);
312310

313311
const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
314-
const scaling_coeff_type SFcoefficient =
312+
const scaling_coeff_type SFcoefficient =
315313
core::compute_scaling_coefficient<scaling_coeff_type>(S_dec_b_fp8, S_enc_rowwise);
316314

317315
// Store scaling factors to SMEM buffer (R2S)
@@ -458,7 +456,8 @@ __device__ __forceinline__ JobDescriptor decode_job(
458456
job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset,
459457
static_cast<size_t>(ctaid_Y), first_logical_dim,
460458
last_logical_dim, offsets_ptr);
461-
job.rows = get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
459+
job.rows =
460+
get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
462461
job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
463462
return job;
464463
}
@@ -691,10 +690,11 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu
691690
size_t rowwise_scale_base_acc = 0;
692691
size_t colwise_scale_base_acc = 0;
693692

694-
for (size_t t = 0; t < num_tensors; ++t) {
695-
const size_t rows = get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
693+
for (size_t t = 0; t < num_tensors; ++t) {
694+
const size_t rows =
695+
get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
696696
const size_t cols = get_tensor_cols_num(t, shape_rep, last_logical_dim, last_dims_ptr);
697-
697+
698698
rowwise_scale_base_acc += rows * get_nvfp4_scale_stride(cols);
699699
colwise_scale_base_acc += cols * get_nvfp4_scale_stride(rows);
700700
rowwise_scale_base[t + 1] = rowwise_scale_base_acc;
@@ -808,18 +808,19 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu
808808
const size_t scale_stride_t = get_nvfp4_scale_stride(rows);
809809

810810
const size_t amax_rowwise_idx = (amax_rowwise_numel > 1) ? current_job.tensor_id : 0;
811-
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr || amax_rowwise_numel == 0)
812-
? 1.0f
813-
: core::compute_global_encode_scaling_factor(amax_rowwise_ptr[amax_rowwise_idx]);
811+
const float S_enc_rowwise =
812+
(amax_rowwise_ptr == nullptr || amax_rowwise_numel == 0)
813+
? 1.0f
814+
: core::compute_global_encode_scaling_factor(amax_rowwise_ptr[amax_rowwise_idx]);
814815
const size_t amax_colwise_idx = (amax_colwise_numel > 1) ? current_job.tensor_id : 0;
815-
const float S_enc_colwise = (amax_colwise_ptr == nullptr || amax_colwise_numel == 0)
816-
? S_enc_rowwise
817-
: core::compute_global_encode_scaling_factor(amax_colwise_ptr[amax_colwise_idx]);
816+
const float S_enc_colwise =
817+
(amax_colwise_ptr == nullptr || amax_colwise_numel == 0)
818+
? S_enc_rowwise
819+
: core::compute_global_encode_scaling_factor(amax_colwise_ptr[amax_colwise_idx]);
818820

819821
nvfp4_scale_t *const scales_rowwise = scales_ptr + rowwise_scale_base[current_job.tensor_id];
820-
nvfp4_scale_t *const scales_colwise = RETURN_TRANSPOSE
821-
? (scales_t_ptr + colwise_scale_base[current_job.tensor_id])
822-
: nullptr;
822+
nvfp4_scale_t *const scales_colwise =
823+
RETURN_TRANSPOSE ? (scales_t_ptr + colwise_scale_base[current_job.tensor_id]) : nullptr;
823824

824825
const CUtensorMap &tensor_map_input = g_tensor_maps_input[current_job.tensor_id];
825826
const CUtensorMap &tensor_map_output = g_tensor_maps_output[current_job.tensor_id];
@@ -1093,7 +1094,8 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n
10931094
const float *const amax_colwise_ptr =
10941095
reinterpret_cast<const float *>(output->columnwise_amax.dptr);
10951096
const size_t amax_rowwise_numel = output->amax.has_data() ? output->amax.numel() : 0;
1096-
const size_t amax_colwise_numel = output->columnwise_amax.has_data() ? output->columnwise_amax.numel() : 0;
1097+
const size_t amax_colwise_numel =
1098+
output->columnwise_amax.has_data() ? output->columnwise_amax.numel() : 0;
10971099

10981100
if (output->amax.has_data()) {
10991101
NVTE_CHECK(amax_rowwise_numel == 1 || amax_rowwise_numel == num_tensors,
@@ -1179,9 +1181,8 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n
11791181
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
11801182
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
11811183
kernel<<<grid, block_size, dshmem_size, stream>>>(
1182-
shape_rep,
1183-
num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr,
1184-
last_dims_ptr, scales_ptr, scales_t_ptr, noop_ptr, amax_rowwise_ptr,
1184+
shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr,
1185+
first_dims_ptr, last_dims_ptr, scales_ptr, scales_t_ptr, noop_ptr, amax_rowwise_ptr,
11851186
amax_colwise_ptr, amax_rowwise_numel, amax_colwise_numel, work_blocks_X,
11861187
work_blocks_Y, rng_state);
11871188
NVTE_CHECK_CUDA(cudaGetLastError());

transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ static_assert(TunableConfig::CHUNK_DIM_Y % TILE_DIM_Y == 0,
5757
static_assert(TunableConfig::CHUNK_DIM_X % TILE_DIM_X == 0,
5858
"Chunk size X must be evenly divisible by the tile size X");
5959

60-
static_assert(TILE_DIM_Y % SCALE_DIM == 0,
61-
"Tile size Y must be evenly divisible by the scale dim");
62-
static_assert(TILE_DIM_X % SCALE_DIM == 0,
63-
"Tile size X must be evenly divisible by the scale dim");
60+
static_assert(TILE_DIM_Y % SCALE_DIM == 0, "Tile size Y must be evenly divisible by the scale dim");
61+
static_assert(TILE_DIM_X % SCALE_DIM == 0, "Tile size X must be evenly divisible by the scale dim");
6462

6563
constexpr int TILES_Y = TunableConfig::CHUNK_DIM_Y / TILE_DIM_Y;
6664
constexpr int TILES_X = TunableConfig::CHUNK_DIM_X / TILE_DIM_X;
@@ -390,15 +388,13 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D
390388
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
391389

392390
// Compute a global encoding/decoding scaling factors for all S_dec_b
393-
const float S_enc_rowwise =
394-
(amax_rowwise_ptr == nullptr)
395-
? 1.0f
396-
: core::compute_global_encode_scaling_factor(*amax_rowwise_ptr);
397-
398-
const float S_enc_colwise =
399-
(amax_colwise_ptr == nullptr)
400-
? S_enc_rowwise
401-
: core::compute_global_encode_scaling_factor(*amax_colwise_ptr);
391+
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
392+
? 1.0f
393+
: core::compute_global_encode_scaling_factor(*amax_rowwise_ptr);
394+
395+
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
396+
? S_enc_rowwise
397+
: core::compute_global_encode_scaling_factor(*amax_colwise_ptr);
402398

403399
__shared__ uint64_t workID_mbar;
404400
__shared__ __uint128_t workID_response;

0 commit comments

Comments
 (0)