diff --git a/lion/lion_pytorch.py b/lion/lion_pytorch.py index ffd52d99..8377b451 100644 --- a/lion/lion_pytorch.py +++ b/lion/lion_pytorch.py @@ -78,7 +78,7 @@ def step(self, closure=None): # Weight update update = exp_avg * beta1 + grad * (1 - beta1) - p.add_(torch.sign(update), alpha=-group['lr'], inplace=True) + p.add_(torch.sign(update), alpha=-group['lr']) #This has been made more efficient by using the torch.sign function's inplace mode. #This will prevent the need to create a new tensor for the updated parameter, #which can save a significant amount of time for large models.