Skip to content

Results do not match the reference. This is likely a bug/unexpected loss of precision. #27188

Open
@AaronSpieler

Description

@AaronSpieler

Description

Hello,

I get the follwing error using the code below:
E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference.

However, it only occurs iff;

  1. I dont set jax.config.update("jax_default_matmul_precision", "float32")
  2. and vmap the forward of the component along a second (e.g. batch) dimension.
  3. its executed on a A100 GPU (cpu doesnt seem to be affected)

This is a repost of the issue #24909 which seems to have gotten lost.

CODE:

import jax
import jax.numpy as jnp

# setting matmul precision prevents the error
#jax.config.update("jax_default_matmul_precision", "float32")

batch_size = 32  # => increased size leads to error
num_mlps = 10    # => doesn't seem to matter

def init_mlp(rng, in_size=100, hidden_size=50, out_size=10):
    """Initialize parameters for a 2-layer MLP."""
    rng_w1, rng_b1, rng_w2, rng_b2 = jax.random.split(rng, 4)
    W1 = jax.random.normal(rng_w1, (in_size, hidden_size))
    b1 = jax.random.normal(rng_b1, (hidden_size,))
    W2 = jax.random.normal(rng_w2, (hidden_size, out_size))
    b2 = jax.random.normal(rng_b2, (out_size,))
    return (W1, b1, W2, b2)

def forward_mlp(params, x):
    """Single forward pass for the MLP."""
    W1, b1, W2, b2 = params
    x = jnp.dot(x, W1) + b1
    x = jax.nn.relu(x)
    x = jnp.dot(x, W2) + b2
    return x

#@jax.jit # => jit-ing leads to error
def infer_mlp(mlps, x):
    """
    Apply each MLP in `mlps` to a corresponding data vector in `x`.
    mlps: (num_mlps,) pytree of parameters
    x:    (num_mlps, in_size)
    """
    return jax.vmap(forward_mlp, in_axes=(0, 0))(mlps, x)

@jax.jit # => jit-ing leads to error
def infer_batch_mlps(mlps, data):
    """
    For each batch element, call `infer_mlp(mlps, x)`.
    data: (batch_size, num_mlps, in_size)
    """
    return jax.vmap(lambda x: infer_mlp(mlps, x))(data)

# --- Main script ---
key = jax.random.PRNGKey(0)

# Create multiple MLP parameters
mlp_keys = jax.random.split(key, num_mlps)
mlps = jax.vmap(init_mlp)(mlp_keys)  # shape: (num_mlps,) of parameter pytree

# Create random input data
data_key = jax.random.PRNGKey(1)
data = jax.random.normal(data_key, (batch_size, num_mlps, 100))

# Run inference
output = infer_batch_mlps(mlps, data)
print("Output shape:", output.shape)

OUTPUT:

Output shape: (32, 10, 10)
E0208 13:42:49.742799   82190 buffer_comparator.cc:157] Difference at 16: 0, expected 31.792
E0208 13:42:49.742839   82190 buffer_comparator.cc:157] Difference at 17: 0, expected 32.434
E0208 13:42:49.742842   82190 buffer_comparator.cc:157] Difference at 18: 0, expected 31.5442
E0208 13:42:49.742845   82190 buffer_comparator.cc:157] Difference at 19: 0, expected 32.2899
E0208 13:42:49.742847   82190 buffer_comparator.cc:157] Difference at 20: 0, expected 31.9846
E0208 13:42:49.742850   82190 buffer_comparator.cc:157] Difference at 21: 0, expected 31.2843
E0208 13:42:49.742852   82190 buffer_comparator.cc:157] Difference at 22: 0, expected 31.597
E0208 13:42:49.742855   82190 buffer_comparator.cc:157] Difference at 23: 0, expected 31.9733
E0208 13:42:49.742857   82190 buffer_comparator.cc:157] Difference at 24: 0, expected 32.4919
E0208 13:42:49.742860   82190 buffer_comparator.cc:157] Difference at 25: 0, expected 28.697
2025-02-08 13:42:49.742874: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0208 13:42:49.743367   82190 buffer_comparator.cc:157] Difference at 16: 0, expected 31.792
E0208 13:42:49.743380   82190 buffer_comparator.cc:157] Difference at 17: 0, expected 32.434
E0208 13:42:49.743383   82190 buffer_comparator.cc:157] Difference at 18: 0, expected 31.5442
E0208 13:42:49.743385   82190 buffer_comparator.cc:157] Difference at 19: 0, expected 32.2899
E0208 13:42:49.743388   82190 buffer_comparator.cc:157] Difference at 20: 0, expected 31.9846
E0208 13:42:49.743390   82190 buffer_comparator.cc:157] Difference at 21: 0, expected 31.2843
E0208 13:42:49.743393   82190 buffer_comparator.cc:157] Difference at 22: 0, expected 31.597
E0208 13:42:49.743395   82190 buffer_comparator.cc:157] Difference at 23: 0, expected 31.9733
E0208 13:42:49.743397   82190 buffer_comparator.cc:157] Difference at 24: 0, expected 32.4919
E0208 13:42:49.743400   82190 buffer_comparator.cc:157] Difference at 25: 0, expected 28.697
2025-02-08 13:42:49.743404: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
```



### System info (python version, jaxlib version, accelerator, etc.)

SYS ENV:

SYSTEM INFO:
OS: x86_64 GNU/Linux
GPU: NVIDIA A100-SXM4-40GB
NVIDIA-SMI 535.216.03
Driver Version: 535.216.03
CUDA Version: 12.2 

CONDA ENV:
```
channels:
  - nvidia
  - pytorch
  - conda-forge
dependencies:
  - python=3.12
  - pip=24.3
  - setuptools=75.6
  - numpy=2.2
  - scipy=1.14
  - scikit-learn=1.6
  - pytorch=2.5.1=*cpu*
  - torchinfo=1.8
  - jax=0.4.35
  - jaxlib=0.4.35=*cuda126*
  - optax=0.2.3
  - equinox=0.11.10
  - jaxtyping=0.2
  - chex=0.1
  - hydra-core=1.3
  - tqdm=4.67
  - pandas=2.2
  - matplotlib=3.9
  - seaborn=0.13
  - jupyterlab=4.3
  - ipywidgets=8.1
  - h5py=3.12
  - urllib3=2.2
  - pytest=8.3
  - pre-commit=4.0
  - wandb=0.19
  - kaggle=1.6
```

PRINT ENV INFO:
```
jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.2.0
python: 3.12.8 | packaged by conda-forge | (main, Dec  5 2024, 14:24:40) [GCC 13.3.0]
device info: NVIDIA A100-SXM4-40GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ravg1001', release='6.4.0-150600.23.38-default', version='#1 SMP PREEMPT_DYNAMIC Thu Feb  6 08:53:28 UTC 2025 (cb92f8c)', machine='x86_64')


$ nvidia-smi
Mon Mar 17 10:19:38 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  | 00000000:31:00.0 Off |                    0 |
| N/A   27C    P0              62W / 400W |    426MiB / 40960MiB |      2%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     88869      C   ...jax_pytorch_env_linux_v3/bin/python      416MiB |
+---------------------------------------------------------------------------------------+
```

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions