Open
Description
Description
Hello,
I get the follwing error using the code below:
E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference.
However, it only occurs iff;
- I dont set
jax.config.update("jax_default_matmul_precision", "float32")
- and vmap the forward of the component along a second (e.g. batch) dimension.
- its executed on a A100 GPU (cpu doesnt seem to be affected)
This is a repost of the issue #24909 which seems to have gotten lost.
CODE:
import jax
import jax.numpy as jnp
# setting matmul precision prevents the error
#jax.config.update("jax_default_matmul_precision", "float32")
batch_size = 32 # => increased size leads to error
num_mlps = 10 # => doesn't seem to matter
def init_mlp(rng, in_size=100, hidden_size=50, out_size=10):
"""Initialize parameters for a 2-layer MLP."""
rng_w1, rng_b1, rng_w2, rng_b2 = jax.random.split(rng, 4)
W1 = jax.random.normal(rng_w1, (in_size, hidden_size))
b1 = jax.random.normal(rng_b1, (hidden_size,))
W2 = jax.random.normal(rng_w2, (hidden_size, out_size))
b2 = jax.random.normal(rng_b2, (out_size,))
return (W1, b1, W2, b2)
def forward_mlp(params, x):
"""Single forward pass for the MLP."""
W1, b1, W2, b2 = params
x = jnp.dot(x, W1) + b1
x = jax.nn.relu(x)
x = jnp.dot(x, W2) + b2
return x
#@jax.jit # => jit-ing leads to error
def infer_mlp(mlps, x):
"""
Apply each MLP in `mlps` to a corresponding data vector in `x`.
mlps: (num_mlps,) pytree of parameters
x: (num_mlps, in_size)
"""
return jax.vmap(forward_mlp, in_axes=(0, 0))(mlps, x)
@jax.jit # => jit-ing leads to error
def infer_batch_mlps(mlps, data):
"""
For each batch element, call `infer_mlp(mlps, x)`.
data: (batch_size, num_mlps, in_size)
"""
return jax.vmap(lambda x: infer_mlp(mlps, x))(data)
# --- Main script ---
key = jax.random.PRNGKey(0)
# Create multiple MLP parameters
mlp_keys = jax.random.split(key, num_mlps)
mlps = jax.vmap(init_mlp)(mlp_keys) # shape: (num_mlps,) of parameter pytree
# Create random input data
data_key = jax.random.PRNGKey(1)
data = jax.random.normal(data_key, (batch_size, num_mlps, 100))
# Run inference
output = infer_batch_mlps(mlps, data)
print("Output shape:", output.shape)
OUTPUT:
Output shape: (32, 10, 10)
E0208 13:42:49.742799 82190 buffer_comparator.cc:157] Difference at 16: 0, expected 31.792
E0208 13:42:49.742839 82190 buffer_comparator.cc:157] Difference at 17: 0, expected 32.434
E0208 13:42:49.742842 82190 buffer_comparator.cc:157] Difference at 18: 0, expected 31.5442
E0208 13:42:49.742845 82190 buffer_comparator.cc:157] Difference at 19: 0, expected 32.2899
E0208 13:42:49.742847 82190 buffer_comparator.cc:157] Difference at 20: 0, expected 31.9846
E0208 13:42:49.742850 82190 buffer_comparator.cc:157] Difference at 21: 0, expected 31.2843
E0208 13:42:49.742852 82190 buffer_comparator.cc:157] Difference at 22: 0, expected 31.597
E0208 13:42:49.742855 82190 buffer_comparator.cc:157] Difference at 23: 0, expected 31.9733
E0208 13:42:49.742857 82190 buffer_comparator.cc:157] Difference at 24: 0, expected 32.4919
E0208 13:42:49.742860 82190 buffer_comparator.cc:157] Difference at 25: 0, expected 28.697
2025-02-08 13:42:49.742874: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0208 13:42:49.743367 82190 buffer_comparator.cc:157] Difference at 16: 0, expected 31.792
E0208 13:42:49.743380 82190 buffer_comparator.cc:157] Difference at 17: 0, expected 32.434
E0208 13:42:49.743383 82190 buffer_comparator.cc:157] Difference at 18: 0, expected 31.5442
E0208 13:42:49.743385 82190 buffer_comparator.cc:157] Difference at 19: 0, expected 32.2899
E0208 13:42:49.743388 82190 buffer_comparator.cc:157] Difference at 20: 0, expected 31.9846
E0208 13:42:49.743390 82190 buffer_comparator.cc:157] Difference at 21: 0, expected 31.2843
E0208 13:42:49.743393 82190 buffer_comparator.cc:157] Difference at 22: 0, expected 31.597
E0208 13:42:49.743395 82190 buffer_comparator.cc:157] Difference at 23: 0, expected 31.9733
E0208 13:42:49.743397 82190 buffer_comparator.cc:157] Difference at 24: 0, expected 32.4919
E0208 13:42:49.743400 82190 buffer_comparator.cc:157] Difference at 25: 0, expected 28.697
2025-02-08 13:42:49.743404: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
```
### System info (python version, jaxlib version, accelerator, etc.)
SYS ENV:
SYSTEM INFO:
OS: x86_64 GNU/Linux
GPU: NVIDIA A100-SXM4-40GB
NVIDIA-SMI 535.216.03
Driver Version: 535.216.03
CUDA Version: 12.2
CONDA ENV:
```
channels:
- nvidia
- pytorch
- conda-forge
dependencies:
- python=3.12
- pip=24.3
- setuptools=75.6
- numpy=2.2
- scipy=1.14
- scikit-learn=1.6
- pytorch=2.5.1=*cpu*
- torchinfo=1.8
- jax=0.4.35
- jaxlib=0.4.35=*cuda126*
- optax=0.2.3
- equinox=0.11.10
- jaxtyping=0.2
- chex=0.1
- hydra-core=1.3
- tqdm=4.67
- pandas=2.2
- matplotlib=3.9
- seaborn=0.13
- jupyterlab=4.3
- ipywidgets=8.1
- h5py=3.12
- urllib3=2.2
- pytest=8.3
- pre-commit=4.0
- wandb=0.19
- kaggle=1.6
```
PRINT ENV INFO:
```
jax: 0.4.35
jaxlib: 0.4.35
numpy: 2.2.0
python: 3.12.8 | packaged by conda-forge | (main, Dec 5 2024, 14:24:40) [GCC 13.3.0]
device info: NVIDIA A100-SXM4-40GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ravg1001', release='6.4.0-150600.23.38-default', version='#1 SMP PREEMPT_DYNAMIC Thu Feb 6 08:53:28 UTC 2025 (cb92f8c)', machine='x86_64')
$ nvidia-smi
Mon Mar 17 10:19:38 2025
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02 Driver Version: 535.230.02 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-40GB On | 00000000:31:00.0 Off | 0 |
| N/A 27C P0 62W / 400W | 426MiB / 40960MiB | 2% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 88869 C ...jax_pytorch_env_linux_v3/bin/python 416MiB |
+---------------------------------------------------------------------------------------+
```