17
17
"""
18
18
19
19
import collections
20
+ import functools
20
21
import inspect
21
22
import math
22
23
import warnings
@@ -1043,11 +1044,13 @@ def stateless_call(
1043
1044
if self ._remat_mode is not None :
1044
1045
outputs = self .rematerialized_call (
1045
1046
self .quantized_call , * args , ** kwargs
1046
- )
1047
+ )( * args , ** kwargs )
1047
1048
else :
1048
1049
outputs = self .quantized_call (* args , ** kwargs )
1049
1050
elif self ._remat_mode is not None :
1050
- outputs = self .rematerialized_call (self .call , * args , ** kwargs )
1051
+ outputs = self .rematerialized_call (self .call , * args , ** kwargs )(
1052
+ * args , ** kwargs
1053
+ )
1051
1054
else :
1052
1055
outputs = self .call (* args , ** kwargs )
1053
1056
if return_losses :
@@ -1601,13 +1604,13 @@ def compute_size(x):
1601
1604
1602
1605
# Full rematerialization
1603
1606
if self ._remat_mode .mode == "full" :
1604
- return remat .remat (layer_call )( * args , ** kwargs )
1607
+ return remat .remat (layer_call )
1605
1608
1606
1609
# Apply rematerialization to specific layers
1607
1610
elif self ._remat_mode .mode == "list_of_layers" and (
1608
1611
self .name in self ._remat_mode .layer_names
1609
1612
):
1610
- return remat .remat (layer_call )( * args , ** kwargs )
1613
+ return remat .remat (layer_call )
1611
1614
1612
1615
# Apply rematerialization based on output size threshold
1613
1616
elif self ._remat_mode .mode == "larger_than" :
@@ -1619,20 +1622,24 @@ def compute_size(x):
1619
1622
output_size
1620
1623
and output_size > self ._remat_mode .output_size_threshold
1621
1624
):
1622
- return remat .remat (layer_call )( * args , ** kwargs )
1625
+ return remat .remat (layer_call )
1623
1626
elif self ._remat_mode .mode == "activations" :
1624
1627
has_activation = (
1625
1628
hasattr (self , "activation" ) and self .activation is not None
1626
1629
)
1627
1630
if has_activation :
1628
- not_rematted_activation = self .activation
1629
- try :
1630
- self .activation = remat .remat (not_rematted_activation )
1631
- return layer_call (* args , ** kwargs )
1632
- finally :
1633
- self .activation = not_rematted_activation
1634
1631
1635
- return layer_call (* args , ** kwargs )
1632
+ @functools .wraps (layer_call )
1633
+ def rematerialized_activation_call_wrapper (* args , ** kwargs ):
1634
+ original_activation = self .activation
1635
+ self .activation = remat .remat (original_activation )
1636
+ try :
1637
+ return layer_call (* args , ** kwargs )
1638
+ finally :
1639
+ self .activation = original_activation
1640
+
1641
+ return rematerialized_activation_call_wrapper
1642
+ return layer_call
1636
1643
1637
1644
1638
1645
def is_backend_tensor_or_symbolic (x , allow_none = False ):
0 commit comments