Skip to content

Commit 2da19b8

Browse files
Fix gpu ci by adding flax in the requirements (#19349)
1 parent 08711c5 commit 2da19b8

4 files changed

+5
-0
lines changed

keras/utils/jax_layer_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def test_jax_layer(
420420
"non_trainable_params": 536,
421421
},
422422
)
423+
@pytest.mark.skipif(flax is None, reason="Flax library is not available.")
423424
def test_flax_layer(
424425
self,
425426
flax_model_class,
@@ -575,6 +576,7 @@ def jax_fn(params, inputs):
575576
test_output = model(test_inputs)
576577
self.assertAllClose(test_output, np.ones((2, 60, 3)))
577578

579+
@pytest.mark.skipif(flax is None, reason="Flax library is not available.")
578580
def test_with_flax_state_no_params(self):
579581
class MyFlaxLayer(flax.linen.Module):
580582
@flax.linen.compact

requirements-jax-cuda.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ torchvision>=0.16.0
1010
# TODO: 0.4.24 has an updated Cuda version breaks Jax CI.
1111
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
1212
jax[cuda12_pip]==0.4.23
13+
flax
1314

1415
-r requirements-common.txt

requirements-tensorflow-cuda.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ torchvision>=0.16.0
88

99
# Jax cpu-only version (needed for testing).
1010
jax[cpu]
11+
flax
1112

1213
-r requirements-common.txt

requirements-torch-cuda.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ torchvision==0.17.1+cu121
88

99
# Jax cpu-only version (needed for testing).
1010
jax[cpu]
11+
flax
1112

1213
-r requirements-common.txt

0 commit comments

Comments
 (0)