]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: launch_bounds, small q4_K, q5_K mmq refactor (#2596)
authorJohannes Gäßler <redacted>
Mon, 14 Aug 2023 08:41:22 +0000 (10:41 +0200)
committerGitHub <redacted>
Mon, 14 Aug 2023 08:41:22 +0000 (10:41 +0200)
ggml-cuda.cu

index 11f67aec8e47bc48986822fb770d81c348f3e2ee..df0cbe18f96ba39b5edde2c97b412b0da406dee0 100644 (file)
@@ -1753,7 +1753,6 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
 }
 
 // contiguous u/y values
-// also used for q5_K
 static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
     const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
     const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -1763,19 +1762,18 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
     float sumf_m = 0.0f;
 
 #pragma unroll
-    for (int i0 = 0; i0 < VDR_Q4_K_Q8_1_MMQ; i0 += (QI8_1/QR4_K)) {
+    for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
         int sumi_d = 0;
 
 #pragma unroll
-        for (int i = i0; i < i0 + (QI8_1/QR4_K); ++i) {
-            sumi_d = __dp4a(v[2*i+0], u[2*i+0], sumi_d); // SIMD dot product
-            sumi_d = __dp4a(v[2*i+1], u[2*i+1], sumi_d); // SIMD dot product
+        for (int j = 0; j < QI8_1; ++j) {
+            sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
         }
 
-        const float2 ds8f = __half22float2(ds8[i0 / 4]);
+        const float2 ds8f = __half22float2(ds8[i]);
 
-        sumf_d += ds8f.x * (sc[i0/4] * sumi_d);
-        sumf_m += ds8f.y *   m[i0/4]; // sum of q8_1 block * q4_K min val
+        sumf_d += ds8f.x * (sc[i] * sumi_d);
+        sumf_m += ds8f.y *   m[i]; // sum of q8_1 block * q4_K min val
     }
 
     const float2 dm4f = __half22float2(dm4);
@@ -1792,7 +1790,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
 #define VDR_Q5_K_Q8_1_MMQ  8
 
 // contiguous v/x values
-static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
     const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
     const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
 
@@ -1829,6 +1827,40 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int j = 0; j < QI8_1; ++j) {
+            sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
+        }
+
+        const float2 ds8f = __half22float2(ds8[i]);
+
+        sumf_d += ds8f.x * (sc[i] * sumi_d);
+        sumf_m += ds8f.y *   m[i]; // sum of q8_1 block * q4_K min val
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+
+#else
+    assert(false);
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
 #define VDR_Q6_K_Q8_1_MMVQ 1
 #define VDR_Q6_K_Q8_1_MMQ  8
 
@@ -2824,18 +2856,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    int v[QR4_K*VDR_Q4_K_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q4_K_Q8_1_MMQ; ++l) {
-        v[l + 0]         = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 0) & 0x0F0F0F0F;
-        v[l + (QI4_K/4)] = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 4) & 0x0F0F0F0F;
-    }
-
     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
 
     const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
-    return vec_dot_q4_K_q8_1_impl_mmq(v, &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
+    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
+                                      x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
@@ -2882,7 +2907,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
         u[2*i+1] = q8[4];
     }
 
-    return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, bq5_K->dm, d8);
+    return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
 
 #else
 
@@ -3025,7 +3050,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
 
     const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k;
     const int index_y = j * WARP_SIZE             + (QR5_K*k) % WARP_SIZE;
-    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
+    return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
+                                      x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
 }
 
 static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
@@ -3301,7 +3327,11 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
 #define  MMQ_Y_Q4_1_PASCAL 64
 #define NWARPS_Q4_1_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q4_1(
+template <bool need_check> static __global__ void
+#if __CUDA_ARCH__ < CC_TURING
+    __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
+#endif // __CUDA_ARCH__ < CC_TURING
+    mul_mat_q4_1(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
@@ -3471,7 +3501,11 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
 #define  MMQ_Y_Q3_K_PASCAL 64
 #define NWARPS_Q3_K_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q3_K(
+template <bool need_check> static __global__ void
+#if __CUDA_ARCH__ < CC_TURING
+    __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
+#endif // __CUDA_ARCH__ < CC_TURING
+    mul_mat_q3_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
@@ -3501,11 +3535,15 @@ template <bool need_check> static __global__ void mul_mat_q3_K(
 #define  MMQ_X_Q4_K_AMPERE 64
 #define  MMQ_Y_Q4_K_AMPERE 128
 #define NWARPS_Q4_K_AMPERE 4
-#define  MMQ_X_Q4_K_PASCAL 32
+#define  MMQ_X_Q4_K_PASCAL 64
 #define  MMQ_Y_Q4_K_PASCAL 64
 #define NWARPS_Q4_K_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q4_K(
+template <bool need_check> static __global__ void
+#if __CUDA_ARCH__ < CC_TURING
+    __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
+#endif // __CUDA_ARCH__ < CC_TURING
+    mul_mat_q4_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
@@ -3569,11 +3607,15 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
 #define  MMQ_X_Q6_K_AMPERE 64
 #define  MMQ_Y_Q6_K_AMPERE 64
 #define NWARPS_Q6_K_AMPERE 4
-#define  MMQ_X_Q6_K_PASCAL 32
+#define  MMQ_X_Q6_K_PASCAL 64
 #define  MMQ_Y_Q6_K_PASCAL 64
 #define NWARPS_Q6_K_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q6_K(
+template <bool need_check> static __global__ void
+#if __CUDA_ARCH__ < CC_TURING
+    __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
+#endif // __CUDA_ARCH__ < CC_TURING
+    mul_mat_q6_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {