diff --git a/optax/_src/linear_algebra_test.py b/optax/_src/linear_algebra_test.py index 4e6c5c23..e3545725 100644 --- a/optax/_src/linear_algebra_test.py +++ b/optax/_src/linear_algebra_test.py @@ -245,4 +245,5 @@ def test_nnls(self, m, n, k, seed, dtype, tol=0.0): if __name__ == '__main__': + jax.config.update('jax_threefry_partitionable', False) absltest.main()