Is there a way to force jax into CPU-only mode? #28587
-
At the top of my file I have import jax
jax.config.update("jax_default_device", "cpu") Yet somehow my computation stalls with a warning that the GPU is out of memory:
I verified that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
By default, JAX places new arrays on To make JAX ignore GPUs entirely, you could try setting the environment variable
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
print(os.environ.get('JAX_PLATFORMS')) # cpu
import jax
print(jax.devices()) # [CpuDevice(id=0)] Regarding your question about a GPU being truly out of memory: if JAX (without |
Beta Was this translation helpful? Give feedback.
By default, JAX places new arrays on
jax.devices()[0]
, which is often the first GPU or TPU.To make JAX ignore GPUs entirely, you could try setting the environment variable
JAX_PLATFORMS="cpu"
. This needs to be done before JAX initializes. Here are two common ways:Regarding your question about a GPU being truly out of memory: if JAX (without
JAX_PLATFORMS="cpu"
) …