Skip to content

Commit b9a8155

Browse files
Lint
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
1 parent 0a86b73 commit b9a8155

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,21 +1623,20 @@ def lowering(
16231623
scaling_mode=scaling_mode.value,
16241624
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
16251625
)
1626-
else:
1627-
ffi_name = GroupedGemmPrimitive.name
1628-
return jax.ffi.ffi_lowering(ffi_name)(
1629-
ctx,
1630-
*args,
1631-
M=M,
1632-
N=N,
1633-
K=K,
1634-
lhs_is_trans=lhs_is_trans,
1635-
rhs_is_trans=rhs_is_trans,
1636-
scaling_mode=scaling_mode.value,
1637-
has_bias=has_bias,
1638-
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
1639-
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
1640-
)
1626+
ffi_name = GroupedGemmPrimitive.name
1627+
return jax.ffi.ffi_lowering(ffi_name)(
1628+
ctx,
1629+
*args,
1630+
M=M,
1631+
N=N,
1632+
K=K,
1633+
lhs_is_trans=lhs_is_trans,
1634+
rhs_is_trans=rhs_is_trans,
1635+
scaling_mode=scaling_mode.value,
1636+
has_bias=has_bias,
1637+
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
1638+
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
1639+
)
16411640

16421641
@staticmethod
16431642
def impl(

transformer_engine/jax/csrc/extensions/gemm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data,
559559
return std::move(grouped_tensor_wrapper);
560560
}
561561

562+
// This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM.
562563
Error_Type GroupedGemmV2FFI(
563564
cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data,
564565
Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type alpha,

0 commit comments

Comments
 (0)