Skip to content

Fix Torch backend LSTM cuDNN: remove blanket exception catch and forced cpu() transfer#22279

Merged
hertschuh merged 5 commits intokeras-team:masterfrom
rstar327:fix-lstm-cudnn-torch-backend
Mar 13, 2026
Merged

Fix Torch backend LSTM cuDNN: remove blanket exception catch and forced cpu() transfer#22279
hertschuh merged 5 commits intokeras-team:masterfrom
rstar327:fix-lstm-cudnn-torch-backend

Conversation

@rstar327
Copy link
Contributor

@rstar327 rstar327 commented Feb 24, 2026

Summary

Fixes two issues in the Torch backend's cuDNN LSTM implementation that cause use_cudnn=True to fail even when all cuDNN criteria are satisfied:

  1. Narrowed exception catch from Exception to RuntimeError — The blanket catch-all converted any exception into NotImplementedError, hiding real errors. Now only RuntimeError triggers the cuDNN fallback, preserving graceful degradation for legitimate config issues while letting real bugs propagate.

  2. Removed .detach().clone().cpu() on output tensors — The .cpu() broke GPU computation, and .detach().clone() is unnecessary since the tensors don't need to be copied, matching the GRU implementation in Implement GRU for PyTorch backend #22115.

Fixes #22274

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @rstar327, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves critical issues in the Torch backend's cuDNN LSTM implementation that previously hindered its functionality. By addressing a broad exception catch and preventing the premature transfer of output tensors to the CPU, the changes enable proper error reporting and ensure that cuDNN LSTM operates correctly on the GPU, thereby improving the reliability and performance of recurrent neural networks.

Highlights

  • Removed blanket exception catch: Eliminated the try/except Exception: raise NotImplementedError block in the lstm function. This change allows real errors during cuDNN LSTM execution to propagate, instead of being masked by a generic NotImplementedError, which previously led to misleading diagnostics.
  • Stopped forced CPU transfer: Removed the .cpu() calls applied to outputs, h_n, and c_n tensors within the _cudnn_lstm function. This ensures that tensors processed by cuDNN LSTM remain on the GPU, preventing unnecessary device transfers and maintaining consistency for subsequent GPU-based operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/backend/torch/rnn.py
    • Removed the try...except Exception block surrounding the _cudnn_lstm call to allow specific errors to surface.
    • Removed .cpu() calls from the outputs, h_n, and c_n tensors in _cudnn_lstm to keep them on the original device.
Activity
  • No human activity has been recorded for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses two important issues in the Torch backend's cuDNN LSTM implementation. First, it correctly removes a broad exception handler that was masking underlying errors, which significantly improves debuggability. Second, it removes unnecessary .cpu() calls on output tensors, ensuring they remain on the correct device for further computation.

While reviewing, I noticed a critical issue that remains. The output tensors are still being detached from the computation graph via .detach(), which will prevent the LSTM layer from learning during training. I've added a comment with a suggestion to fix this by removing the .detach() calls.

@codecov-commenter
Copy link

codecov-commenter commented Feb 24, 2026

Codecov Report

❌ Patch coverage is 0% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 83.04%. Comparing base (0ddf962) to head (70365ba).
⚠️ Report is 96 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/rnn.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #22279       +/-   ##
===========================================
+ Coverage   71.44%   83.04%   +11.59%     
===========================================
  Files         594      596        +2     
  Lines       65029    66704     +1675     
  Branches    10174    10384      +210     
===========================================
+ Hits        46461    55394     +8933     
+ Misses      16105     8673     -7432     
- Partials     2463     2637      +174     
Flag Coverage Δ
keras 82.87% <0.00%> (+11.54%) ⬆️
keras-jax 60.53% <0.00%> (-0.91%) ⬇️
keras-numpy 54.77% <0.00%> (-0.87%) ⬇️
keras-openvino 49.95% <0.00%> (?)
keras-tensorflow 61.78% <0.00%> (?)
keras-torch 60.60% <0.00%> (-0.95%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@MarcosAsh
Copy link
Contributor

Hey rstar327,

I think you're on the right track but you have some issues. I recently worked on the same thing on the GRU side in #22215 so I might be of help. For the try/except part you might want to narrow it rather than remove it entirely. The LSTM layer's call() method catches NotImplementedError to fall back to the non-cuDNN path, so removing the try/except means if cuDNN fails for a legitimate config reason it'll hard crash instead of falling back gracefully. Something like catching only RuntimeError from torch would preserve that behavior while still letting real bugs through.

Also I think Gemini's suggestion about .detach() is correct I fixed the same thing in #22115 if you want a reference. You can drop both .detach() and .clone() entirely, the tensors don't need to be copied.

Let me know if you want a hand with any of this!

- Catch only RuntimeError instead of blanket Exception so real errors
  propagate while cuDNN still falls back gracefully
- Remove .detach().clone() as tensors don't need to be copied,
  matching the GRU implementation in keras-team#22115
rstar327 added 2 commits March 2, 2026 03:46
Replace np.split/np.concatenate with torch.chunk/torch.cat in
prepare_lstm_weights to fix CUDA tensor handling, matching the
GRU implementation pattern.
@rstar327
Copy link
Contributor Author

rstar327 commented Mar 5, 2026

Done

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the fix!

Can you look at the failing Torch tests and change the precision for those to 1e-5, it will be very similar to this: https://github.com/keras-team/keras/pull/22115/changes#diff-4b44f448b4e0809176c41ba1df1500998524f979e5f4a456ecc21dce3f935344

@rstar327
Copy link
Contributor Author

Done, added atol=1e-5, rtol=1e-5 to all assertAllClose calls in lstm_test.py, matching the same pattern from the GRU fix in #22115.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@google-ml-butler google-ml-butler bot added the ready to pull Ready to be merged into the codebase label Mar 13, 2026
@hertschuh hertschuh merged commit 3b13145 into keras-team:master Mar 13, 2026
13 of 14 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Torch backend: LSTM fails to use cuDNN when criteria are satisfied

6 participants