File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed
openequivariance/openequivariance/jax/jvp Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -161,6 +161,12 @@ def conv_fwd_jvp_abstract_eval(
161161
162162conv_fwd_jvp_p .def_impl (conv_fwd_jvp_impl )
163163conv_fwd_jvp_p .def_abstract_eval (conv_fwd_jvp_abstract_eval )
164+ mlir .register_lowering (
165+ conv_fwd_jvp_p , mlir .lower_fun (conv_fwd_jvp_impl , multiple_results = False ), platform = "cuda"
166+ )
167+ mlir .register_lowering (
168+ conv_fwd_jvp_p , mlir .lower_fun (conv_fwd_jvp_impl , multiple_results = False ), platform = "rocm"
169+ )
164170
165171
166172# ==============================================================================
@@ -285,6 +291,12 @@ def conv_bwd_jvp_abstract_eval(
285291
286292conv_bwd_jvp_p .def_impl (conv_bwd_jvp_impl )
287293conv_bwd_jvp_p .def_abstract_eval (conv_bwd_jvp_abstract_eval )
294+ mlir .register_lowering (
295+ conv_bwd_jvp_p , mlir .lower_fun (conv_bwd_jvp_impl , multiple_results = True ), platform = "cuda"
296+ )
297+ mlir .register_lowering (
298+ conv_bwd_jvp_p , mlir .lower_fun (conv_bwd_jvp_impl , multiple_results = True ), platform = "rocm"
299+ )
288300
289301
290302# ==============================================================================
You can’t perform that action at this time.
0 commit comments