Description
Description
I'm trying to use both PyCUDA and JAX in the same Python script, but I'm running into a frustrating issue. Everything works fine when I:
-Run PyCUDA code first (works perfectly)
-Then run JAX code (also works fine)
But when I try to go back to PyCUDA after using JAX, it crashes saying the GPU is "busy or unavailable". It's like JAX won't let go of the GPU even after it's done.
Here is a small snippet of code I used to try and understand what was going on :
import numpy as np
from pycuda import driver as cuda, compiler
import jax.numpy as jnp
def run_pycuda():
cuda.init()
ctx = cuda.Device(0).make_context()
# Simple CUDA multiplication kernel
mod = compiler.SourceModule("""
__global__ void test(float *x) { x[threadIdx.x] *= 2; }
""")
kernel = mod.get_function("test")
arr = np.random.randn(100).astype(np.float32)
kernel(cuda.InOut(arr), block=(100,1,1))
ctx.detach()
print("First PyCUDA run worked!")
def run_jax():
x = jnp.array([1.0, 2.0, 3.0])
if __name__ == "__main__":
run_pycuda() # Works fine
run_jax() # Also works
run_pycuda() # Crashes here
I also tried using .block_until_ready
the the jax array or @jax.block_until_ready
above the run_jax
function does not release the gpu, I tried adding a time.sleep(3)
after the function to see if some time would let jax release the gpu but it didn't work !
Any idea on how I can make jax release the GPU ?
Thank you in advance !
System info (python version, jaxlib version, accelerator, etc.)
jax version : 0.6.1 (Installed with pip install -U "jax[cuda12]")
Cuda version : 12.9 (nvidia-smi (version 575.51.03))
jax: 0.6.1
jaxlib: 0.6.1
numpy: 1.26.4
python: 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:21:13) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='vm-did', release='6.1.0-34-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.1.135-1 (2025-04-25)', machine='x86_64')
$ nvidia-smi
Mon Jun 2 09:18:29 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.51.03 Driver Version: 575.51.03 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+