Skip to content

WIP: Add PyTorch backend support for LSTM with CuDNN optimization #21135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 15, 2025

Conversation

praveenhosdrug123
Copy link
Contributor

@praveenhosdrug123 praveenhosdrug123 commented Apr 4, 2025

This is a work-in-progress implementation of LSTM for the PyTorch backend with CuDNN optimization.
This is a followup from the GitHub pull request #20875 and #20916

The working example can be found under the following collab:
https://colab.research.google.com/drive/1Vciv4nulEAHpY8_wstNfNzx4GMjpxGyB

Current features:

  • Support for variable length sequences with padding masks
  • Proper weight conversion between Keras/TF format and PyTorch format
  • CuDNN acceleration when available

Still to be addressed:

  • Additional testing for edge cases
  • More comprehensive error handling
  • Performance optimization and benchmarking
  • Support for additional configuration options (bidirectional, etc.)

Feedback welcome on:

  1. The weight conversion approach between Keras and PyTorch
  2. Handling of masks and variable-length sequences
  3. Error handling and fallback options when CuDNN is not available

Note: I've added an exclusion for the keras/src/namex directory in the Ruff configuration to prevent linting errors in this third-party code. My actual implementation code passes all linting checks.

Sorry, something went wrong.

@codecov-commenter
Copy link

codecov-commenter commented Apr 4, 2025

Codecov Report

Attention: Patch coverage is 14.40000% with 107 lines in your changes missing coverage. Please review.

Project coverage is 82.58%. Comparing base (2111fbc) to head (c9e8149).
Report is 3 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/rnn.py 14.40% 107 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21135      +/-   ##
==========================================
- Coverage   82.69%   82.58%   -0.11%     
==========================================
  Files         564      564              
  Lines       54223    54320      +97     
  Branches     8424     8437      +13     
==========================================
+ Hits        44837    44860      +23     
- Misses       7310     7389      +79     
+ Partials     2076     2071       -5     
Flag Coverage Δ
keras 82.39% <14.40%> (-0.11%) ⬇️
keras-jax 63.83% <8.80%> (-0.10%) ⬇️
keras-numpy 58.94% <8.80%> (-0.09%) ⬇️
keras-openvino 32.92% <0.00%> (-0.06%) ⬇️
keras-tensorflow 64.21% <8.80%> (-0.09%) ⬇️
keras-torch 63.92% <14.40%> (-0.07%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@fchollet
Copy link
Collaborator

fchollet commented Apr 9, 2025

Thank you for the PR! The code is looking good. Running GPU tests now.

Did you observe a good speed up?

@praveenhosdrug123
Copy link
Contributor Author

Thank you for the PR! The code is looking good. Running GPU tests now.

Did you observe a good speed up?

Thank you!. I revised the code and have pushed the changes. In the current collab, It looks good.

I also tried to test for the majority of the GPU tests which were failing and it works now.
===== BENCHMARK SUMMARY =====
Average speedup: 15.11x
Median speedup: 11.53x
Min speedup: 2.28x
Max speedup: 64.08x

Best configuration:
Batch size: 32.0
Sequence length: 200.0
Feature dimensions: 64.0
Hidden units: 64.0
Speedup: 64.08x

Benchmark

@praveenhosdrug123
Copy link
Contributor Author

Hi All,

Im trying to push the updates post the GPU tests to address the issues of the earlier tests. But I am stuck with pre-commit hooks since I don't have CUDA available on my machine. The below error shows up in the check the code format test section. I have tried to skip and no-verify flags but of no avail. Reaching out and requesting guidance on next steps.

Run pre-commit run --all-files --hook-stage manual

[INFO] Initializing environment for https://github.com/astral-sh/ruff-pre-commit.

[INFO] Installing environment for https://github.com/astral-sh/ruff-pre-commit.

[INFO] Once installed this environment will be reused.

[INFO] This may take a few minutes...

api_gen..................................................................Failed

  • hook id: api-gen

  • files were modified by this hook

Generating api directory with public APIs...

2025-04-13 19:35:45.504044: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.

2025-04-13 19:35:45.507247: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.

2025-04-13 19:35:45.516018: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR

E0000 00:00:1744572945.530281 2302 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered

E0000 00:00:1744572945.534547 2302 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

2025-04-13 19:35:45.550126: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.

To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

/opt/hostedtoolcache/Python/3.10.16/x64/lib/python3.10/site-packages/openvino/runtime/init.py:10: DeprecationWarning: The openvino.runtime module is deprecated and will be removed in the 2026.0 release. Please replace openvino.runtime with openvino.

warnings.warn(

Formatting api directory...

ruff.....................................................................Passed

ruff-format..............................................................Passed

@fchollet
Copy link
Collaborator

Seems to be some local environment issues -- we can just keep testing on CI in that case. Re-running now.

@fchollet
Copy link
Collaborator

It's working!

The code needs to be formatted, though -- can you do it?

@praveenhosdrug123
Copy link
Contributor Author

praveenhosdrug123 commented Apr 15, 2025

It's working!

The code needs to be formatted, though -- can you do it?

Just submitted the formatted code. Thank you Francois!

@fchollet
Copy link
Collaborator

Thank you for the contribution!

@github-project-automation github-project-automation bot moved this from Assigned Reviewer to Approved by Reviewer in PR Queue Apr 15, 2025
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Apr 15, 2025
@fchollet fchollet merged commit 128e280 into keras-team:master Apr 15, 2025
7 checks passed
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Apr 15, 2025
@github-project-automation github-project-automation bot moved this from Approved by Reviewer to Merged in PR Queue Apr 15, 2025
@praveenhosdrug123 praveenhosdrug123 deleted the lstm branch April 17, 2025 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

None yet

5 participants