Skip to content

allow TorchModuleWrapper compute output shape #21160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 14, 2025

Conversation

pass-lin
Copy link
Contributor

The TorchModuleWrapper layer does not allow automatic inference of the output. Therefore, the model that owns this layer cannot be used like other keras.Model. For example, you cannot use model.summary().

And now the vast majority of models are based on Torch. By adding this feature, we can let ordinary Torch models enjoy the workflow of Keras without making too many modifications.

@codecov-commenter
Copy link

codecov-commenter commented Apr 13, 2025

Codecov Report

Attention: Patch coverage is 71.42857% with 2 lines in your changes missing coverage. Please review.

Project coverage is 82.69%. Comparing base (2111fbc) to head (b688627).

Files with missing lines Patch % Lines
keras/src/utils/torch_utils.py 71.42% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21160   +/-   ##
=======================================
  Coverage   82.69%   82.69%           
=======================================
  Files         564      564           
  Lines       54223    54228    +5     
  Branches     8424     8425    +1     
=======================================
+ Hits        44837    44842    +5     
  Misses       7310     7310           
  Partials     2076     2076           
Flag Coverage Δ
keras 82.50% <71.42%> (+<0.01%) ⬆️
keras-jax 63.91% <28.57%> (-0.01%) ⬇️
keras-numpy 59.02% <28.57%> (-0.01%) ⬇️
keras-openvino 32.98% <28.57%> (-0.01%) ⬇️
keras-tensorflow 64.29% <28.57%> (-0.01%) ⬇️
keras-torch 63.99% <71.42%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gbaned gbaned added this to PR Queue Apr 14, 2025
@gbaned gbaned requested a review from fchollet April 14, 2025 07:36
@github-project-automation github-project-automation bot moved this to Assigned Reviewer in PR Queue Apr 14, 2025
@@ -25,6 +25,8 @@ class TorchModuleWrapper(Layer):
instance, then its parameters must be initialized before
passing the instance to `TorchModuleWrapper` (e.g. by calling
it once).
out_shape :The shape of the output of this layer. It helps Keras
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please name it output_shape

@@ -138,13 +141,18 @@ def load_own_variables(self, store):
state_dict[key] = convert_to_tensor(store[key])
self.module.load_state_dict(state_dict)

def compute_output_shape(self, input_shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider validating this in __call__?

@github-project-automation github-project-automation bot moved this from Assigned Reviewer to Approved by Reviewer in PR Queue Apr 14, 2025
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Apr 14, 2025
@fchollet fchollet merged commit 6d52164 into keras-team:master Apr 14, 2025
7 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Apr 14, 2025
@github-project-automation github-project-automation bot moved this from Approved by Reviewer to Merged in PR Queue Apr 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

4 participants