Skip to content

~2x perf improvement beating PyTorch (cublasLt, TF32, CUDA graphs, kernel fusion, etc…) #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Free memory properly, remove param compression because it's useless
  • Loading branch information
ademeure committed Apr 12, 2024
commit 55889c6f228cd902ab98a15152f8bb11d499a33e
24 changes: 17 additions & 7 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ GPT-2 Transformer Neural Net trained in raw CUDA
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

//#define ENABLE_PARAM_COMPRESSION
#define ENABLE_ACTIVATION_COMPRESSION
//#define MASK_ONE_BYTE_COMPRESSION
//#define MASK_TWO_BYTES_COMPRESSION
//#define MASK_THREE_BYTES_COMPRESSION
//#define MASK_ALL_BYTES_COMPRESSION
size_t activation_size = 0; // We need to keep track of the activation size for freeing the memory *sigh*

// ----------------------------------------------------------------------------
// CUDA utils & global variables
Expand Down Expand Up @@ -845,11 +845,8 @@ float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes
// malloc all parameters all at once on the device
float* params_memory;
if (on_device) {
#if defined(ENABLE_PARAM_COMPRESSION)
allocateCompressible((void**)&params_memory, num_parameters * sizeof(float), true);
#else
cudaCheck(cudaMalloc((void**)&params_memory, num_parameters * sizeof(float)));
#endif
cudaCheckErrors();
} else {
params_memory = (float*)malloc(num_parameters * sizeof(float));
}
Expand All @@ -864,6 +861,12 @@ float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes
*(ptrs[i]) = params_memory_iterator;
params_memory_iterator += param_sizes[i];
}

// Get size of nv_bfloat16
size_t nv_bfloat16_size = sizeof(nv_bfloat16);
// printf it
printf("nv_bfloat16 size: %zu\n", nv_bfloat16_size);

return params_memory;
}

Expand Down Expand Up @@ -906,7 +909,8 @@ float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes)
float* acts_memory;

#if defined(ENABLE_ACTIVATION_COMPRESSION)
allocateCompressible((void**)&acts_memory, num_activations * sizeof(float), true);
activation_size = num_activations * sizeof(float);
allocateCompressible((void**)&acts_memory, activation_size, true);
#else
cudaCheck(cudaMalloc((void**)&acts_memory, num_activations * sizeof(float)));
#endif
Expand Down Expand Up @@ -1237,11 +1241,16 @@ void gpt2_free(GPT2 *model) {
cudaCheck(cudaFree(model->grads_memory));
cudaCheck(cudaFree(model->m_memory));
cudaCheck(cudaFree(model->v_memory));
cudaCheck(cudaFree(model->acts_memory));
cudaCheck(cudaFree(model->grads_acts_memory));
cudaCheck(cudaFree(model->inputs));
cudaCheck(cudaFree(model->targets));

#if defined(ENABLE_ACTIVATION_COMPRESSION)
freeCompressible((void**)&model->acts_memory, activation_size, true);
#else
cudaCheck(cudaFree(model->acts_memory));
#endif

// free the cublas handles and CUDA stream
if (cublaslt_handle != NULL) {
cublasCheck(cublasLtDestroy(cublaslt_handle));
Expand Down Expand Up @@ -1369,6 +1378,7 @@ int sample_mult(float* probabilities, int n, float coin) {
// ----------------------------------------------------------------------------
// main training loop
int main() {
cudaCheck(cudaSetDevice(0));

// build the GPT-2 model from a checkpoint
GPT2 model;
Expand Down