~2x perf improvement beating PyTorch (cublasLt, TF32, CUDA graphs, kernel fusion, etc…) #89
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This improves performance on my local RTX 4090 from ~65ms to ~34ms (while pyTorch takes ~36ms!)
ORIGINAL: step 1: train loss 4.406481 (took 64.890952 ms)
OPTIMISED: step 1: train loss 4.406351 (took 34.064025 ms)
PYTORCH: iteration 1, loss: 4.579084396362305, time: 36.545ms
Tested on pyTorch 2.2.2+cu121 with driver 550.54.14 (and nvcc 12.1 to be like-for-like, possibly faster with 12.4). TF32 is enabled for both my optimised code and pytorch using the command from README.md:
python train_gpt2.py --inference_only 1 --write_tensors 0 --sequence_length 1024 --batch_size 4 --compile 1 --tensorcores 1"
When TF32 is enabled, I had to increase the tolerance from 1e2 to 1.0f in test_gpt2.cu, which results in the following output:
The biggest performance gains come from:
One possible issue with this commit is the huge number of new global static variables for CUDA at the top of train_gpt2.cu. This is to avoid passing loads of new arguments all over the place, e.g. every kernel launch now has to use a custom CUDA stream instead of the default one in order to be able to use CUDA graphs (same reason why cuBlas(Lt) handles can only be created once now).
Also I didn't include the associated changes to the standalone .cu files for now, partly because it became a bit of a mess with the cublas(Lt) handle problem from above and it depends on whether it needs to be refactored or not, but I'm happy to provide that as well tomorrow if needed.
Not tested on A100/H100 or with different CUDA/pyTorch versions yet so there's a very strong chance it doesn't match pyTorch on other configurations, but that doesn't sound as cool as just saying it beats pyTorch, so that's what I am going with... ;)
Future work ideas: