Skip to content

Is there a way to force jax into CPU-only mode? #28587

Discussion options

You must be logged in to vote

By default, JAX places new arrays on jax.devices()[0], which is often the first GPU  or TPU.

To make JAX ignore GPUs entirely, you could try setting the environment variable JAX_PLATFORMS="cpu". This needs to be done before JAX initializes. Here are two common ways:

  1. From your shell:
$ JAX_PLATFORMS=cpu python -c "import jax"
  1. In your Python script (before the first import jax, or after restarting your Python kernel if jax was already imported):
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
print(os.environ.get('JAX_PLATFORMS')) # cpu

import jax
print(jax.devices()) # [CpuDevice(id=0)]

Regarding your question about a GPU being truly out of memory: if JAX (without JAX_PLATFORMS="cpu") …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Jacob-Stevens-Haas
Comment options

Answer selected by Jacob-Stevens-Haas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants