-
Notifications
You must be signed in to change notification settings - Fork 19.6k
allow TorchModuleWrapper compute output shape #21160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21160 +/- ##
=======================================
Coverage 82.69% 82.69%
=======================================
Files 564 564
Lines 54223 54228 +5
Branches 8424 8425 +1
=======================================
+ Hits 44837 44842 +5
Misses 7310 7310
Partials 2076 2076
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
keras/src/utils/torch_utils.py
Outdated
@@ -25,6 +25,8 @@ class TorchModuleWrapper(Layer): | |||
instance, then its parameters must be initialized before | |||
passing the instance to `TorchModuleWrapper` (e.g. by calling | |||
it once). | |||
out_shape :The shape of the output of this layer. It helps Keras |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please name it output_shape
@@ -138,13 +141,18 @@ def load_own_variables(self, store): | |||
state_dict[key] = convert_to_tensor(store[key]) | |||
self.module.load_state_dict(state_dict) | |||
|
|||
def compute_output_shape(self, input_shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we consider validating this in __call__
?
The TorchModuleWrapper layer does not allow automatic inference of the output. Therefore, the model that owns this layer cannot be used like other keras.Model. For example, you cannot use model.summary().
And now the vast majority of models are based on Torch. By adding this feature, we can let ordinary Torch models enjoy the workflow of Keras without making too many modifications.