Skip to content

Unable to Reuse GPU After JAX Execution #29159

Closed
@NeyKoZv

Description

@NeyKoZv

Description

I'm trying to use both PyCUDA and JAX in the same Python script, but I'm running into a frustrating issue. Everything works fine when I:

-Run PyCUDA code first (works perfectly)
-Then run JAX code (also works fine)

But when I try to go back to PyCUDA after using JAX, it crashes saying the GPU is "busy or unavailable". It's like JAX won't let go of the GPU even after it's done.

Here is a small snippet of code I used to try and understand what was going on :

import numpy as np
from pycuda import driver as cuda, compiler
import jax.numpy as jnp

def run_pycuda():
    cuda.init()
    ctx = cuda.Device(0).make_context()
    
    # Simple CUDA multiplication kernel
    mod = compiler.SourceModule("""
    __global__ void test(float *x) { x[threadIdx.x] *= 2; }
    """)
    kernel = mod.get_function("test")
    
    arr = np.random.randn(100).astype(np.float32)
    kernel(cuda.InOut(arr), block=(100,1,1))
    
    ctx.detach()
    print("First PyCUDA run worked!")

def run_jax():
    x = jnp.array([1.0, 2.0, 3.0])

if __name__ == "__main__":
    run_pycuda()  # Works fine
    run_jax()     # Also works
    run_pycuda()  # Crashes here

I also tried using .block_until_ready the the jax array or @jax.block_until_ready above the run_jax function does not release the gpu, I tried adding a time.sleep(3) after the function to see if some time would let jax release the gpu but it didn't work !

Any idea on how I can make jax release the GPU ?

Thank you in advance !

System info (python version, jaxlib version, accelerator, etc.)

jax version : 0.6.1 (Installed with pip install -U "jax[cuda12]")
Cuda version : 12.9 (nvidia-smi (version 575.51.03))

jax: 0.6.1
jaxlib: 0.6.1
numpy: 1.26.4
python: 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:21:13) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='vm-did', release='6.1.0-34-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.1.135-1 (2025-04-25)', machine='x86_64')

$ nvidia-smi
Mon Jun 2 09:18:29 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.51.03 Driver Version: 575.51.03 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions