Skip to content

Commit 6d52164

Browse files
authored
allow TorchModuleWrapper compute output shape (#21160)
* allow TorchModuleWrapper compute output shape * modify
1 parent c90a3a5 commit 6d52164

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

Diff for: keras/src/utils/torch_utils.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class TorchModuleWrapper(Layer):
2525
instance, then its parameters must be initialized before
2626
passing the instance to `TorchModuleWrapper` (e.g. by calling
2727
it once).
28+
output_shape :The shape of the output of this layer. It helps Keras
29+
perform automatic shape inference.
2830
name: The name of the layer (string).
2931
3032
Example:
@@ -80,7 +82,7 @@ def call(self, inputs):
8082
```
8183
"""
8284

83-
def __init__(self, module, name=None, **kwargs):
85+
def __init__(self, module, name=None, output_shape=None, **kwargs):
8486
super().__init__(name=name, **kwargs)
8587
import torch.nn as nn
8688

@@ -98,6 +100,7 @@ def __init__(self, module, name=None, **kwargs):
98100

99101
self.module = module.to(get_device())
100102
self._track_module_parameters()
103+
self.output_shape = output_shape
101104

102105
def parameters(self, recurse=True):
103106
return self.module.parameters(recurse=recurse)
@@ -138,13 +141,21 @@ def load_own_variables(self, store):
138141
state_dict[key] = convert_to_tensor(store[key])
139142
self.module.load_state_dict(state_dict)
140143

144+
def compute_output_shape(self, input_shape):
145+
if self.output_shape is None:
146+
return super().compute_output_shape(input_shape)
147+
return self.output_shape
148+
141149
def get_config(self):
142150
base_config = super().get_config()
143151
import torch
144152

145153
buffer = io.BytesIO()
146154
torch.save(self.module, buffer)
147-
config = {"module": buffer.getvalue()}
155+
config = {
156+
"module": buffer.getvalue(),
157+
"output_shape": self.output_shape,
158+
}
148159
return {**base_config, **config}
149160

150161
@classmethod

Diff for: keras/src/utils/torch_utils_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from absl.testing import parameterized
77

8+
import keras
89
from keras.src import backend
910
from keras.src import layers
1011
from keras.src import models
@@ -235,3 +236,13 @@ def test_from_config(self):
235236
new_mw = TorchModuleWrapper.from_config(config)
236237
for ref_w, new_w in zip(mw.get_weights(), new_mw.get_weights()):
237238
self.assertAllClose(ref_w, new_w, atol=1e-5)
239+
240+
def test_build_model(self):
241+
x = keras.Input([4])
242+
z = TorchModuleWrapper(torch.nn.Linear(4, 8), output_shape=[None, 8])(x)
243+
y = TorchModuleWrapper(torch.nn.Linear(8, 16), output_shape=[None, 16])(
244+
z
245+
)
246+
model = keras.Model(x, y)
247+
self.assertEqual(model.predict(np.zeros([5, 4])).shape, (5, 16))
248+
self.assertEqual(model(np.zeros([5, 4])).shape, (5, 16))

0 commit comments

Comments
 (0)