CUDA编程之GEMM优化

本文介绍CUDA编程中GEMM的优化方法,重点讲解层次化的GEMM结构,包括BlockTile、WarpTile和ThreadTile三个层次,并给出具体实现示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


前言

最近由于工作需要,研究了一下CUDA编程中的GEMM的优化,主要是学习了GEMM优化的常用方法。学习过程中主要参考了CUTLASS官方博客的优化思路,这篇文章清晰的讲述了GEMM的主要优化方法,网上也有中文翻译版本,里面有些地方翻译的可能不是很准确,在阅读中文版本的时候最好能对照原文看一下。



CUTLASS简介

CUTLASS是nv推出的一个线性代数模板库,CUTLASS将gemm的各个部分抽象成一个个可复用的组件,用户可以通过组合这些组件自定义自己的高性能kernel。CUTLASS使用分块结构来实现gemm,这种分块结构是通过block分块,warp分块和thread分块的这种层次结构来实现的。这个层次结构对应了cuda的编程模型(全局内存->共享内存->寄存器)。

所以cutlass的核心是层次化的分块结构。目的是提高数据复用度。
在这里插入图片描述


层次化的GEMM结构

CUTLASS将gemm的分块层次分为三级:block tile,warp tile,thread tile

Block Tile(Block分块)

在这里插入图片描述
block分块表示每个线程块计算C矩阵的一个部分,计算每个C子块的时候,沿着K维度循环加载A和B的子块到共享内存,然后对A和B的子块执行矩阵乘,并将计算好的结果累加到C子块中。

Warp Tile(Warp分块)

在这里插入图片描述
warp分块表示每个warp计算block子块中的一个子块,计算子块的时候,沿着block分块的K维度循环从block分块的共享内存中加载对应的子块到寄存器中,将计算结果累加到C的子矩阵中。

Thread Tile(Thread分块)

在这里插入图片描述
thread分块表示每个thread计算warp子块中的一个子块,每个线程通过计算寄存器中的子块的矩阵外积,将计算结果累加到C的子矩阵中。


层次化的GEMM结构的实现

矩阵类的设计

为了便于访问矩阵中的元素,设计了如下的矩阵类。

template<typename T>
class Matrix
{
public:
    __device__ __host__ Matrix() = default;
    __device__ __host__ Matrix(const Matrix &) = default;
    __device__ __host__ Matrix& operator=(const Matrix &) = default;
    __device__ __host__ Matrix(T *_data,int _rows,int _cols,int _strideOfRow,int _strideOfCol):
                                    data(_data),
                                    rows(_rows),
                                    cols(_cols),
                                    strideOfRow(_strideOfRow),
                                    strideOfCol(_strideOfCol){}



    // 返回该矩阵所有字节数
    constexpr __device__ __host__ int GetNumberOfBytes() const
    {
        return rows*cols*sizeof(T);
    }

    // 返回该矩阵元素个数
    constexpr __device__ __host__ int GetNumberOfElements() const
    {
        return rows*cols;
    }

    // 访问某个元素,该元素的索引为二维逻辑索引:(rowIndex,colIndex)
    __device__ __host__ float &operator()(int rowIndex,int colIndex)
    {
        // 计算内存索引
        int memoryIndex=rowIndex*strideOfRow+colIndex*strideOfCol;

        return data[memoryIndex];
    }

    // 访问某个元素,该元素的索引为一维逻辑索引:(Index)
    __device__ __host__ float &operator()(int index)
    {
        // 转换为二维逻辑索引
        int colIndex=index%cols;
        int rowIndex=index/cols;

        // 计算内存索引
        int memoryIndex=rowIndex*strideOfRow+colIndex*strideOfCol;

        return data[memoryIndex];
    }



public:
    T *data = nullptr;// 数据指针
    int rows = 0;// 矩阵的行数
    int cols = 0;// 矩阵的列数
    int strideOfRow = 0;// 行步长
    int strideOfCol = 0;// 列步长

};

该矩阵类可以作为一个view,与原数据共享内存,避免数据拷贝带来的性能问题,同时提供了一维和二维索引方式来访问view中的元素。通过该类,可以方便快捷的访问矩阵中的元素。比如在block tile层次,每次沿着K维度循环访问一个分块的时候,我们可以先为A和B矩阵的分块创建view:
在这里插入图片描述
创建矩阵A和B的view的代码如下(BM,BN,BK等参数的含义见下文):

Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+i*BK,BM,BK,A.strideOfRow,A.strideOfCol);
Matrix<float> B_View(B.data+i*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);

A_View和B_View分别表示图中粉色和黄色的子块,创建好了之后,如果要访问A子块中的第i个元素,可以通过A_View(i)的方式访问,如果要访问索引为(i,j)的元素,可以通过A_View(i,j)的方式访问,简化了数据访问。通过这种方式可以方便的将global memory中的元素拷贝到shared memory。

层次化GEMM结构的实现

由于warp在kernel中不容易显式表达,所以本文实现过程中,为了简化kernel便于理解,只实现了block tile和thread tile两个层次。
在这里插入图片描述

// 矩阵参数
#define _M 1024 // A矩阵的行数
#define _N 512 // B矩阵的列数
#define _K 256 // A矩阵的列数,B矩阵的行数

// 分块参数
#define BM 128 // block子块大小
#define BN 128
#define BK 8
#define TM 8 // thread子块大小
#define TN 8

// 定义向量化数据类型
template <typename T,int N>
struct VecType
{
    T data[N];
};
using VecDataType = VecType<float, BM*BK/(BN/TN*BM/TM)>; // BM*BK/(BN/TN*BM/TM): 表示每个线程需要处理的数据量

__global__ void BlockGEMM_V2(Matrix<float> A,Matrix<float> B,Matrix<float> C)
{
    // 每个线程的计算结果
    float c[TM][TN]={0.0};
    float a[TM]={0.0};
    float b[TN]={0.0};

    // block tile
    // 沿着K维度循环加载一个block中对应的A和B的数据到共享内存
    for(int i=0;i<A.cols/BK;++i)
    {
        // 每个block对应的全局内存中的A,B子块,即创建全局内存中A,B的view
        Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+i*BK,BM,BK,A.strideOfRow,A.strideOfCol);
        Matrix<float> B_View(B.data+i*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);

        // 将A_View,B_View加载到共享内存
        // 以BM=BN=128,BK=TM=TN=8为例:由于一个block有16x16=256个线程,而A_View和B_View中一共有1024个元素,所以每个线程加载4个元素
        __shared__ float A_Shared[BM][BK];
        __shared__ float B_Shared[BK][BN];
        int startIndex=(BM*BK/(BN/TN*BM/TM))*(threadIdx.y*blockDim.x+threadIdx.x); // 每个线程读取的起始索引,也是vec数据的起始索引,BM*BK/(BN/TN*BM/TM)表示每个线程需要读取BM*BK/(BN/TN*BM/TM)个数据
        ((VecDataType*)((float*)A_Shared+startIndex))[0]=((VecDataType*)((float*)&A_View(startIndex)))[0];
        ((VecDataType*)((float*)B_Shared+startIndex))[0]=((VecDataType*)((float*)&B_View(startIndex)))[0];
        __syncthreads();

        // 每个thread对应的共享内存中的A_Shared,B_Shared的子块,即创建A_Shared,B_Shared的view
        Matrix<float> A_View_Shared((float *)A_Shared+threadIdx.y*TM*BK,TM,BK,BK,1);// 每个线程对应的共享内存中A和B的子块
        Matrix<float> B_View_Shared((float *)B_Shared+threadIdx.x*TN,BK,TN,BN,1);

        // thread tile
        for(int k=0;k<BK;++k)
        {
            // 先将A的一列和B的一行加载到寄存器
            for(int m=0;m<TM;++m)
            {
                a[m]=A_View_Shared(m,k);
            }
            for(int n=0;n<TN;++n)
            {
                b[n]=B_View_Shared(k,n);
            }

            // 使用寄存器计算
            for(int m=0;m<TM;++m)
            {
                for(int n=0;n<TN;++n)
                {
                    c[m][n]+=a[m]*b[n];
                }
            }
        }
        __syncthreads();

    }

    // 将每个线程计算好的结果写回到C矩阵
    // C_View为每个线程对应的全局内存的C矩阵子块,创建C矩阵的view
    Matrix<float> C_View(C.data+((blockIdx.y*BM+threadIdx.y*TM)*C.strideOfRow+blockIdx.x*BN+threadIdx.x*TN),TM,TN,C.strideOfRow,C.strideOfCol);
    for(int m=0;m<TM;++m)
    {
        for(int n=0;n<TN;++n)
        {
            C_View(m,n)=c[m][n];
        }
    }

}

Bank冲突

上面的实现中,在thread tile层级中,每个线程会读取共享内存中A子块中的一列,这样就会产生共享内存的bank冲突,这里采用的解决bank冲突的方案是:将A子块读取到共享内存的时候进行转置,这样每个线程读取A矩阵的元素的时候就按照行来读取,就不会发生冲突了。

__global__ void BlockGEMM_V2_BankConflict(Matrix<float> A,Matrix<float> B,Matrix<float> C)
{
    // 每个线程的计算结果
    float c[TM][TN]={0.0};
    float a[TM]={0.0};
    float b[TN]={0.0};

    // 沿着K维度循环加载一个block中对应的A和B的数据到共享内存
    for(int i=0;i<A.cols/BK;++i)
    {
        // 每个block对应的全局内存中的A,B子块,即创建全局内存中A,B的view
        Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+i*BK,BM,BK,A.strideOfRow,A.strideOfCol);
        Matrix<float> B_View(B.data+i*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);

        // 将A_View,B_View加载到共享内存
        // 以BM=BN=128,BK=TM=TN=8为例:由于一个block有16x16=256个线程,而A_View和B_View中一共有1024个元素,所以每个线程加载4个元素
        __shared__ float A_Shared[BK][BM];
        __shared__ float B_Shared[BK][BN];
        int startIndex=(BM*BK/(BN/TN*BM/TM))*(threadIdx.y*blockDim.x+threadIdx.x); // 每个线程读取的起始索引,也是vec数据的起始索引,BM*BK/(BN/TN*BM/TM)表示每个线程需要读取BM*BK/(BN/TN*BM/TM)个数据
        VecDataType aElement=((VecDataType*)((float*)&A_View(startIndex)))[0];
        #pragma unroll 
        for(int j=0;j<(BM*BK/(BN/TN*BM/TM));++j)
        {
            A_Shared[startIndex%BK+j][startIndex/BK]=aElement.data[j];
        }
        ((VecDataType*)((float*)B_Shared+startIndex))[0]=((VecDataType*)((float*)&B_View(startIndex)))[0];
        __syncthreads();

        // 每个thread对应的共享内存中的A_Shared,B_Shared的子块,即创建A_Shared,B_Shared的view
        Matrix<float> A_View_Shared((float *)A_Shared+threadIdx.y*TM,BK,TM,BM,1);// 每个线程对应的共享内存中A和B的子块
        Matrix<float> B_View_Shared((float *)B_Shared+threadIdx.x*TN,BK,TN,BN,1);

        // 每个线程执行计算
        for(int k=0;k<BK;++k)
        {
            // 先将A的一行和B的一行加载到寄存器
            for(int m=0;m<TM;++m)
            {
                a[m]=A_View_Shared(k,m);
            }
            for(int n=0;n<TN;++n)
            {
                b[n]=B_View_Shared(k,n);
            }

            // 使用寄存器计算
            for(int m=0;m<TM;++m)
            {
                for(int n=0;n<TN;++n)
                {
                    c[m][n]+=a[m]*b[n];
                }
            }
        }
        __syncthreads();

    }

    // 将每个线程计算好的结果写回到C矩阵
    // C_View为每个线程对应的全局内存的C矩阵子块,创建C矩阵的view
    Matrix<float> C_View(C.data+((blockIdx.y*BM+threadIdx.y*TM)*C.strideOfRow+blockIdx.x*BN+threadIdx.x*TN),TM,TN,C.strideOfRow,C.strideOfCol);
    for(int m=0;m<TM;++m)
    {
        for(int n=0;n<TN;++n)
        {
            C_View(m,n)=c[m][n];
        }
    }

}

双缓冲

上面的实现中,沿着K维度循环的时候,先加载数据到共享内存,然后执行计算,计算单元需要等待所有数据加载到共享内存后才能执行计算,计算效率很低。采用双缓冲技术,可以将数据加载和计算同时执行,可以有效隐藏数据访问的延迟。
在这里插入图片描述

__global__ void BlockGEMM_V3(Matrix<float> A,Matrix<float> B,Matrix<float> C)
{
    // 每个线程的计算结果
    float c[TM][TN]={0.0};
    float a[TM]={0.0};
    float b[TN]={0.0};

    // 此时需要的共享内存是原来的2倍
    // 注意:读取和写入的时候第一个维度的索引是交错进行的
    __shared__ float A_Shared[2][BM][BK];
    __shared__ float B_Shared[2][BK][BN];

    // 预取(先读取第一个BK)
    Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+0*BK,BM,BK,A.strideOfRow,A.strideOfCol);
    Matrix<float> B_View(B.data+0*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);
    int startIndex=(BM*BK/(BN/TN*BM/TM))*(threadIdx.y*blockDim.x+threadIdx.x); // 每个线程读取的起始索引,也是vec数据的起始索引,BM*BK/(BN/TN*BM/TM)表示每个线程需要读取BM*BK/(BN/TN*BM/TM)个数据
    ((VecDataType*)((float*)A_Shared+startIndex))[0]=((VecDataType*)((float*)&A_View(startIndex)))[0];
    ((VecDataType*)((float*)B_Shared+startIndex))[0]=((VecDataType*)((float*)&B_View(startIndex)))[0];
    __syncthreads();

    // 沿着K维度循环加载剩下的数据
    int indexOfRead,indexOfWrite;
    bool indexFlag=false;// 辅助变量,用来计算索引
    for(int i=1;i<A.cols/BK;++i)
    {
        // 计算索引,indexOfRead和indexOfWrite每次循环会交替变换,i=1时为indexOfRead=0,indexOfWrite=1,i=2时为indexOfRead=1,indexOfWrite=0
        indexOfRead = (int)indexFlag; // 读索引,即本次循环读取A_Shared[indexOfRead,:,:]和B_Shared[indexOfRead,:,:]中的数据执行计算
        indexOfWrite = 1-indexOfRead; // 写索引,即预取下一次计算需要的数据到A_Shared[indexOfWrite,:,:]和B_Shared[indexOfWrite,:,:]中

        // 每个线程执行计算
        Matrix<float> A_View_Shared(((float *)A_Shared+indexOfRead*BM*BK)+threadIdx.y*TM*BK,TM,BK,BK,1);// 每个线程对应的共享内存中A和B的子块
        Matrix<float> B_View_Shared(((float *)B_Shared+indexOfRead*BK*BN)+threadIdx.x*TN,BK,TN,BN,1);
        for(int k=0;k<BK;++k)
        {
            // 先将A的一列和B的一行加载到寄存器
            for(int m=0;m<TM;++m)
            {
                a[m]=A_View_Shared(m,k);
            }
            for(int n=0;n<TN;++n)
            {
                b[n]=B_View_Shared(k,n);
            }

            // 使用寄存器计算
            for(int m=0;m<TM;++m)
            {
                for(int n=0;n<TN;++n)
                {
                    c[m][n]+=a[m]*b[n];
                }
            }
        }

        // 预取下个循环的数据
        Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+i*BK,BM,BK,A.strideOfRow,A.strideOfCol);
        Matrix<float> B_View(B.data+i*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);
        int startIndex=(BM*BK/(BN/TN*BM/TM))*(threadIdx.y*blockDim.x+threadIdx.x); // 每个线程读取的起始索引,也是vec数据的起始索引,BM*BK/(BN/TN*BM/TM)表示每个线程需要读取BM*BK/(BN/TN*BM/TM)个数据
        ((VecDataType*)((float*)A_Shared+indexOfWrite*BM*BK+startIndex))[0]=((VecDataType*)((float*)&A_View(startIndex)))[0];
        ((VecDataType*)((float*)B_Shared+indexOfWrite*BN*BK+startIndex))[0]=((VecDataType*)((float*)&B_View(startIndex)))[0];
        __syncthreads();

        // 设置flag
        indexFlag=!indexFlag;
    }

    // 计算最后一个BK
    {
        Matrix<float> A_View_Shared(((float *)A_Shared+indexOfWrite*BM*BK)+threadIdx.y*TM*BK,TM,BK,BK,1);// 每个线程对应的共享内存中A和B的子块
        Matrix<float> B_View_Shared(((float *)B_Shared+indexOfWrite*BK*BN)+threadIdx.x*TN,BK,TN,BN,1);
        for(int k=0;k<BK;++k)
        {
            // 先将A的一列和B的一行加载到寄存器
            for(int m=0;m<TM;++m)
            {
                a[m]=A_View_Shared(m,k);
            }
            for(int n=0;n<TN;++n)
            {
                b[n]=B_View_Shared(k,n);
            }

            // 使用寄存器计算
            for(int m=0;m<TM;++m)
            {
                for(int n=0;n<TN;++n)
                {
                    c[m][n]+=a[m]*b[n];
                }
            }
        }
    }

    // 将每个线程计算好的结果写回到C矩阵
    // C_View为每个线程对应的全局内存的C矩阵子块,创建C矩阵的view
    Matrix<float> C_View(C.data+((blockIdx.y*BM+threadIdx.y*TM)*C.strideOfRow+blockIdx.x*BN+threadIdx.x*TN),TM,TN,C.strideOfRow,C.strideOfCol);
    for(int m=0;m<TM;++m)
    {
        for(int n=0;n<TN;++n)
        {
            C_View(m,n)=c[m][n];
        }
    }
}

完整示例代码

完整示例代码如下:

/* GEMM优化示例

C=αA∗B+βC

示例程序中:alpha=1,beta=0
*/

#include <stdio.h>
#include <sys/time.h>
#include <cuda.h>
#include <cublas.h>

// 计时
static double seconds()
{
    struct timeval tp;
    struct timezone tzp;
    int i = gettimeofday(&tp, &tzp);
    return ((double)tp.tv_sec + (double)tp.tv_usec * 1.e-6);
}

// 矩阵参数
#define _M 1024 // A矩阵的行数
#define _N 512 // B矩阵的列数
#define _K 256 // A矩阵的列数,B矩阵的行数

//////////////////////// 分块参数 ////////////////////////
// 注意:修改了BM,BN,TM,TN后需要修改每个block中的线程数

// #define BM 128 // block子块大小
// #define BN 128
// #define BK 16
// #define TM 8 // thread子块大小
// #define TN 8

#define BM 128 // block子块大小
#define BN 128
#define BK 8
#define TM 8 // thread子块大小
#define TN 8

// #define BM 128 // block子块大小
// #define BN 128
// #define BK 16
// #define TM 4 // thread子块大小
// #define TN 4

// #define BM 128 // block子块大小
// #define BN 128
// #define BK 8
// #define TM 4 // thread子块大小
// #define TN 4

// #define BM 64 // block子块大小
// #define BN 64
// #define BK 16
// #define TM 8 // thread子块大小
// #define TN 8

// #define BM 64 // block子块大小
// #define BN 64
// #define BK 16
// #define TM 4 // thread子块大小
// #define TN 4


template<typename T>
class Matrix
{
public:
    __device__ __host__ Matrix() = default;
    __device__ __host__ Matrix(const Matrix &) = default;
    __device__ __host__ Matrix& operator=(const Matrix &) = default;
    __device__ __host__ Matrix(T *_data,int _rows,int _cols,int _strideOfRow,int _strideOfCol):
                                    data(_data),
                                    rows(_rows),
                                    cols(_cols),
                                    strideOfRow(_strideOfRow),
                                    strideOfCol(_strideOfCol){}

    // 返回该矩阵所有字节数
    constexpr __device__ __host__ int GetNumberOfBytes() const
    {
        return rows*cols*sizeof(T);
    }

    // 返回该矩阵元素个数
    constexpr __device__ __host__ int GetNumberOfElements() const
    {
        return rows*cols;
    }

    // 访问某个元素,该元素的索引为二维逻辑索引:(rowIndex,colIndex)
    __device__ __host__ float &operator()(int rowIndex,int colIndex)
    {
        // 计算内存索引
        int memoryIndex=rowIndex*strideOfRow+colIndex*strideOfCol;

        return data[memoryIndex];
    }

    // 访问某个元素,该元素的索引为一维逻辑索引:(Index)
    __device__ __host__ float &operator()(int index)
    {
        // 转换为二维逻辑索引
        int colIndex=index%cols;
        int rowIndex=index/cols;

        // 计算内存索引
        int memoryIndex=rowIndex*strideOfRow+colIndex*strideOfCol;

        return data[memoryIndex];
    }



public:
    T *data = nullptr;// 数据指针
    int rows = 0;// 矩阵的行数
    int cols = 0;// 矩阵的列数
    int strideOfRow = 0;// 行步长
    int strideOfCol = 0;// 列步长

};


__global__ void NaiveGEMM(Matrix<float> A,Matrix<float> B,Matrix<float> C) // 注意,这里传参的时候不能传引用,而是要传值
{
    
    // 获取线程在网格内的索引
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    // 每个线程计算矩阵C的一个元素
    if(row<C.rows&&col<C.cols)
    {
        float c = 0;
        for (int i = 0; i < A.cols; ++i)
        {
            c += A(row,i)*B(i,col);// 使用A的第row行乘以B的第col列
        }
        C(row,col) = c;
    }
    
}
// Block Tile: Block分块,每个线程计算C矩阵的一个数据
__global__ void BlockGEMM_V1(Matrix<float> A,Matrix<float> B,Matrix<float> C)
{
    // 注意命名不要与前面的宏定义重名
    const int BLOCK_N=16;// 每个block的x方向线程数
    const int BLOCK_M=16;// 每个block的y方向线程数
    const int BLOCK_K=16;

    // 沿着K维度循环加载一个block中对应的A和B的数据到共享内存
    float c=0.0;
    for(int i=0;i<A.cols/BLOCK_K;++i)
    {
        // 每个block对应的全局内存中的A,B子块,即创建全局内存中A,B的view
        Matrix<float> A_View(A.data+blockIdx.y*BLOCK_M*A.strideOfRow+i*BLOCK_K,BLOCK_M,BLOCK_K,A.strideOfRow,A.strideOfCol);
        Matrix<float> B_View(B.data+i*BLOCK_K*B.strideOfRow+blockIdx.x*BLOCK_N,BLOCK_K,BLOCK_N,B.strideOfRow,B.strideOfCol);

        // 将A_View,B_View加载到共享内存
        // 注意:这里需要将一维逻辑索引转换为多维逻辑索引:startIndex->(startIndex/cols,startIndex%cols)
        __shared__ float A_Shared[BLOCK_M][BLOCK_K];
        __shared__ float B_Shared[BLOCK_K][BLOCK_N];
        int index=threadIdx.y*blockDim.x+threadIdx.x;// index为每个线程读取的数据索引,这里每个线程读取一个元素
        ((float*)A_Shared+index)[0]=A_View(index);
        ((float*)B_Shared+index)[0]=B_View(index);
        __syncthreads();

        // 每个thread计算A的一行和B的一列
        for(int k=0;k<BLOCK_K;++k)
        {
            c+=A_Shared[threadIdx.y][k]*B_Shared[k][threadIdx.x];
        }
        __syncthreads();

    }

    // 将每个线程计算好的结果写回到C矩阵
    // C_View为每个线程对应的全局内存的C矩阵子块,创建C矩阵的view
    Matrix<float> C_View(C.data+(blockIdx.y*BLOCK_M*C.strideOfRow+blockIdx.x*BLOCK_N),BLOCK_M,BLOCK_N,C.strideOfRow,C.strideOfCol);
    C_View(threadIdx.y,threadIdx.x)=c;

}


// 定义向量化数据类型
template <typename T,int N>
struct VecType
{
    T data[N];
};
using VecDataType = VecType<float, BM*BK/(BN/TN*BM/TM)>; // BM*BK/(BN/TN*BM/TM): 表示每个线程需要处理的数据量
// Block Tile+Thread Tile: 每个Block计算一个子块,同时每个线程计算Block中的一个子块
__global__ void BlockGEMM_V2(Matrix<float> A,Matrix<float> B,Matrix<float> C)
{
    // 每个线程的计算结果
    float c[TM][TN]={0.0};
    float a[TM]={0.0};
    float b[TN]={0.0};

    // block tile
    // 沿着K维度循环加载一个block中对应的A和B的数据到共享内存
    for(int i=0;i<A.cols/BK;++i)
    {
        // 每个block对应的全局内存中的A,B子块,即创建全局内存中A,B的view
        Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+i*BK,BM,BK,A.strideOfRow,A.strideOfCol);
        Matrix<float> B_View(B.data+i*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);

        // 将A_View,B_View加载到共享内存
        // 以BM=BN=128,BK=TM=TN=8为例:由于一个block有16x16=256个线程,而A_View和B_View中一共有1024个元素,所以每个线程加载4个元素
        __shared__ float A_Shared[BM][BK];
        __shared__ float B_Shared[BK][BN];
        int startIndex=(BM*BK/(BN/TN*BM/TM))*(threadIdx.y*blockDim.x+threadIdx.x); // 每个线程读取的起始索引,也是vec数据的起始索引,BM*BK/(BN/TN*BM/TM)表示每个线程需要读取BM*BK/(BN/TN*BM/TM)个数据
        ((VecDataType*)((float*)A_Shared+startIndex))[0]=((VecDataType*)((float*)&A_View(startIndex)))[0];
        ((VecDataType*)((float*)B_Shared+startIndex))[0]=((VecDataType*)((float*)&B_View(startIndex)))[0];
        __syncthreads();

        // 每个thread对应的共享内存中的A_Shared,B_Shared的子块,即创建A_Shared,B_Shared的view
        Matrix<float> A_View_Shared((float *)A_Shared+threadIdx.y*TM*BK,TM,BK,BK,1);// 每个线程对应的共享内存中A和B的子块
        Matrix<float> B_View_Shared((float *)B_Shared+threadIdx.x*TN,BK,TN,BN,1);

        // thread tile
        for(int k=0;k<BK;++k)
        {
            // 先将A的一列和B的一行加载到寄存器
            for(int m=0;m<TM;++m)
            {
                a[m]=A_View_Shared(m,k);
            }
            for(int n=0;n<TN;++n)
            {
                b[n]=B_View_Shared(k,n);
            }

            // 使用寄存器计算
            for(int m=0;m<TM;++m)
            {
                for(int n=0;n<TN;++n)
                {
                    c[m][n]+=a[m]*b[n];
                }
            }
        }
        __syncthreads();

    }

    // 将每个线程计算好的结果写回到C矩阵
    // C_View为每个线程对应的全局内存的C矩阵子块,创建C矩阵的view
    Matrix<float> C_View(C.data+((blockIdx.y*BM+threadIdx.y*TM)*C.strideOfRow+blockIdx.x*BN+threadIdx.x*TN),TM,TN,C.strideOfRow,C.strideOfCol);
    for(int m=0;m<TM;++m)
    {
        for(int n=0;n<TN;++n)
        {
            C_View(m,n)=c[m][n];
        }
    }

}

// 解决共享内存bank冲突的问题:共享内存读到寄存器的时候,每次读取A的一列,这样存在bank冲突,将A_Shared进行转置,这样就可以读取一行了
__global__ void BlockGEMM_V2_BankConflict(Matrix<float> A,Matrix<float> B,Matrix<float> C)
{
    // 每个线程的计算结果
    float c[TM][TN]={0.0};
    float a[TM]={0.0};
    float b[TN]={0.0};

    // 沿着K维度循环加载一个block中对应的A和B的数据到共享内存
    for(int i=0;i<A.cols/BK;++i)
    {
        // 每个block对应的全局内存中的A,B子块,即创建全局内存中A,B的view
        Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+i*BK,BM,BK,A.strideOfRow,A.strideOfCol);
        Matrix<float> B_View(B.data+i*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);

        // 将A_View,B_View加载到共享内存
        // 以BM=BN=128,BK=TM=TN=8为例:由于一个block有16x16=256个线程,而A_View和B_View中一共有1024个元素,所以每个线程加载4个元素
        __shared__ float A_Shared[BK][BM];
        __shared__ float B_Shared[BK][BN];
        int startIndex=(BM*BK/(BN/TN*BM/TM))*(threadIdx.y*blockDim.x+threadIdx.x); // 每个线程读取的起始索引,也是vec数据的起始索引,BM*BK/(BN/TN*BM/TM)表示每个线程需要读取BM*BK/(BN/TN*BM/TM)个数据
        VecDataType aElement=((VecDataType*)((float*)&A_View(startIndex)))[0];
        #pragma unroll 
        for(int j=0;j<(BM*BK/(BN/TN*BM/TM));++j)
        {
            A_Shared[startIndex%BK+j][startIndex/BK]=aElement.data[j];
        }
        ((VecDataType*)((float*)B_Shared+startIndex))[0]=((VecDataType*)((float*)&B_View(startIndex)))[0];
        __syncthreads();

        // 每个thread对应的共享内存中的A_Shared,B_Shared的子块,即创建A_Shared,B_Shared的view
        Matrix<float> A_View_Shared((float *)A_Shared+threadIdx.y*TM,BK,TM,BM,1);// 每个线程对应的共享内存中A和B的子块
        Matrix<float> B_View_Shared((float *)B_Shared+threadIdx.x*TN,BK,TN,BN,1);

        // 每个线程执行计算
        for(int k=0;k<BK;++k)
        {
            // 先将A的一行和B的一行加载到寄存器
            for(int m=0;m<TM;++m)
            {
                a[m]=A_View_Shared(k,m);
            }
            for(int n=0;n<TN;++n)
            {
                b[n]=B_View_Shared(k,n);
            }

            // 使用寄存器计算
            for(int m=0;m<TM;++m)
            {
                for(int n=0;n<TN;++n)
                {
                    c[m][n]+=a[m]*b[n];
                }
            }
        }
        __syncthreads();

    }

    // 将每个线程计算好的结果写回到C矩阵
    // C_View为每个线程对应的全局内存的C矩阵子块,创建C矩阵的view
    Matrix<float> C_View(C.data+((blockIdx.y*BM+threadIdx.y*TM)*C.strideOfRow+blockIdx.x*BN+threadIdx.x*TN),TM,TN,C.strideOfRow,C.strideOfCol);
    for(int m=0;m<TM;++m)
    {
        for(int n=0;n<TN;++n)
        {
            C_View(m,n)=c[m][n];
        }
    }

}
// 在V2的基础上加上数据预取,通过计算和访存并发执行来隐藏访存的延迟
__global__ void BlockGEMM_V3(Matrix<float> A,Matrix<float> B,Matrix<float> C)
{
    // 每个线程的计算结果
    float c[TM][TN]={0.0};
    float a[TM]={0.0};
    float b[TN]={0.0};

    // 此时需要的共享内存是原来的2倍
    // 注意:读取和写入的时候第一个维度的索引是交错进行的
    __shared__ float A_Shared[2][BM][BK];
    __shared__ float B_Shared[2][BK][BN];

    // 预取(先读取第一个BK)
    Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+0*BK,BM,BK,A.strideOfRow,A.strideOfCol);
    Matrix<float> B_View(B.data+0*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);
    int startIndex=(BM*BK/(BN/TN*BM/TM))*(threadIdx.y*blockDim.x+threadIdx.x); // 每个线程读取的起始索引,也是vec数据的起始索引,BM*BK/(BN/TN*BM/TM)表示每个线程需要读取BM*BK/(BN/TN*BM/TM)个数据
    ((VecDataType*)((float*)A_Shared+startIndex))[0]=((VecDataType*)((float*)&A_View(startIndex)))[0];
    ((VecDataType*)((float*)B_Shared+startIndex))[0]=((VecDataType*)((float*)&B_View(startIndex)))[0];
    __syncthreads();

    // 沿着K维度循环加载剩下的数据
    int indexOfRead,indexOfWrite;
    bool indexFlag=false;// 辅助变量,用来计算索引
    for(int i=1;i<A.cols/BK;++i)
    {
        // 计算索引,indexOfRead和indexOfWrite每次循环会交替变换,i=1时为indexOfRead=0,indexOfWrite=1,i=2时为indexOfRead=1,indexOfWrite=0
        indexOfRead = (int)indexFlag; // 读索引,即本次循环读取A_Shared[indexOfRead,:,:]和B_Shared[indexOfRead,:,:]中的数据执行计算
        indexOfWrite = 1-indexOfRead; // 写索引,即预取下一次计算需要的数据到A_Shared[indexOfWrite,:,:]和B_Shared[indexOfWrite,:,:]中

        // 每个线程执行计算
        Matrix<float> A_View_Shared(((float *)A_Shared+indexOfRead*BM*BK)+threadIdx.y*TM*BK,TM,BK,BK,1);// 每个线程对应的共享内存中A和B的子块
        Matrix<float> B_View_Shared(((float *)B_Shared+indexOfRead*BK*BN)+threadIdx.x*TN,BK,TN,BN,1);
        for(int k=0;k<BK;++k)
        {
            // 先将A的一列和B的一行加载到寄存器
            for(int m=0;m<TM;++m)
            {
                a[m]=A_View_Shared(m,k);
            }
            for(int n=0;n<TN;++n)
            {
                b[n]=B_View_Shared(k,n);
            }

            // 使用寄存器计算
            for(int m=0;m<TM;++m)
            {
                for(int n=0;n<TN;++n)
                {
                    c[m][n]+=a[m]*b[n];
                }
            }
        }

        // 预取下个循环的数据
        Matrix<float> A_View(A.data+blockIdx.y*BM*A.strideOfRow+i*BK,BM,BK,A.strideOfRow,A.strideOfCol);
        Matrix<float> B_View(B.data+i*BK*B.strideOfRow+blockIdx.x*BN,BK,BN,B.strideOfRow,B.strideOfCol);
        int startIndex=(BM*BK/(BN/TN*BM/TM))*(threadIdx.y*blockDim.x+threadIdx.x); // 每个线程读取的起始索引,也是vec数据的起始索引,BM*BK/(BN/TN*BM/TM)表示每个线程需要读取BM*BK/(BN/TN*BM/TM)个数据
        ((VecDataType*)((float*)A_Shared+indexOfWrite*BM*BK+startIndex))[0]=((VecDataType*)((float*)&A_View(startIndex)))[0];
        ((VecDataType*)((float*)B_Shared+indexOfWrite*BN*BK+startIndex))[0]=((VecDataType*)((float*)&B_View(startIndex)))[0];
        __syncthreads();

        // 设置flag
        indexFlag=!indexFlag;
    }

    // 计算最后一个BK
    {
        Matrix<float> A_View_Shared(((float *)A_Shared+indexOfWrite*BM*BK)+threadIdx.y*TM*BK,TM,BK,BK,1);// 每个线程对应的共享内存中A和B的子块
        Matrix<float> B_View_Shared(((float *)B_Shared+indexOfWrite*BK*BN)+threadIdx.x*TN,BK,TN,BN,1);
        for(int k=0;k<BK;++k)
        {
            // 先将A的一列和B的一行加载到寄存器
            for(int m=0;m<TM;++m)
            {
                a[m]=A_View_Shared(m,k);
            }
            for(int n=0;n<TN;++n)
            {
                b[n]=B_View_Shared(k,n);
            }

            // 使用寄存器计算
            for(int m=0;m<TM;++m)
            {
                for(int n=0;n<TN;++n)
                {
                    c[m][n]+=a[m]*b[n];
                }
            }
        }
    }

    // 将每个线程计算好的结果写回到C矩阵
    // C_View为每个线程对应的全局内存的C矩阵子块,创建C矩阵的view
    Matrix<float> C_View(C.data+((blockIdx.y*BM+threadIdx.y*TM)*C.strideOfRow+blockIdx.x*BN+threadIdx.x*TN),TM,TN,C.strideOfRow,C.strideOfCol);
    for(int m=0;m<TM;++m)
    {
        for(int n=0;n<TN;++n)
        {
            C_View(m,n)=c[m][n];
        }
    }
}
// CPU矩阵乘法,用于验证正确性
void MatMul_CPU(float *a,float *b,float *c,int M,int N,int K)
{
    for(int i=0;i<M;++i)
    {
        for(int j=0;j<N;++j)
        {
            float sum=0.0f;
            for(int k=0;k<K;++k)
            {
                sum+=a[i*K+k]*b[k*N+j];
            }
            c[i*N+j]=sum;
        }
    }

}
int main(int argc,char *argv[])
{
    // 创建Host A,B矩阵,这里不考虑边界处理,创建的矩阵的行和列可以被分块大小整除
    float *A_Host;
    float *B_Host;
    float *C_Host;
    A_Host = (float *)malloc(_M*_K*sizeof(float));
    B_Host = (float *)malloc(_K*_N*sizeof(float));
    C_Host = (float *)malloc(_M*_N*sizeof(float));
    printf("A size:%d x %d\n",_M,_K);
    printf("B size:%d x %d\n",_K,_N);
    for(int i=0;i<_M;++i)
    {
        for(int j=0;j<_K;++j)
        {
            A_Host[i*_K+j]=rand()%10;
        }
    }
    for(int i=0;i<_K;++i)
    {
        for(int j=0;j<_N;++j)
        {
            B_Host[i*_N+j]=rand()%10;
        }
    }
    // 计算Host C矩阵
    double time1,time2;
    time1=seconds();
    MatMul_CPU((float*)A_Host,(float *)B_Host,(float *)C_Host,_M,_N,_K);
    time2=seconds();
    printf("cpu elapsed:%f ms\n",(time2-time1)*1000);

    // 创建Device A矩阵
    float *dataOfA_Device=nullptr;
    cudaMalloc((void **)&dataOfA_Device, _M*_K*sizeof(float));
    cudaMemcpy(dataOfA_Device, (float *)A_Host, _M*_K*sizeof(float), cudaMemcpyHostToDevice);
    Matrix<float> A_Device(dataOfA_Device,_M,_K,_K,1);

    // 创建Device B矩阵
    float *dataOfB_Device=nullptr;
    cudaMalloc((void **)&dataOfB_Device, _K*_N*sizeof(float));
    cudaMemcpy(dataOfB_Device, (float *)B_Host, _K*_N*sizeof(float), cudaMemcpyHostToDevice);
    Matrix<float> B_Device(dataOfB_Device,_K,_N,_N,1);

    // 创建Device C矩阵
    float *dataOfC_Device=nullptr;
    cudaMalloc((void **)&dataOfC_Device, _M*_N*sizeof(float));
    cudaMemset(dataOfC_Device, 0, _M*_N*sizeof(float));
    Matrix<float> C_Device(dataOfC_Device,_M,_N,_N,1);
    
    ////////////////////////////// NaiveGEMM /////////////////////////////////////////////
    // {
    //     int BLOCKX = 16;// 每个block的x方向线程数
    //     int BLOCKY = 16;// 每个block的y方向线程数
    //     dim3 block(BLOCKX,BLOCKY);
    //     dim3 grid((C_Device.cols+BLOCKX-1) / BLOCKX,(C_Device.rows+BLOCKY-1) / BLOCKY);
    //     for(int i=0;i<10;++i)
    //     {
    //         time1=seconds();
    //         NaiveGEMM<<<grid, block>>>(A_Device,B_Device,C_Device);
    //         cudaDeviceSynchronize();
    //         time2=seconds();
    //         printf("NaiveGEMM elapsed:%f ms\n",(time2-time1)*1000);
    //     }
        
    // }

    //////////////////////////////// BlockGEMM_V1 /////////////////////////////////////////////
    // {
    //     int BLOCKX = 16;// 每个block的x方向线程数
    //     int BLOCKY = 16;// 每个block的y方向线程数
    //     dim3 block(BLOCKX,BLOCKY);
    //     dim3 grid(C_Device.cols/BLOCKX,C_Device.rows/BLOCKY);
    //     for(int i=0;i<10;++i)
    //     {
    //         time1=seconds();
    //         BlockGEMM_V1<<<grid, block>>>(A_Device,B_Device,C_Device);
    //         cudaDeviceSynchronize();
    //         time2=seconds();
    //         printf("BlockGEMM_V1 elapsed:%f ms\n",(time2-time1)*1000);
    //     }

    // }

    //////////////////////////////// BlockGEMM_V2 /////////////////////////////////////////////
    // {
    //     int BLOCKX = BN/TN;// 每个block的x方向线程数
    //     int BLOCKY = BM/TM;// 每个block的y方向线程数
    //     dim3 block(BLOCKX,BLOCKY);
    //     dim3 grid(C_Device.cols/BN,C_Device.rows/BM);
    //     for(int i=0;i<10;++i)
    //     {
    //         time1=seconds();
    //         BlockGEMM_V2<<<grid, block>>>(A_Device,B_Device,C_Device);
    //         // BlockGEMM_V2_BankConflit<<<grid, block>>>(A_Device,B_Device,C_Device);
    //         cudaDeviceSynchronize();
    //         time2=seconds();
    //         printf("BlockGEMM_V2 elapsed:%f ms\n",(time2-time1)*1000);
    //     }

    // }

    ////////////////////////////// BlockGEMM_V2_BankConflict /////////////////////////////////////////////
    {
        int BLOCKX = BN/TN;// 每个block的x方向线程数
        int BLOCKY = BM/TM;// 每个block的y方向线程数
        dim3 block(BLOCKX,BLOCKY);
        dim3 grid(C_Device.cols/BN,C_Device.rows/BM);
        for(int i=0;i<10;++i)
        {
            time1=seconds();
            BlockGEMM_V2_BankConflict<<<grid, block>>>(A_Device,B_Device,C_Device);
            cudaDeviceSynchronize();
            time2=seconds();
            printf("BlockGEMM_V2_BankConflict elapsed:%f ms\n",(time2-time1)*1000);
        }

    }

    //////////////////////////////// BlockGEMM_V3 /////////////////////////////////////////////
    // {
    //     int BLOCKX = BN/TN;// 每个block的x方向线程数
    //     int BLOCKY = BM/TM;// 每个block的y方向线程数
    //     dim3 block(BLOCKX,BLOCKY);
    //     dim3 grid(C_Device.cols/BN,C_Device.rows/BM);
    //     for(int i=0;i<10;++i)
    //     {
    //         time1=seconds();
    //         BlockGEMM_V3<<<grid, block>>>(A_Device,B_Device,C_Device);
    //         cudaDeviceSynchronize();
    //         time2=seconds();
    //         printf("BlockGEMM_V3 elapsed:%f ms\n",(time2-time1)*1000);
    //     }

    // }

    //////////////////////////////// cublas /////////////////////////////////////////////
    // {
    //     cublasHandle_t handle;
    //     cublasCreate_v2(&handle);
	
	//     cublasOperation_t  transA;
    //     cublasOperation_t  transB;
	
	//     transA=CUBLAS_OP_N;
	//     transB=CUBLAS_OP_N;// CUBLAS_OP_N,CUBLAS_OP_T

    //     float alpha=1.0;
	//     float beta=0.0;

        
    //     for(int i=0;i<10;++i)
    //     {
    //         time1=seconds();
    //         cublasSgemm_v2(handle,transA,transB,_N,_M,_K,&alpha,B_Device.data,_N,A_Device.data,_K,&beta,C_Device.data,_N);
    //         cudaDeviceSynchronize();
    //         time2=seconds();
    //         printf("cublas elapsed:%f ms\n",(time2-time1)*1000);
    //     }

    // }

    // 拷贝GPU结果
    float *dataOfC_DeviceToHost=nullptr;
    dataOfC_DeviceToHost=(float *)malloc(C_Device.GetNumberOfBytes());
    cudaMemcpy(dataOfC_DeviceToHost, C_Device.data, C_Device.GetNumberOfBytes(), cudaMemcpyDeviceToHost);

    // 验证结果的正确性
    float *resultOfHost=(float*)C_Host;
    float *resultOfDevice=(float*)dataOfC_DeviceToHost;
    int numberOfError=0;
    for(int i=0;i<C_Device.GetNumberOfElements();++i)
    {
        if(std::isnan(resultOfDevice[i]) || fabs(resultOfHost[i]-resultOfDevice[i])>1e-3)
        {
            ++numberOfError;
        }
    }
    if(numberOfError==0)
    {
        printf("OK!\n");
    }
    else
    {
        printf("ERROR!\n");
    }

    // free
    cudaFree(dataOfA_Device);
    cudaFree(dataOfB_Device);
    cudaFree(dataOfC_Device);
    free(dataOfC_DeviceToHost);
    free(A_Host);
    free(B_Host);
    free(C_Host);

    return 0;
}

将上述示例代码保存为GEMM.cu后,使用如下命令可以运行示例代码:

nvcc -std=c++17 -I=./,/usr/local/cuda/include/ -L=/usr/local/cuda/lib64/ -l=cublas GEMM.cu

性能测试

bank冲突和双缓冲对性能的影响

测试环境:A800
为了简化测试,本次实验选取固定分块参数:BM=BN=128,BK=TM=TN=8,矩阵size设置为M=N=K
表格数据单位:ms
在这里插入图片描述
由于本文的bank冲突解决方案并不能完全避免bank冲突,所以bank冲突带来的性能提升较小,双缓冲可以带来明显的性能提升。

不同参数对性能的影响

对BlockGEMM_V2进行测试,实验中选取不同的MNK和分块参数,观察不同问题规模下不同分块参数对性能的影响

表格数据单位:ms

在这里插入图片描述

可以看出不同的MNK,对应的最佳分块大小是不同的,同时对于同一组MNK,不同分块参数的性能差异可能很大。


参考文献

  1. https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/
  2. https://siboehm.com/articles/22/CUDA-MMM
  3. https://github.com/NVIDIA/cutlass/blob/main/media/docs/efficient_gemm.md
  4. https://zhuanlan.zhihu.com/p/461060382
  5. https://zhuanlan.zhihu.com/p/410278370

结束语

由于自己也是刚接触GEMM的优化,所以对GEMM的优化研究并不是很深入,本文只是实现了一些最基本的GEMM优化方法,目的主要是为了了解在GPU上优化GEMM的常用方法。文中有什么不对的地方,欢迎批评指正。


评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值