Skip to content

~2x perf improvement beating PyTorch (cublasLt, TF32, CUDA graphs, kernel fusion, etc…) #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

ademeure
Copy link
Contributor

This improves performance on my local RTX 4090 from ~65ms to ~34ms (while pyTorch takes ~36ms!)

ORIGINAL: step 1: train loss 4.406481 (took 64.890952 ms)
OPTIMISED: step 1: train loss 4.406351 (took 34.064025 ms)
PYTORCH: iteration 1, loss: 4.579084396362305, time: 36.545ms

Tested on pyTorch 2.2.2+cu121 with driver 550.54.14 (and nvcc 12.1 to be like-for-like, possibly faster with 12.4). TF32 is enabled for both my optimised code and pytorch using the command from README.md: python train_gpt2.py --inference_only 1 --write_tensors 0 --sequence_length 1024 --batch_size 4 --compile 1 --tensorcores 1"

When TF32 is enabled, I had to increase the tolerance from 1e2 to 1.0f in test_gpt2.cu, which results in the following output:

-43.431705 -43.351101
-39.836426 -39.763416
-43.066010 -42.994701
OK (LOGITS)
LOSS OK: 5.269499 5.270009

The biggest performance gains come from:

  1. Using cuBLASLt for matmul_forward with merged bias and GELU (with cublasLtMatmulAlgoGetHeuristic).
  2. Optional TF32 for cuBlasLt/cuBlas to match pyTorch precision (requires looser threshold in test_gpt2.cu).
  3. Optimised softmax kernels, 1st version with fused scale kernel for attention and hardcoded block size of 512 threads, 2nd version for very large C (many loop iterations) with some advanced loop unrolling tricks.
  4. CUDA graphs with a non-default stream to maximise GPU/CPU parallelism (including cudaMemcpyAsync).

One possible issue with this commit is the huge number of new global static variables for CUDA at the top of train_gpt2.cu. This is to avoid passing loads of new arguments all over the place, e.g. every kernel launch now has to use a custom CUDA stream instead of the default one in order to be able to use CUDA graphs (same reason why cuBlas(Lt) handles can only be created once now).

Also I didn't include the associated changes to the standalone .cu files for now, partly because it became a bit of a mess with the cublas(Lt) handle problem from above and it depends on whether it needs to be refactored or not, but I'm happy to provide that as well tomorrow if needed.


Not tested on A100/H100 or with different CUDA/pyTorch versions yet so there's a very strong chance it doesn't match pyTorch on other configurations, but that doesn't sound as cool as just saying it beats pyTorch, so that's what I am going with... ;)

Future work ideas:

  • Benchmark on A100 and/or H100 using CUDA 12.4 and latest pyTorch
  • Check whether there's a faster way to do cublasSgemmStridedBatched with cublasLt
  • Look into optimising away the permute/unpermute kernels (can it be "free" with the TMA on H100?)
  • Investigate H100/AD102 lossless memory compression and/or cache residency controls (interesting info on this at GTC, extremely sensitive to access patterns so might not work in practice)

…aphs, async stream, fused scale kernel, optimised softmax kernels, etc.
@karpathy
Copy link
Owner

(sounds really great! processing through this now)

@karpathy
Copy link
Owner

I think we'll want to break this up into chunks, a lot of really good stuff here.

@tantara
Copy link

tantara commented Apr 13, 2024

One more data point for your PR. I ran your PR on my 4090 (torch==2.2.2, cuda 12.1, nvidia driver 530.30.02)

  • llm.c (main branch) : step 1: train loss 4.406586 (took 37.426351 ms) with ./train_gpt2cu
  • llm.c (PR#89) : step 1: train loss 4.406351 (took 29.146125 ms) 🔥 with ./train_gpt2cu
  • pytorch : iteration 1, loss: 4.579084396362305, time: 33.262ms with python train_gpt2.py in the thread above

…4/2025 created in just ~24 hours which was fun. Remember, you can just do things.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants