]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: int8 tensor cores for MMQ (q4_K, q5_K, q6_K) (llama/7860)
authorJohannes Gäßler <redacted>
Tue, 11 Jun 2024 06:26:07 +0000 (08:26 +0200)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
src/ggml-cuda/mma.cuh
src/ggml-cuda/mmq.cuh

index 71e8e342918aa16b1fa7ee41feefe2dfe1fc8cd6..63e07fbc21291013fac4bab8d188fb8bb6c58a5d 100644 (file)
@@ -1,5 +1,27 @@
 #include "common.cuh"
 
+struct mma_int_A_I16K4 {
+    static constexpr int I  = 16;
+    static constexpr int K  = 4;
+    static constexpr int ne = 2;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l%2) * (I/2) + threadIdx.x / K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int /* l */) {
+        const int ret = threadIdx.x % K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+};
+
 struct mma_int_A_I16K8 {
     static constexpr int I  = 16;
     static constexpr int K  = 8;
@@ -22,6 +44,28 @@ struct mma_int_A_I16K8 {
     }
 };
 
+struct mma_int_B_J8K4 {
+    static constexpr int J  = 8;
+    static constexpr int K  = 4;
+    static constexpr int ne = 1;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x / K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int /* l */) {
+        const int ret = threadIdx.x % K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+};
+
 struct mma_int_B_J8K8 {
     static constexpr int J  = 8;
     static constexpr int K  = 8;
@@ -65,6 +109,28 @@ struct mma_int_C_I16J8 {
         return ret;
     }
 
+    __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#endif // __CUDA_ARCH__ >= CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+    }
+
     __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
 #ifdef INT8_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= CC_AMPERE
index 62111f376ec815c6cd80ff01dc4f8e6ee6d75ccc..01e2086b41646936a836ffdbc1645b2f7132ba0d 100644 (file)
@@ -1089,7 +1089,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
@@ -1115,6 +1115,97 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
     }
 }
 
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+    mma_A   A[2];
+    int   scA[mma_C::ne/2][2];
+    int    mA[mma_C::ne/2][2];
+    half2 dmA[mma_C::ne/2];
+#pragma unroll
+    for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i = i0 + mma_A::get_i(l);
+            const int k = k0 + mma_A::get_k(l);
+
+            A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + mma_C::get_i(2*l);
+
+            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
+            const uint8_t *  m = sc + 8;
+
+            scA[l][kvdr/4] = sc[kvdr/4];
+            mA[l][kvdr/4]  =  m[kvdr/4];
+        }
+    }
+
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        float tmpd[mma_C::ne] = {0.0f};
+        float tmpm[mma_C::ne] = {0.0f};
+
+#pragma unroll
+        for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+            mma_C   C;
+            mma_B   B;
+            half2 dsB[mma_C::ne/2];
+
+#pragma unroll
+            for (int l = 0; l < mma_B::ne; ++l) {
+                const int j = j0 + mma_B::get_j(l);
+                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+
+                B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+            }
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+            }
+
+            C.mma_K8(A[kvdr/4], B);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) *  __low2float(dsB[l%2]);
+                tmpm[l] += mA[l/2][kvdr/4]           * __high2float(dsB[l%2]);
+            }
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
+        }
+    }
+}
+
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -1188,7 +1279,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
@@ -1214,6 +1305,97 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
     }
 }
 
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+    mma_A   A[2];
+    int   scA[mma_C::ne/2][2];
+    int    mA[mma_C::ne/2][2];
+    half2 dmA[mma_C::ne/2];
+#pragma unroll
+    for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i = i0 + mma_A::get_i(l);
+            const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
+
+            A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + mma_C::get_i(2*l);
+
+            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
+            const uint8_t *  m = sc + 8;
+
+            scA[l][kvdr/4] = sc[kvdr/4];
+            mA[l][kvdr/4]  =  m[kvdr/4];
+        }
+    }
+
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        float tmpd[mma_C::ne] = {0.0f};
+        float tmpm[mma_C::ne] = {0.0f};
+
+#pragma unroll
+        for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+            mma_C   C;
+            mma_B   B;
+            half2 dsB[mma_C::ne/2];
+
+#pragma unroll
+            for (int l = 0; l < mma_B::ne; ++l) {
+                const int j = j0 + mma_B::get_j(l);
+                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+
+                B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
+            }
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+            }
+
+            C.mma_K8(A[kvdr/4], B);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) *  __low2float(dsB[l%2]);
+                tmpm[l] += mA[l/2][kvdr/4]           * __high2float(dsB[l%2]);
+            }
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
+        }
+    }
+}
+
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -1280,7 +1462,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
-static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
@@ -1307,6 +1489,97 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
     }
 }
 
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    const float * x_df = (const float *) x_dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = threadIdx.y*mma_A::I;
+    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+
+    mma_A   A[4];
+    int   scA[mma_C::ne/2][4];
+    float  dA[mma_C::ne/2];
+#pragma unroll
+    for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i = i0 + mma_A::get_i(l);
+            const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
+
+            A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
+            A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + mma_C::get_i(2*l);
+
+            const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
+
+            scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
+            scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+        }
+    }
+
+#pragma unroll
+    for (int l = 0; l < mma_C::ne/2; ++l) {
+        const int i = i0 + mma_C::get_i(2*l);
+
+        dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
+        float tmp[mma_C::ne] = {0.0f};
+
+#pragma unroll
+        for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+            mma_C C[2];
+            mma_B B[2];
+            float dB[mma_C::ne/2];
+
+#pragma unroll
+            for (int l = 0; l < mma_B::ne; ++l) {
+                const int j = j0 + mma_B::get_j(l);
+                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+
+                B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
+                B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
+            }
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
+            }
+
+            C[0].mma_K4(A[kvdr/2 + 0], B[0]);
+            C[1].mma_K4(A[kvdr/2 + 1], B[1]);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
+            }
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne; ++l) {
+            sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
+        }
+    }
+}
+
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
 static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
 #pragma unroll
@@ -1448,24 +1721,39 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
     static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
     static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 };
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
     static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
-    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+#ifdef INT8_MMA_AVAILABLE
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
     static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
 };
 
 static int mmq_need_sum(const ggml_type type_x) {