diff --git a/torchao/prototype/blockwise_fp8_training/kernels.py b/torchao/prototype/blockwise_fp8_training/kernels.py index bc1491421e..fab5c0e4f7 100644 --- a/torchao/prototype/blockwise_fp8_training/kernels.py +++ b/torchao/prototype/blockwise_fp8_training/kernels.py @@ -318,7 +318,7 @@ def triton_fp8_blockwise_act_quant_lhs_kernel( k_offs = pid_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) - x = tl.load(x_ptr + x_offs, mask=x_mask) + x = tl.load(x_ptr + x_offs, mask=x_mask, other=0.0) # Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1) amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, max=float("inf")).to(tl.float64) @@ -333,7 +333,8 @@ def triton_fp8_blockwise_act_quant_lhs_kernel( # Write reciprocal scales scale_offs = m_offs[:, None] * s_stride_dim_0 + pid_k * s_stride_dim_1 - tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale)) + scale_mask = m_offs[:, None] < M + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask) @triton_op("torchao::triton_fp8_blockwise_act_quant_lhs", mutates_args={}) @@ -412,7 +413,7 @@ def triton_fp8_blockwise_act_quant_rhs_kernel( k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS) x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) - x = tl.load(x_ptr + x_offs, mask=x_mask) + x = tl.load(x_ptr + x_offs, mask=x_mask, other=0.0) # Column-wise scales for RHS operand, shape (1, block_size) amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64) @@ -427,7 +428,8 @@ def triton_fp8_blockwise_act_quant_rhs_kernel( # Write scales scale_offs = pid_m * s_stride_dim_0 + k_offs[None, :] * s_stride_dim_1 - tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale)) + scale_mask = k_offs[None, :] < K + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask) @triton_op("torchao::triton_fp8_blockwise_act_quant_rhs", mutates_args={})