Skip to content

Commit 29b84c1

Browse files
[Common] Fix NVFP4 tuned-kernel numerics (#2639)
* Fixed scaling-factor computation for FP32 to match the reference implementation. Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Uncommented the tuned kernel path Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 94ba75d commit 29b84c1

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,10 +1168,10 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
11681168
// TODO(Frank): Is there a better way to do this?
11691169
bool return_transpose = output->has_columnwise_data();
11701170

1171-
// if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) {
1172-
// quantize_transpose_tuned_1D(input, noop, output, quant_config, stream);
1173-
// return;
1174-
// }
1171+
if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) {
1172+
quantize_transpose_tuned_1D(input, noop, output, quant_config, stream);
1173+
return;
1174+
}
11751175

11761176
constexpr bool COMPUTE_ACTIVATIONS = false;
11771177
using ParamOP = Empty;

transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,24 @@ __device__ __forceinline__ float get_amax_of_pair(const IType2 pair) {
163163
template <typename SF_TYPE>
164164
__device__ __forceinline__ SF_TYPE
165165
compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) {
166-
constexpr float float_max = detail::TypeExtrema<SF_TYPE>::max;
167-
const float scale_rcp = fminf(S_enc / static_cast<float>(S_dec_block), float_max);
168-
return static_cast<SF_TYPE>(scale_rcp);
166+
NVTE_DEVICE_ERROR("Unsupported scaling-factor type. Only FP32 and BF16 are supported.");
167+
}
168+
169+
template <>
170+
__device__ __forceinline__ float compute_nvfp4_scaling_coefficient<float>(
171+
const nvfp4_scale_t S_dec_block, const float S_enc) {
172+
const float S_dec = 1.0f / S_enc;
173+
const float scale_rcp =
174+
fminf(1.0f / (static_cast<float>(S_dec_block) * S_dec), detail::TypeExtrema<float>::max);
175+
return scale_rcp;
176+
}
177+
178+
template <>
179+
__device__ __forceinline__ bf16
180+
compute_nvfp4_scaling_coefficient<bf16>(const nvfp4_scale_t S_dec_block, const float S_enc) {
181+
const float scale_rcp =
182+
fminf(S_enc / (static_cast<float>(S_dec_block)), detail::TypeExtrema<bf16>::max);
183+
return static_cast<bf16>(scale_rcp);
169184
}
170185

171186
template <bool USE_STOCHASTIC_ROUNDING, bool USE_FAST_MATH>

0 commit comments

Comments
 (0)