Skip to content

Commit b6d305f

Browse files
authored
Fix cloning of Sequential models w. input_tensors argument (#20550)
* Fix cloning for Sequential w. input tensor * Add missing test for input_tensor argument * Add Sequential wo. Input to test, build model to ensure defined inputs
1 parent 553521e commit b6d305f

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

Diff for: keras/src/models/cloning.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
298298
input_dtype = None
299299
input_batch_shape = None
300300

301-
if input_tensors:
301+
if input_tensors is not None:
302302
if isinstance(input_tensors, (list, tuple)):
303303
if len(input_tensors) != 1:
304304
raise ValueError(
@@ -310,7 +310,12 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
310310
"Argument `input_tensors` must be a KerasTensor. "
311311
f"Received invalid value: input_tensors={input_tensors}"
312312
)
313-
inputs = Input(tensor=input_tensors, name=input_name)
313+
inputs = Input(
314+
tensor=input_tensors,
315+
batch_shape=input_tensors.shape,
316+
dtype=input_tensors.dtype,
317+
name=input_name,
318+
)
314319
new_layers = [inputs] + new_layers
315320
else:
316321
if input_batch_shape is not None:

Diff for: keras/src/models/cloning_test.py

+26
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def get_sequential_model(explicit_input=True):
6161
return model
6262

6363

64+
def get_cnn_sequential_model(explicit_input=True):
65+
model = models.Sequential()
66+
if explicit_input:
67+
model.add(layers.Input(shape=(7, 3)))
68+
model.add(layers.Conv1D(2, 2, padding="same"))
69+
model.add(layers.Conv1D(2, 2, padding="same"))
70+
return model
71+
72+
6473
def get_subclassed_model():
6574
class ExampleModel(models.Model):
6675
def __init__(self, **kwargs):
@@ -124,6 +133,23 @@ def clone_function(layer):
124133
if not isinstance(l1, layers.InputLayer):
125134
self.assertEqual(l2.name, l1.name + "_custom")
126135

136+
@parameterized.named_parameters(
137+
("cnn_functional", get_cnn_functional_model),
138+
("cnn_sequential", get_cnn_sequential_model),
139+
(
140+
"cnn_sequential_noinputlayer",
141+
lambda: get_cnn_sequential_model(explicit_input=False),
142+
),
143+
)
144+
def test_input_tensors(self, model_fn):
145+
ref_input = np.random.random((2, 7, 3))
146+
model = model_fn()
147+
model(ref_input) # Maybe needed to get model inputs if no Input layer
148+
input_tensor = model.inputs[0]
149+
new_model = clone_model(model, input_tensors=input_tensor)
150+
tree.assert_same_structure(model.inputs, new_model.inputs)
151+
tree.assert_same_structure(model.outputs, new_model.outputs)
152+
127153
def test_shared_layers_cloning(self):
128154
model = get_mlp_functional_model(shared_layers=True)
129155
new_model = clone_model(model)

0 commit comments

Comments
 (0)