Skip to content

Commit 2d0a628

Browse files
kshyattKatharine Hyatt
authored andcommitted
Add Enzyme rules
1 parent a9b9a4b commit 2d0a628

File tree

5 files changed

+305
-2
lines changed

5 files changed

+305
-2
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2424
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2525
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2626
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
27+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2728

2829
[extensions]
2930
TensorOperationsBumperExt = "Bumper"
3031
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
3132
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
3233
TensorOperationsMooncakeExt = "Mooncake"
34+
TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"]
3335

3436
[compat]
3537
Aqua = "0.6, 0.7, 0.8"
@@ -38,6 +40,8 @@ CUDA = "5"
3840
ChainRulesCore = "1"
3941
ChainRulesTestUtils = "1"
4042
DynamicPolynomials = "0.5, 0.6"
43+
Enzyme = "0.13.115"
44+
EnzymeTestUtils = "0.2"
4145
LRUCache = "1"
4246
LinearAlgebra = "1.6"
4347
Logging = "1.6"
@@ -59,13 +63,16 @@ julia = "1.10"
5963
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6064
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
6165
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
66+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6267
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
6368
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
69+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
70+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
6471
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
6572
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
6673
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6774
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6875
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
6976

7077
[targets]
71-
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"]
78+
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"]
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
module TensorOperationsEnzymeExt
2+
3+
using TensorOperations
4+
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
5+
using VectorInterface
6+
using TupleTools
7+
using Enzyme, ChainRulesCore
8+
using Enzyme.EnzymeCore
9+
using Enzyme.EnzymeCore: EnzymeRules
10+
11+
@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true
12+
Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any)
13+
14+
@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true
15+
@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true
16+
@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true
17+
@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true
18+
@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true
19+
20+
function EnzymeRules.augmented_primal(
21+
config::EnzymeRules.RevConfigWidth{1},
22+
func::Const{typeof(TensorOperations.tensorcontract!)},
23+
::Type{RT},
24+
C_dC::Annotation{<:AbstractArray{TC}},
25+
A_dA::Annotation{<:AbstractArray{TA}},
26+
pA_dpA::Const{<:Index2Tuple},
27+
conjA_dconjA::Const{Bool},
28+
B_dB::Annotation{<:AbstractArray{TB}},
29+
pB_dpB::Const{<:Index2Tuple},
30+
conjB_dconjB::Const{Bool},
31+
pAB_dpAB::Const{<:Index2Tuple},
32+
α_dα::Annotation{Tα},
33+
β_dβ::Annotation{Tβ},
34+
ba_dba::Const...,
35+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
36+
# form caches if needed
37+
cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
38+
cache_B = !isa(B_dB, Const) && EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing
39+
cache_C = copy(C_dC.val) # do we need to do this, if we don't need the primal?
40+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
41+
TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...)
42+
primal = if EnzymeRules.needs_primal(config)
43+
C_dC.val
44+
else
45+
nothing
46+
end
47+
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
48+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C))
49+
end
50+
51+
function EnzymeRules.reverse(
52+
config::EnzymeRules.RevConfigWidth{1},
53+
func::Const{typeof(TensorOperations.tensorcontract!)},
54+
::Type{RT},
55+
cache,
56+
C_dC::Annotation{<:AbstractArray{TC}},
57+
A_dA::Annotation{<:AbstractArray{TA}},
58+
pA_dpA::Const{<:Index2Tuple},
59+
conjA_dconjA::Const{Bool},
60+
B_dB::Annotation{<:AbstractArray{TB}},
61+
pB_dpB::Const{<:Index2Tuple},
62+
conjB_dconjB::Const{Bool},
63+
pAB_dpAB::Const{<:Index2Tuple},
64+
α_dα::Annotation{Tα},
65+
β_dβ::Annotation{Tβ},
66+
ba_dba::Const...,
67+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
68+
cache_A, cache_B, cache_C = cache
69+
Aval = something(cache_A, A_dA.val)
70+
Bval = something(cache_B, B_dB.val)
71+
Cval = cache_C
72+
dC = C_dC.dval
73+
dA = A_dA.dval
74+
dB = B_dB.dval
75+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
76+
α = α_dα.val
77+
β = β_dβ.val
78+
dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA_dpA.val, conjA_dconjA.val, Bval, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α, β, ba...)
79+
return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
80+
end
81+
82+
function EnzymeRules.augmented_primal(
83+
config::EnzymeRules.RevConfigWidth{1},
84+
::Annotation{typeof(tensoradd!)},
85+
::Type{RT},
86+
C_dC::Annotation{<:AbstractArray{TC}},
87+
A_dA::Annotation{<:AbstractArray{TA}},
88+
pA_dpA::Const{<:Index2Tuple},
89+
conjA_dconjA::Const{Bool},
90+
α_dα::Annotation{Tα},
91+
β_dβ::Annotation{Tβ},
92+
ba_dba::Const...,
93+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
94+
# form caches if needed
95+
cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
96+
cache_C = copy(C_dC.val)
97+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
98+
α = α_dα.val
99+
β = β_dβ.val
100+
conjA = conjA_dconjA.val
101+
TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...)
102+
primal = if EnzymeRules.needs_primal(config)
103+
C_dC.val
104+
else
105+
nothing
106+
end
107+
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
108+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C))
109+
end
110+
111+
function EnzymeRules.reverse(
112+
config::EnzymeRules.RevConfigWidth{1},
113+
::Annotation{typeof(tensoradd!)},
114+
::Type{RT},
115+
cache,
116+
C_dC::Annotation{<:AbstractArray{TC}},
117+
A_dA::Annotation{<:AbstractArray{TA}},
118+
pA_dpA::Const{<:Index2Tuple},
119+
conjA_dconjA::Const{Bool},
120+
α_dα::Annotation{Tα},
121+
β_dβ::Annotation{Tβ},
122+
ba_dba::Const...,
123+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
124+
cache_A, cache_C = cache
125+
Aval = something(cache_A, A_dA.val)
126+
Cval = cache_C
127+
pA = pA_dpA.val
128+
conjA = conjA_dconjA.val
129+
α = α_dα.val
130+
β = β_dβ.val
131+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
132+
dC = C_dC.dval
133+
dA = A_dA.dval
134+
dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, pA, conjA, α, β, ba...)
135+
return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
136+
end
137+
138+
function EnzymeRules.augmented_primal(
139+
config::EnzymeRules.RevConfigWidth{1},
140+
::Annotation{typeof(tensortrace!)},
141+
::Type{RT},
142+
C_dC::Annotation{<:AbstractArray{TC}},
143+
A_dA::Annotation{<:AbstractArray{TA}},
144+
p_dp::Const{<:Index2Tuple},
145+
q_dq::Const{<:Index2Tuple},
146+
conjA_dconjA::Const{Bool},
147+
α_dα::Annotation{Tα},
148+
β_dβ::Annotation{Tβ},
149+
ba_dba::Const...,
150+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
151+
# form caches if needed
152+
cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
153+
cache_C = copy(C_dC.val)
154+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
155+
α = α_dα.val
156+
β = β_dβ.val
157+
conjA = conjA_dconjA.val
158+
TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...)
159+
primal = if EnzymeRules.needs_primal(config)
160+
C_dC.val
161+
else
162+
nothing
163+
end
164+
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
165+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C))
166+
end
167+
168+
function EnzymeRules.reverse(
169+
config::EnzymeRules.RevConfigWidth{1},
170+
::Annotation{typeof(tensortrace!)},
171+
::Type{RT},
172+
cache,
173+
C_dC::Annotation{<:AbstractArray{TC}},
174+
A_dA::Annotation{<:AbstractArray{TA}},
175+
p_dp::Const{<:Index2Tuple},
176+
q_dq::Const{<:Index2Tuple},
177+
conjA_dconjA::Const{Bool},
178+
α_dα::Annotation{Tα},
179+
β_dβ::Annotation{Tβ},
180+
ba_dba::Const...,
181+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
182+
cache_A, cache_C = cache
183+
Aval = something(cache_A, A_dA.val)
184+
Cval = cache_C
185+
p = p_dp.val
186+
q = q_dq.val
187+
conjA = conjA_dconjA.val
188+
α = α_dα.val
189+
β = β_dβ.val
190+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
191+
dC = C_dC.dval
192+
dA = A_dA.dval
193+
dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, p, q, conjA, α, β, ba...)
194+
return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
195+
end
196+
197+
end

test/enzyme.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
using TensorOperations, VectorInterface
2+
using Enzyme, ChainRulesCore, EnzymeTestUtils
3+
4+
@testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in
5+
(
6+
(Float64, Float64),
7+
(Float32, Float64),
8+
(ComplexF64, ComplexF64),
9+
(Float64, ComplexF64),
10+
(ComplexF64, Float64),
11+
)
12+
T = promote_type(T₁, T₂)
13+
atol = max(precision(T₁), precision(T₂))
14+
rtol = max(precision(T₁), precision(T₂))
15+
16+
pAB = ((3, 2, 4, 1), ())
17+
pA = ((2, 4, 5), (1, 3))
18+
pB = ((2, 1), (3,))
19+
20+
A = rand(T₁, (2, 3, 4, 2, 5))
21+
B = rand(T₂, (4, 2, 3))
22+
C = rand(T, (5, 2, 3, 3))
23+
@testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T)))
24+
= α === Zero() ? Const : Active
25+
= β === Zero() ? Const : Active
26+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)
27+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)
28+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)
29+
30+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
31+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)
32+
end
33+
end
34+
35+
@testset "tensoradd! ($T₁, $T₂)" for (T₁, T₂) in (
36+
(Float64, Float64),
37+
(Float32, Float64),
38+
(ComplexF64, ComplexF64),
39+
(Float64, ComplexF64),
40+
)
41+
T = promote_type(T₁, T₂)
42+
atol = max(precision(T₁), precision(T₂))
43+
rtol = max(precision(T₁), precision(T₂))
44+
45+
pA = ((2, 1, 4, 3, 5), ())
46+
A = rand(T₁, (2, 3, 4, 2, 1))
47+
C = rand(T₂, size.(Ref(A), pA[1]))
48+
@testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T)))
49+
= α === Zero() ? Const : Active
50+
= β === Zero() ? Const : Active
51+
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol)
52+
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol)
53+
54+
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
55+
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)
56+
end
57+
end
58+
59+
@testset "tensortrace! ($T₁, $T₂)" for (T₁, T₂) in
60+
(
61+
(Float64, Float64),
62+
(Float32, Float64),
63+
(ComplexF64, ComplexF64),
64+
(Float64, ComplexF64),
65+
)
66+
T = promote_type(T₁, T₂)
67+
atol = max(precision(T₁), precision(T₂))
68+
rtol = max(precision(T₁), precision(T₂))
69+
70+
p = ((3, 5, 2), ())
71+
q = ((1,), (4,))
72+
A = rand(T₁, (2, 3, 4, 2, 5))
73+
C = rand(T₂, size.(Ref(A), p[1]))
74+
@testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T)))
75+
= α === Zero() ? Const : Active
76+
= β === Zero() ? Const : Active
77+
78+
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol)
79+
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol)
80+
81+
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
82+
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)
83+
end
84+
end
85+
86+
@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64)
87+
atol = precision(T)
88+
rtol = precision(T)
89+
90+
C = Array{T, 0}(undef, ())
91+
fill!(C, rand(T))
92+
test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol)
93+
end

test/mooncake.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ is_primitive = false
1414
(Float32, Float64),
1515
#(ComplexF64, ComplexF64),
1616
#(Float64, ComplexF64),
17+
#(ComplexF64, Float64),
1718
)
1819
T = promote_type(T₁, T₂)
1920
atol = max(precision(T₁), precision(T₂))

test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8
1515
# specific ones
1616
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1717
if !is_buildkite
18-
1918
@testset "tensoropt" verbose = true begin
2019
include("tensoropt.jl")
2120
end
@@ -37,6 +36,12 @@ if !is_buildkite
3736
@testset "mooncake" verbose = false begin
3837
include("mooncake.jl")
3938
end
39+
# mystery segfault on 1.10 for now
40+
@static if VERSION >= v"1.11.0"
41+
@testset "enzyme" verbose = false begin
42+
include("enzyme.jl")
43+
end
44+
end
4045
end
4146

4247
if is_buildkite

0 commit comments

Comments
 (0)