diff --git a/cvxpylayers/torch/test_cvxpylayer.py b/cvxpylayers/torch/test_cvxpylayer.py index adc3f67..47dba50 100755 --- a/cvxpylayers/torch/test_cvxpylayer.py +++ b/cvxpylayers/torch/test_cvxpylayer.py @@ -88,8 +88,8 @@ def test_least_squares(self): def lstsq( A, b): return torch.solve( - (A_th.t() @ b_th).unsqueeze(1), - A_th.t() @ A_th + + (A.t() @ b).unsqueeze(1), + A.t() @ A + torch.eye(n).double())[0] x_lstsq = lstsq(A_th, b_th)