Skip to content

Commit 4fd738f

Browse files
committed
register lowering for conv_fwd_jvp_p and conv_bwd_jvp_p
1 parent 64ef9f8 commit 4fd738f

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

openequivariance/openequivariance/jax/jvp/conv_prim.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ def conv_fwd_jvp_abstract_eval(
161161

162162
conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl)
163163
conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval)
164+
mlir.register_lowering(
165+
conv_fwd_jvp_p, mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), platform="cuda"
166+
)
167+
mlir.register_lowering(
168+
conv_fwd_jvp_p, mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False), platform="rocm"
169+
)
164170

165171

166172
# ==============================================================================
@@ -285,6 +291,12 @@ def conv_bwd_jvp_abstract_eval(
285291

286292
conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl)
287293
conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval)
294+
mlir.register_lowering(
295+
conv_bwd_jvp_p, mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), platform="cuda"
296+
)
297+
mlir.register_lowering(
298+
conv_bwd_jvp_p, mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True), platform="rocm"
299+
)
288300

289301

290302
# ==============================================================================

0 commit comments

Comments
 (0)