Description
mlx has a GEMM function that they call mlx.core.addmm. We can dispatch our GEMM Op to it as follows:
@mlx_funcify.register(Gemm)
def mlx_funcify_Gemm(op, **kwargs):
# GEMM has signature:
# b * z + a * dot(x, y)
def gemm(z, a, x, y, b):
# mx.addmm has signature:
# alpha * (a @ b) + beta * c
return mx.addmm(z, x, y, alpha=a, beta=b)
return gemm
What's tricky is that the blas rewrite machinery is quite C specific. I'm not sure if we can register just the dot22_to_gemm rewrite for sure in MLX.
Description
mlx has a GEMM function that they call
mlx.core.addmm. We can dispatch our GEMM Op to it as follows:What's tricky is that the blas rewrite machinery is quite C specific. I'm not sure if we can register just the dot22_to_gemm rewrite for sure in MLX.