Fix Torch backend LSTM cuDNN: remove blanket exception catch and forced cpu() transfer#22279
Conversation
Summary of ChangesHello @rstar327, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves critical issues in the Torch backend's cuDNN LSTM implementation that previously hindered its functionality. By addressing a broad exception catch and preventing the premature transfer of output tensors to the CPU, the changes enable proper error reporting and ensure that cuDNN LSTM operates correctly on the GPU, thereby improving the reliability and performance of recurrent neural networks. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses two important issues in the Torch backend's cuDNN LSTM implementation. First, it correctly removes a broad exception handler that was masking underlying errors, which significantly improves debuggability. Second, it removes unnecessary .cpu() calls on output tensors, ensuring they remain on the correct device for further computation.
While reviewing, I noticed a critical issue that remains. The output tensors are still being detached from the computation graph via .detach(), which will prevent the LSTM layer from learning during training. I've added a comment with a suggestion to fix this by removing the .detach() calls.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22279 +/- ##
===========================================
+ Coverage 71.44% 83.04% +11.59%
===========================================
Files 594 596 +2
Lines 65029 66704 +1675
Branches 10174 10384 +210
===========================================
+ Hits 46461 55394 +8933
+ Misses 16105 8673 -7432
- Partials 2463 2637 +174
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hey rstar327, I think you're on the right track but you have some issues. I recently worked on the same thing on the GRU side in #22215 so I might be of help. For the try/except part you might want to narrow it rather than remove it entirely. The LSTM layer's call() method catches NotImplementedError to fall back to the non-cuDNN path, so removing the try/except means if cuDNN fails for a legitimate config reason it'll hard crash instead of falling back gracefully. Something like catching only RuntimeError from torch would preserve that behavior while still letting real bugs through. Also I think Gemini's suggestion about .detach() is correct I fixed the same thing in #22115 if you want a reference. You can drop both .detach() and .clone() entirely, the tensors don't need to be copied. Let me know if you want a hand with any of this! |
- Catch only RuntimeError instead of blanket Exception so real errors propagate while cuDNN still falls back gracefully - Remove .detach().clone() as tensors don't need to be copied, matching the GRU implementation in keras-team#22115
Replace np.split/np.concatenate with torch.chunk/torch.cat in prepare_lstm_weights to fix CUDA tensor handling, matching the GRU implementation pattern.
|
Done |
hertschuh
left a comment
There was a problem hiding this comment.
Thanks a lot for the fix!
Can you look at the failing Torch tests and change the precision for those to 1e-5, it will be very similar to this: https://github.com/keras-team/keras/pull/22115/changes#diff-4b44f448b4e0809176c41ba1df1500998524f979e5f4a456ecc21dce3f935344
|
Done, added atol=1e-5, rtol=1e-5 to all assertAllClose calls in lstm_test.py, matching the same pattern from the GRU fix in #22115. |
Summary
Fixes two issues in the Torch backend's cuDNN LSTM implementation that cause
use_cudnn=Trueto fail even when all cuDNN criteria are satisfied:Narrowed exception catch from
ExceptiontoRuntimeError— The blanket catch-all converted any exception intoNotImplementedError, hiding real errors. Now onlyRuntimeErrortriggers the cuDNN fallback, preserving graceful degradation for legitimate config issues while letting real bugs propagate.Removed
.detach().clone().cpu()on output tensors — The.cpu()broke GPU computation, and.detach().clone()is unnecessary since the tensors don't need to be copied, matching the GRU implementation in Implement GRU for PyTorch backend #22115.Fixes #22274