Skip to content

Commit 9067d80

Browse files
divyashreepathihallilaxmareddyp
authored andcommitted
Fix Remat error when called with a model (#21094)
* add print * fix remat issue * simplify code * enable traceback filtering and update the function sig * add a wrapper for activations * change to except * add layer call decorator * fix remat call
1 parent c0545ac commit 9067d80

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

keras/src/layers/layer.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
import collections
20+
import functools
2021
import inspect
2122
import math
2223
import warnings
@@ -1043,11 +1044,13 @@ def stateless_call(
10431044
if self._remat_mode is not None:
10441045
outputs = self.rematerialized_call(
10451046
self.quantized_call, *args, **kwargs
1046-
)
1047+
)(*args, **kwargs)
10471048
else:
10481049
outputs = self.quantized_call(*args, **kwargs)
10491050
elif self._remat_mode is not None:
1050-
outputs = self.rematerialized_call(self.call, *args, **kwargs)
1051+
outputs = self.rematerialized_call(self.call, *args, **kwargs)(
1052+
*args, **kwargs
1053+
)
10511054
else:
10521055
outputs = self.call(*args, **kwargs)
10531056
if return_losses:
@@ -1601,13 +1604,13 @@ def compute_size(x):
16011604

16021605
# Full rematerialization
16031606
if self._remat_mode.mode == "full":
1604-
return remat.remat(layer_call)(*args, **kwargs)
1607+
return remat.remat(layer_call)
16051608

16061609
# Apply rematerialization to specific layers
16071610
elif self._remat_mode.mode == "list_of_layers" and (
16081611
self.name in self._remat_mode.layer_names
16091612
):
1610-
return remat.remat(layer_call)(*args, **kwargs)
1613+
return remat.remat(layer_call)
16111614

16121615
# Apply rematerialization based on output size threshold
16131616
elif self._remat_mode.mode == "larger_than":
@@ -1619,20 +1622,24 @@ def compute_size(x):
16191622
output_size
16201623
and output_size > self._remat_mode.output_size_threshold
16211624
):
1622-
return remat.remat(layer_call)(*args, **kwargs)
1625+
return remat.remat(layer_call)
16231626
elif self._remat_mode.mode == "activations":
16241627
has_activation = (
16251628
hasattr(self, "activation") and self.activation is not None
16261629
)
16271630
if has_activation:
1628-
not_rematted_activation = self.activation
1629-
try:
1630-
self.activation = remat.remat(not_rematted_activation)
1631-
return layer_call(*args, **kwargs)
1632-
finally:
1633-
self.activation = not_rematted_activation
16341631

1635-
return layer_call(*args, **kwargs)
1632+
@functools.wraps(layer_call)
1633+
def rematerialized_activation_call_wrapper(*args, **kwargs):
1634+
original_activation = self.activation
1635+
self.activation = remat.remat(original_activation)
1636+
try:
1637+
return layer_call(*args, **kwargs)
1638+
finally:
1639+
self.activation = original_activation
1640+
1641+
return rematerialized_activation_call_wrapper
1642+
return layer_call
16361643

16371644

16381645
def is_backend_tensor_or_symbolic(x, allow_none=False):

keras/src/layers/layer_test.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras.src.backend.common import remat
1818
from keras.src.backend.common.remat import RematScope
1919
from keras.src.models import Model
20+
from keras.src.utils import traceback_utils
2021

2122

2223
class LayerTest(testing.TestCase):
@@ -219,9 +220,11 @@ def test_functional_model_with_remat(self):
219220
self.skipTest(
220221
"remat is not supported in openvino and numpy backends."
221222
)
222-
with patch(
223-
"keras.src.backend.common.remat.remat", wraps=remat.remat
224-
) as mock_remat:
223+
traceback_utils.enable_traceback_filtering()
224+
mock_remat = MockRemat()
225+
with mock.patch(
226+
"keras.src.backend.common.remat.remat", wraps=mock_remat
227+
):
225228
# Define model inputs
226229
inputs = Input(shape=(32, 32, 3))
227230

keras/src/ops/operation.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ def __call__(self, *args, **kwargs):
3737
else:
3838
if getattr(self, "_remat_mode", None) is not None:
3939
if getattr(self, "quantization_mode", None) is not None:
40-
call_fn = self.rematerialized_call(self.quantized_call)
40+
call_fn = self.rematerialized_call(
41+
self.quantized_call,
42+
*args,
43+
**kwargs,
44+
)
4145
else:
42-
call_fn = self.rematerialized_call(self.call)
46+
call_fn = self.rematerialized_call(
47+
self.call, *args, **kwargs
48+
)
4349
else:
4450
if getattr(self, "quantization_mode", None) is not None:
4551
call_fn = self.quantized_call
@@ -58,9 +64,11 @@ def __call__(self, *args, **kwargs):
5864
if getattr(self, "quantization_mode", None) is not None:
5965
return self.rematerialized_call(
6066
self.quantized_call, *args, **kwargs
61-
)
67+
)(*args, **kwargs)
6268
else:
63-
return self.rematerialized_call(self.call, *args, **kwargs)
69+
return self.rematerialized_call(self.call, *args, **kwargs)(
70+
*args, **kwargs
71+
)
6472
else:
6573
if getattr(self, "quantization_mode", None) is not None:
6674
return self.quantized_call(*args, **kwargs)

0 commit comments

Comments
 (0)