@@ -77,10 +77,8 @@ static_assert(CHUNK_DIM_Y % TILE_DIM_Y == 0,
7777static_assert(CHUNK_DIM_X % TILE_DIM_X == 0,
7878 "Chunk size X must be evenly divisible by the tile size X");
7979
80- static_assert(TILE_DIM_Y % SCALE_DIM == 0,
81- "Tile size Y must be evenly divisible by the scale dim");
82- static_assert(TILE_DIM_X % SCALE_DIM == 0,
83- "Tile size X must be evenly divisible by the scale dim");
80+ static_assert(TILE_DIM_Y % SCALE_DIM == 0, "Tile size Y must be evenly divisible by the scale dim");
81+ static_assert(TILE_DIM_X % SCALE_DIM == 0, "Tile size X must be evenly divisible by the scale dim");
8482
8583constexpr int TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y;
8684constexpr int TILES_X = CHUNK_DIM_X / TILE_DIM_X;
@@ -311,7 +309,7 @@ __device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_pt
311309 const float block_amax = get_amax_of_pair(thread_amax_2x);
312310
313311 const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
314- const scaling_coeff_type SFcoefficient =
312+ const scaling_coeff_type SFcoefficient =
315313 core::compute_scaling_coefficient<scaling_coeff_type>(S_dec_b_fp8, S_enc_rowwise);
316314
317315 // Store scaling factors to SMEM buffer (R2S)
@@ -458,7 +456,8 @@ __device__ __forceinline__ JobDescriptor decode_job(
458456 job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset,
459457 static_cast<size_t>(ctaid_Y), first_logical_dim,
460458 last_logical_dim, offsets_ptr);
461- job.rows = get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
459+ job.rows =
460+ get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
462461 job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
463462 return job;
464463}
@@ -691,10 +690,11 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu
691690 size_t rowwise_scale_base_acc = 0;
692691 size_t colwise_scale_base_acc = 0;
693692
694- for (size_t t = 0; t < num_tensors; ++t) {
695- const size_t rows = get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
693+ for (size_t t = 0; t < num_tensors; ++t) {
694+ const size_t rows =
695+ get_tensor_rows_num(t, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
696696 const size_t cols = get_tensor_cols_num(t, shape_rep, last_logical_dim, last_dims_ptr);
697-
697+
698698 rowwise_scale_base_acc += rows * get_nvfp4_scale_stride(cols);
699699 colwise_scale_base_acc += cols * get_nvfp4_scale_stride(rows);
700700 rowwise_scale_base[t + 1] = rowwise_scale_base_acc;
@@ -808,18 +808,19 @@ __global__ void __launch_bounds__(THREADS_NUM) group_quantize_transpose_nvfp4_tu
808808 const size_t scale_stride_t = get_nvfp4_scale_stride(rows);
809809
810810 const size_t amax_rowwise_idx = (amax_rowwise_numel > 1) ? current_job.tensor_id : 0;
811- const float S_enc_rowwise = (amax_rowwise_ptr == nullptr || amax_rowwise_numel == 0)
812- ? 1.0f
813- : core::compute_global_encode_scaling_factor(amax_rowwise_ptr[amax_rowwise_idx]);
811+ const float S_enc_rowwise =
812+ (amax_rowwise_ptr == nullptr || amax_rowwise_numel == 0)
813+ ? 1.0f
814+ : core::compute_global_encode_scaling_factor(amax_rowwise_ptr[amax_rowwise_idx]);
814815 const size_t amax_colwise_idx = (amax_colwise_numel > 1) ? current_job.tensor_id : 0;
815- const float S_enc_colwise = (amax_colwise_ptr == nullptr || amax_colwise_numel == 0)
816- ? S_enc_rowwise
817- : core::compute_global_encode_scaling_factor(amax_colwise_ptr[amax_colwise_idx]);
816+ const float S_enc_colwise =
817+ (amax_colwise_ptr == nullptr || amax_colwise_numel == 0)
818+ ? S_enc_rowwise
819+ : core::compute_global_encode_scaling_factor(amax_colwise_ptr[amax_colwise_idx]);
818820
819821 nvfp4_scale_t *const scales_rowwise = scales_ptr + rowwise_scale_base[current_job.tensor_id];
820- nvfp4_scale_t *const scales_colwise = RETURN_TRANSPOSE
821- ? (scales_t_ptr + colwise_scale_base[current_job.tensor_id])
822- : nullptr;
822+ nvfp4_scale_t *const scales_colwise =
823+ RETURN_TRANSPOSE ? (scales_t_ptr + colwise_scale_base[current_job.tensor_id]) : nullptr;
823824
824825 const CUtensorMap &tensor_map_input = g_tensor_maps_input[current_job.tensor_id];
825826 const CUtensorMap &tensor_map_output = g_tensor_maps_output[current_job.tensor_id];
@@ -1093,7 +1094,8 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n
10931094 const float *const amax_colwise_ptr =
10941095 reinterpret_cast<const float *>(output->columnwise_amax.dptr);
10951096 const size_t amax_rowwise_numel = output->amax.has_data() ? output->amax.numel() : 0;
1096- const size_t amax_colwise_numel = output->columnwise_amax.has_data() ? output->columnwise_amax.numel() : 0;
1097+ const size_t amax_colwise_numel =
1098+ output->columnwise_amax.has_data() ? output->columnwise_amax.numel() : 0;
10971099
10981100 if (output->amax.has_data()) {
10991101 NVTE_CHECK(amax_rowwise_numel == 1 || amax_rowwise_numel == num_tensors,
@@ -1179,9 +1181,8 @@ inline void group_quantize_transpose(const GroupedTensor *input, const Tensor *n
11791181 NVTE_CHECK_CUDA(cudaFuncSetAttribute(
11801182 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
11811183 kernel<<<grid, block_size, dshmem_size, stream>>>(
1182- shape_rep,
1183- num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr,
1184- last_dims_ptr, scales_ptr, scales_t_ptr, noop_ptr, amax_rowwise_ptr,
1184+ shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr,
1185+ first_dims_ptr, last_dims_ptr, scales_ptr, scales_t_ptr, noop_ptr, amax_rowwise_ptr,
11851186 amax_colwise_ptr, amax_rowwise_numel, amax_colwise_numel, work_blocks_X,
11861187 work_blocks_Y, rng_state);
11871188 NVTE_CHECK_CUDA(cudaGetLastError());
0 commit comments