-
Notifications
You must be signed in to change notification settings - Fork 68
Add Enzyme rules and move some logic into pullbacks #243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Co-authored-by: Lukas Devos <[email protected]>
|
I'll gladly update the chainrules implementation once we are happy with the names and signatures etc |
|
Since I'm already modifying the Mooncake stuff let's also wait for Jutho/StridedViews.jl#19 to land so that we can test the complex stuff with those rules. |
lkdvos
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw you removed the import of the chainrules overloads for tensoralloc and tensorfree. I think we still actually need this, for example in the ManualAllocator case we are both capturing the intermediate tensors in the closures and we might be freeing them at the end of the contraction, so this would lead to a use after free (and possibly segfaults).
For Chainrules, we have taken the cautious approach and simply disabled having temporary tensors and freeing them (which is exactly these overloads, i.e. temp=Val(true) is replaced by temp = Val(false)), and I would think we need this here too?
TensorOperations.jl/ext/TensorOperationsChainRulesCoreExt.jl
Lines 26 to 41 in 22ef6dc
| function ChainRulesCore.rrule( | |
| ::typeof(TensorOperations.tensorfree!), allocator = DefaultAllocator() | |
| ) | |
| tensorfree!_pullback(Δargs...) = (NoTangent(), NoTangent()) | |
| return nothing, tensorfree!_pullback | |
| end | |
| function ChainRulesCore.rrule( | |
| ::typeof(TensorOperations.tensoralloc), ttype, structure, | |
| istemp, allocator = DefaultAllocator() | |
| ) | |
| output = TensorOperations.tensoralloc(ttype, structure, Val(false), allocator) | |
| function tensoralloc_pullback(Δargs...) | |
| return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) | |
| end | |
| return output, tensoralloc_pullback | |
| end |
I'm guessing it will be better to start using the shared pullbacks for Mooncake and Chainrules in separate PRs, so probably best to leave that for now as is?
Otherwise very cool!
We can probably use these PB functions for ChainRules as well though I'm less familiar with that library, so I might need to ask someone else to do the plumbing :)