@@ -25,6 +25,8 @@ class TorchModuleWrapper(Layer):
25
25
instance, then its parameters must be initialized before
26
26
passing the instance to `TorchModuleWrapper` (e.g. by calling
27
27
it once).
28
+ output_shape :The shape of the output of this layer. It helps Keras
29
+ perform automatic shape inference.
28
30
name: The name of the layer (string).
29
31
30
32
Example:
@@ -80,7 +82,7 @@ def call(self, inputs):
80
82
```
81
83
"""
82
84
83
- def __init__ (self , module , name = None , ** kwargs ):
85
+ def __init__ (self , module , name = None , output_shape = None , ** kwargs ):
84
86
super ().__init__ (name = name , ** kwargs )
85
87
import torch .nn as nn
86
88
@@ -98,6 +100,7 @@ def __init__(self, module, name=None, **kwargs):
98
100
99
101
self .module = module .to (get_device ())
100
102
self ._track_module_parameters ()
103
+ self .output_shape = output_shape
101
104
102
105
def parameters (self , recurse = True ):
103
106
return self .module .parameters (recurse = recurse )
@@ -138,13 +141,21 @@ def load_own_variables(self, store):
138
141
state_dict [key ] = convert_to_tensor (store [key ])
139
142
self .module .load_state_dict (state_dict )
140
143
144
+ def compute_output_shape (self , input_shape ):
145
+ if self .output_shape is None :
146
+ return super ().compute_output_shape (input_shape )
147
+ return self .output_shape
148
+
141
149
def get_config (self ):
142
150
base_config = super ().get_config ()
143
151
import torch
144
152
145
153
buffer = io .BytesIO ()
146
154
torch .save (self .module , buffer )
147
- config = {"module" : buffer .getvalue ()}
155
+ config = {
156
+ "module" : buffer .getvalue (),
157
+ "output_shape" : self .output_shape ,
158
+ }
148
159
return {** base_config , ** config }
149
160
150
161
@classmethod
0 commit comments