Nvidia Tensor Core-WMMA API程式設計入門

2023-04-12 21:01:23

1 WMMA (Warp-level Matrix Multiply Accumulate) API

對於計算能力在7.0及以上的CUDA裝置,可以使用CUDA C++ API呼叫Tensor Core,支援形如D = AB + C的混合精度的矩陣乘運算。
template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment;

void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout);
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout);
void fill_fragment(fragment<...> &a, const T& v);
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
  • fragment:Tensor Core資料儲存類,支援matrix_a、matrix_b和accumulator
  • load_matrix_sync:Tensor Core資料載入API,支援將矩陣資料從global memory或shared memory載入到fragment
  • store_matrix_sync:Tensor Core結果儲存API,支援將計算結果從fragment儲存到global memory或shared memory
  • fill_fragment:fragment填充API,支援常數值填充
  • mma_sync:Tensor Core矩陣乘計算API,支援D = AB + C或者C = AB + C

2 範例

以m16n16k16為例,實現HGEMM:C = AB,其中矩陣A(M * K,row major)、B(K * N,col major)和C(M * N,row major)的精度均為FP16。首先我們看如何使用CUDA Core寫HGEMM naive演演算法。

2.1 CUDA Core

按照每個執行緒計算矩陣C中的一個元素來構建naive kernel,首先確定當前執行緒處理矩陣C的元素座標,再遍歷K並直接從global memory中載入所需A、B矩陣元素到暫存器參與計算,最後將計算結果從暫存器直接寫回矩陣C。所有block計算完成之後即可得到矩陣C。這個例子不能說簡單,只能說技術含量不高,不過我們只是為了對比。
#define DIV_CEIL(x, y) (((x) + (y) - 1) / (y))

__global__ void naiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
                                size_t N, size_t K) {
    size_t row = threadIdx.x + blockDim.x * blockIdx.x;
    size_t col = threadIdx.y + blockDim.y * blockIdx.y;
    if (row < M && col < N) {
        half tmp = 0.0;
        for (size_t i = 0; i < K; ++i) {
            tmp += A[row * K + i] * B[i + col * K];
        }
        C[row * N + col] = tmp;
    }
}

void hgemmNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
    dim3 block(16, 16);
    dim3 grid(DIV_CEIL(M, block.x), DIV_CEIL(N, block.y));

    naiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}

2.2 Tensor Core

我們再來看如何用WMMA API來構建naive kernel,參考cuda sample。與CUDA Core naive不同的是,WMMA需要按照每個warp處理一個矩陣C的WMMA_M * WMMA_N大小的tile的思路來構建,因為Tensor Core的計算層級是warp級別,計算的矩陣元素也是二維的。接下來,與CUDA Core naive的處理思路一致,首先確定當前warp處理矩陣C的tile座標,宣告計算tilie所需的fragment,再以WMMA_K為步長遍歷K並直接從global memory中載入所需A、B矩陣tile到fragment參與計算,最後將計算結果從fragment直接寫回矩陣C。所有block計算完成之後即可得到矩陣C。
值得注意的是,load_matrix_sync和store_matrix_sync都是按stride存取矩陣元素。
#include <mma.h>

#define WARP_SIZE 32

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

using namespace nvcuda;

__global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
                                size_t N, size_t K) {
    size_t warpM = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE;
    size_t warpN = (blockIdx.y * blockDim.y + threadIdx.y);

    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> c_frag;

    wmma::fill_fragment(c_frag, 0.0f);

    for (size_t i = 0; i < K; i += WMMA_K) {
        size_t aCol = i;
        size_t aRow = warpM * WMMA_M;
        size_t bCol = warpN * WMMA_N;
        size_t bRow = i;

        if (aRow < M && aCol < K && bRow < K && bCol < N) {
            wmma::load_matrix_sync(a_frag, A + aCol + aRow * K, K);
            wmma::load_matrix_sync(b_frag, B + bRow + bCol * K, K);

            wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
        }
    }

    size_t cCol = warpN * WMMA_N;
    size_t cRow = warpM * WMMA_M;

    if (cRow < M && cCol < N) {
        wmma::store_matrix_sync(C + cCol + cRow * N, c_frag, N, wmma::mem_row_major);
    }
}

void hgemmWmmaNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
    dim3 block(128, 4);
    dim3 grid((M - 1) / (WMMA_M * block.x / WARP_SIZE) + 1, (N - 1) / (WMMA_N * block.y) + 1);

    wmmaNaiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}

2.3 區別

從上述兩個naive kernel的程式碼來看呼叫CUDA Core和Tensor Core的區別如下:
  • 計算層級:CUDA Core是執行緒級別,Tensor Core是warp級別
  • 計算維度:CUDA Core是一維逐點計算,Tensor Core是二維逐tile計算
  • 計算依賴:WMMA呼叫Tensor Core需要藉助資料儲存類fragment,CUDA Core不需要藉助其他

3 底層程式碼

我們再對上述WMMA naive kernel做進一步探索,看一下它在RTX A6000(sm_86,CUDA 11.3)上對應的PTX和SASS。

3.1 PTX

dump出對應的PTX程式碼如下,好像不那麼簡單了。
.visible .entry _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm(
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5
)
{
.reg .pred %p<8>;
.reg .b16 %rs<2>;
.reg .f32 %f<2>;
.reg .b32 %r<58>;
.reg .b64 %rd<28>;

ld.param.u64 %rd9, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0];
ld.param.u64 %rd10, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1];
ld.param.u64 %rd11, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2];
ld.param.u64 %rd14, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3];
ld.param.u64 %rd12, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4];
ld.param.u64 %rd13, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5];
mov.u32 %r19, %ntid.x;
mov.u32 %r20, %ctaid.x;
mov.u32 %r21, %tid.x;
mad.lo.s32 %r22, %r20, %r19, %r21;
mov.u32 %r23, %ntid.y;
mov.u32 %r24, %ctaid.y;
mov.u32 %r25, %tid.y;
mad.lo.s32 %r26, %r24, %r23, %r25;
mov.f32 %f1, 0f00000000;

	{ cvt.rn.f16.f32 %rs1, %f1;}

	mov.b32 %r50, {%rs1, %rs1};
mul.wide.u32 %rd1, %r26, 16;
shr.u32 %r27, %r22, 1;
and.b32 %r28, %r27, 2147483632;
cvt.u64.u32 %rd2, %r28;
setp.lt.u64 %p2, %rd2, %rd14;
setp.lt.u64 %p3, %rd1, %rd12;
and.pred %p1, %p2, %p3;
setp.eq.s64 %p4, %rd13, 0;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50;
@%p4 bra $L__BB0_5;

mul.lo.s64 %rd3, %rd2, %rd13;
cvt.u32.u64 %r2, %rd13;
mul.lo.s64 %rd4, %rd1, %rd13;
cvta.to.global.u64 %rd5, %rd10;
cvta.to.global.u64 %rd6, %rd9;
mov.u64 %rd27, 0;
not.pred %p5, %p1;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50;

$L__BB0_2:
@%p5 bra $L__BB0_4;

add.s64 %rd16, %rd27, %rd3;
shl.b64 %rd17, %rd16, 1;
add.s64 %rd18, %rd6, %rd17;
wmma.load.a.sync.aligned.row.m16n16k16.global.f16 {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, [%rd18], %r2;
add.s64 %rd19, %rd27, %rd4;
shl.b64 %rd20, %rd19, 1;
add.s64 %rd21, %rd5, %rd20;
wmma.load.b.sync.aligned.col.m16n16k16.global.f16 {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, [%rd21], %r2;
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 {%r53, %r52, %r51, %r50}, {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, {%r53, %r52, %r51, %r50};

$L__BB0_4:
add.s64 %rd27, %rd27, 16;
setp.lt.u64 %p6, %rd27, %rd13;
@%p6 bra $L__BB0_2;

$L__BB0_5:
not.pred %p7, %p1;
@%p7 bra $L__BB0_7;

mul.lo.s64 %rd22, %rd2, %rd12;
add.s64 %rd23, %rd22, %rd1;
cvta.to.global.u64 %rd24, %rd11;
shl.b64 %rd25, %rd23, 1;
add.s64 %rd26, %rd24, %rd25;
cvt.u32.u64 %r45, %rd12;
wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%rd26], {%r53, %r52, %r51, %r50}, %r45;

$L__BB0_7:
ret;

}

不過我們主要關注WMMA相關的PTX指令,如下所示。可以看到這裡正是Nvidia提供的WMMA PTX指令來呼叫Tensor Core,所以無論是使用WMMA API程式設計,還是使用WMMA PTX指令程式設計,底層差別不會太大。

wmma.load.a.sync.aligned.row.m16n16k16.global.f16
wmma.load.b.sync.aligned.col.m16n16k16.global.f16
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16
wmma.store.d.sync.aligned.row.m16n16k16.global.f16

3.2 SASS

進一步dump出對應的SASS程式碼,似乎也不簡單。
      IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] 
      S2R R0, SR_CTAID.X 
      ISETP.NE.U32.AND P2, PT, RZ, c[0x0][0x188], PT 
      ULDC.64 UR4, c[0x0][0x118] 
      CS2R R8, SRZ 
      S2R R10, SR_CTAID.Y 
      ISETP.NE.AND.EX P2, PT, RZ, c[0x0][0x18c], PT, P2 
      S2R R5, SR_TID.Y 
      S2R R3, SR_TID.X 
      IMAD R10, R10, c[0x0][0x4], R5 
      IMAD R0, R0, c[0x0][0x0], R3 
      IMAD.WIDE.U32 R10, R10, 0x10, RZ 
      CS2R R2, SRZ 
      SHF.R.U32.HI R0, RZ, 0x1, R0 
      ISETP.GE.U32.AND P0, PT, R10, c[0x0][0x180], PT 
      LOP3.LUT R13, R0, 0x7ffffff0, RZ, 0xc0, !PT 
      ISETP.GE.U32.AND.EX P0, PT, R11, c[0x0][0x184], PT, P0 
      ISETP.LT.U32.AND P1, PT, R13, c[0x0][0x178], PT 
      ISETP.LT.U32.AND.EX P0, PT, RZ, c[0x0][0x17c], !P0, P1 
@!P2  BRA 0x7f1eaefc0160 
      BSSY B0, 0x7f1eaefc0160 
      IMAD.MOV.U32 R0, RZ, RZ, RZ 
      CS2R R8, SRZ 
      IMAD.MOV.U32 R15, RZ, RZ, RZ 
      IMAD.MOV.U32 R2, RZ, RZ, RZ 
      BSSY B1, 0x7f1eaefc0100 
@!P0  BRA 0x7f1eaefc00f0 
      S2R R16, SR_LANEID 
      IMAD R17, R11, c[0x0][0x188], RZ 
      IMAD.MOV.U32 R14, RZ, RZ, R0 
      IMAD.MOV.U32 R23, RZ, RZ, c[0x0][0x188] 
      IMAD.WIDE.U32 R6, R10, c[0x0][0x188], R14 
      SHF.R.U32.HI R12, RZ, 0x1, R23 
      IMAD R17, R10, c[0x0][0x18c], R17 
      LEA R21, P2, R6, c[0x0][0x168], 0x1 
      IMAD.WIDE.U32 R4, R13, c[0x0][0x188], R14 
      IMAD.IADD R7, R7, 0x1, R17 
      IMAD.MOV.U32 R17, RZ, RZ, RZ 
      IMAD R5, R13, c[0x0][0x18c], R5 
      LEA.HI.X R7, R6, c[0x0][0x16c], R7, 0x1, P2 
      SHF.R.U32.HI R19, RZ, 0x2, R16 
      LOP3.LUT R16, R16, 0x3, RZ, 0xc0, !PT 
      IMAD.WIDE.U32 R16, R19, R12, R16 
      LEA R19, P1, R4, c[0x0][0x160], 0x1 
      LEA.HI.X R5, R4, c[0x0][0x164], R5, 0x1, P1 
      LEA R18, P1, R16, R19, 0x2 
      LEA R20, P2, R16, R21, 0x2 
      LEA.HI.X R19, R16, R5, R17, 0x2, P1 
      LEA.HI.X R21, R16, R7, R17, 0x2, P2 
      IMAD.WIDE.U32 R16, R23, 0x10, R18 
      LDG.E R4, [R18.64] 
      IMAD.WIDE.U32 R22, R23, 0x10, R20 
      LDG.E R24, [R20.64] 
      LDG.E R25, [R20.64+0x10] 
      LDG.E R6, [R18.64+0x10] 
      LDG.E R5, [R16.64] 
      LDG.E R7, [R16.64+0x10] 
      LDG.E R26, [R22.64] 
      LDG.E R27, [R22.64+0x10] 
      WARPSYNC 0xffffffff 
      HMMA.16816.F16 R8, R4, R24, R8 
      HMMA.16816.F16 R2, R4, R26, R2 
      NOP 
      BSYNC B1 
      IADD3 R0, P1, R0, 0x10, RZ 
      IMAD.X R15, RZ, RZ, R15, P1 
      ISETP.GE.U32.AND P1, PT, R0, c[0x0][0x188], PT 
      ISETP.GE.U32.AND.EX P1, PT, R15, c[0x0][0x18c], PT, P1 
@!P1  BRA 0x7f1eaefbfe90 
      BSYNC B0 
@!P0  EXIT 
      S2R R4, SR_LANEID 
      IMAD.MOV.U32 R15, RZ, RZ, c[0x0][0x180] 
      WARPSYNC 0xffffffff 
      IMAD.WIDE.U32 R10, R13, c[0x0][0x180], R10 
      SHF.R.U32.HI R15, RZ, 0x1, R15 
      IMAD.MOV.U32 R5, RZ, RZ, RZ 
      LEA R7, P0, R10, c[0x0][0x170], 0x1 
      IMAD R11, R13, c[0x0][0x184], R11 
      LEA.HI.X R11, R10, c[0x0][0x174], R11, 0x1, P0 
      SHF.R.U32.HI R0, RZ, 0x2, R4 
      LOP3.LUT R4, R4, 0x3, RZ, 0xc0, !PT 
      IMAD.WIDE.U32 R4, R0, R15, R4 
      LEA R6, P0, R4, R7, 0x2 
      LEA.HI.X R7, R4, R11, R5, 0x2, P0 
      IMAD.WIDE.U32 R4, R15, 0x20, R6 
      STG.E [R6.64], R8 
      STG.E [R4.64], R9 
      STG.E [R6.64+0x10], R2 
      STG.E [R4.64+0x10], R3 
      EXIT 
      BRA 0x7f1eaefc02b0
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP

我們依然主要關注WMMA相關的SASS指令,如下所示。可以發現WMMA161616在底層是通過兩個HMMA16816指令實現,同樣地,SASS指令也是Nvidia提供的另一種呼叫Tensor Core的程式設計方法。

HMMA.16816.F16
Nvidia Tensor Core初探中提到Nvidia提供了四種呼叫Tensor Core的程式設計方法,這裡提到了三種,還有一種是MMA PTX指令,其中MMA16816 PTX指令底層實現即是HMMA16816指令,後續會在MMA PTX相關文章中提及。

4 其他

4.1 HGEMM優化

學習WMMA API的目標在於呼叫Tensor Core優化HGEMM,相比於cublas,WMMA的效能究竟如何?