]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: mul_mat_vec_q tiling, refactor mul mat logic (llama/5434)
authorJohannes Gäßler <redacted>
Sun, 11 Feb 2024 18:08:39 +0000 (19:08 +0100)
committerGeorgi Gerganov <redacted>
Mon, 12 Feb 2024 07:31:12 +0000 (09:31 +0200)
* CUDA: mul_mat_vec_q tiling, refactor mul mat logic

Co-authored-by: slaren <redacted>
---------

Co-authored-by: slaren <redacted>
ggml-cuda.cu

index 5053757e6d41ef33702b972abc782a93ae6d4bce..96976f2487294d974b66a780ded82f2b8af58126 100644 (file)
 #define CUDA_USE_TENSOR_CORES
 #endif
 
-// max batch size to use MMQ kernels when tensor cores are available
-#define MMQ_MAX_BATCH_SIZE 32
+#define MMVQ_MAX_BATCH_SIZE  8 // max batch size to use MMVQ kernels
+#define  MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
 
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
@@ -5310,51 +5310,59 @@ template <bool need_check> static __global__ void
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
-#define MMVQ_NWARPS_NVIDIA    4
-#define MMVQ_NWARPS_AMD_RDNA2 1
-#define MMVQ_NWARPS_AMD_OLD   4
-
-template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
+template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
+// tell the compiler to use as many registers as it wants, see nwarps definition below
+__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void mul_mat_vec_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
+    const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
 
-    const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+    constexpr int nwarps              = 1;
+    constexpr int rows_per_cuda_block = 1;
+#else
+    constexpr int nwarps              = ncols_y <= 4 ? 4 : 2;
+    constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
 
-    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
-    const int row = blockIdx.x;
-
-    const int blocks_per_row_x = ncols_x / qk;
-    const int blocks_per_col_y = nrows_y / QK8_1;
-    const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
+    const     int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    const     int row0 = rows_per_cuda_block*blockIdx.x;
+    const     int blocks_per_row_x = ncols_x / qk;
+    const     int blocks_per_col_y = nrows_y / QK8_1;
+    constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
 
 // partial sum for each thread
-    float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
+    float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
 
     const block_q_t  * x = (const block_q_t  *) vx;
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
-    for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
-        const int ibx = row*blocks_per_row_x + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+    for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
+        const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
 
-        const int iqs  = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
+        // x block quant index when casting the quants to int
+        const int kqs = vdr * (tid % (qi/vdr));
 
 #pragma unroll
         for (int j = 0; j < ncols_y; ++j) {
-            tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
+#pragma unroll
+            for (int i = 0; i < rows_per_cuda_block; ++i) {
+                tmp[j][i] += vec_dot_q_cuda(
+                    &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
+            }
         }
     }
 
-    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
+    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
     if (threadIdx.y > 0) {
 #pragma unroll
         for (int j = 0; j < ncols_y; ++j) {
-            tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
+#pragma unroll
+            for (int i = 0; i < rows_per_cuda_block; ++i) {
+                tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
+            }
         }
     }
     __syncthreads();
@@ -5366,13 +5374,16 @@ static __global__ void mul_mat_vec_q(
 #pragma unroll
     for (int j = 0; j < ncols_y; ++j) {
 #pragma unroll
-        for (int i = 0; i < nwarps-1; ++i) {
-            tmp[j] += tmp_shared[i][j][threadIdx.x];
+        for (int i = 0; i < rows_per_cuda_block; ++i) {
+#pragma unroll
+            for (int l = 0; l < nwarps-1; ++l) {
+                tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
+            }
+            tmp[j][i] = warp_reduce_sum(tmp[j][i]);
         }
-        tmp[j] = warp_reduce_sum(tmp[j]);
 
-        if (threadIdx.x == 0) {
-            dst[j*nrows_dst + row] = tmp[j];
+        if (threadIdx.x < rows_per_cuda_block) {
+            dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
         }
     }
 }
@@ -6851,65 +6862,75 @@ static void mul_mat_vec_q_cuda(
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
     GGML_ASSERT(ncols_x % qk == 0);
-    GGML_ASSERT(ncols_y <= 4);
+    GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
 
-    int nwarps;
-    if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
-        nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
-    } else {
-        nwarps = MMVQ_NWARPS_NVIDIA;
-    }
+    int64_t nwarps = 1;
+    int64_t rows_per_cuda_block = 1;
 
-    const dim3 block_nums(nrows_x, 1, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    switch (nwarps) {
-        case 1: switch(ncols_y) {
+    if (g_device_caps[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
+        switch(ncols_y) {
             case 1:
-                mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
-                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                nwarps = 4;
+                rows_per_cuda_block = 1;
                 break;
             case 2:
-                mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
-                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-                break;
             case 3:
-                mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
-                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-                break;
             case 4:
-                mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
-                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        } break;
-        case 4: switch(ncols_y) {
-            case 1:
-                mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
-                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                nwarps = 4;
+                rows_per_cuda_block = 2;
                 break;
-            case 2:
-                mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
-                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-                break;
-            case 3:
-                mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
-                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-                break;
-            case 4:
-                mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
-                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+                nwarps = 2;
+                rows_per_cuda_block = 2;
                 break;
             default:
                 GGML_ASSERT(false);
                 break;
-        } break;
+        }
+    }
+    const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
+    const dim3 block_nums(nblocks, 1, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
 
+    switch (ncols_y) {
+        case 1:
+            mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 2:
+            mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 3:
+            mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 4:
+            mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 5:
+            mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 6:
+            mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 7:
+            mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 8:
+            mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
         default:
             GGML_ASSERT(false);
             break;
@@ -9735,7 +9756,7 @@ static __global__ void k_compute_batched_ptrs(
     ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
 }
 
-static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(!ggml_is_transposed(src0));
     GGML_ASSERT(!ggml_is_transposed(src1));
 
@@ -9893,39 +9914,69 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
 
     int64_t min_compute_capability = INT_MAX;
 
+    bool any_pascal_with_slow_fp16 = false;
     if (split) {
         ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
         auto & tensor_split = buft_ctx->tensor_split;
         for (int id = 0; id < g_device_count; ++id) {
-            if (min_compute_capability > g_device_caps[id].cc && tensor_split[id] < (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
+            // skip devices that are not going to do any work:
+            if (tensor_split[id] >= (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
+                continue;
+            }
+
+            if (min_compute_capability > g_device_caps[id].cc) {
                 min_compute_capability = g_device_caps[id].cc;
             }
+            if (g_device_caps[id].cc == 610) {
+                any_pascal_with_slow_fp16 = true;
+            }
         }
     } else {
-        min_compute_capability = g_device_caps[g_main_device].cc;
+        min_compute_capability    = g_device_caps[g_main_device].cc;
+        any_pascal_with_slow_fp16 = g_device_caps[g_main_device].cc == 610;
     }
 
+    // check data types and tensor shapes for custom matrix multiplication kernels:
+    bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
+
+    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+
+    bool              use_mul_mat_q =  ggml_cuda_supports_mmq(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 
     const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
-    bool               use_mul_mat_q = ggml_is_quantized(src0->type);
+
 #ifdef CUDA_USE_TENSOR_CORES
     use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
 #endif // CUDA_USE_TENSOR_CORES
 
 #else
 
-    const bool fp16_performance_good = min_compute_capability >= CC_VOLTA;
-    bool               use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+    // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
+    const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
+
+    // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
+    use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
+    use_mul_mat_q     = use_mul_mat_q     && min_compute_capability >= MIN_CC_DP4A;
+
 #ifdef CUDA_USE_TENSOR_CORES
     // when tensor cores are available, use them for large batch size
     // ref: https://github.com/ggerganov/llama.cpp/pull/3776
-    use_mul_mat_q = use_mul_mat_q && !(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE);
+    use_mul_mat_q     = use_mul_mat_q     && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
 #endif // CUDA_USE_TENSOR_CORES
 
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 
-    use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
+    // if mmvq is available it's a better choice than dmmv:
+#ifndef GGML_CUDA_FORCE_DMMV
+    use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
+#endif // GGML_CUDA_FORCE_DMMV
 
     // debug helpers
     //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
@@ -9943,33 +9994,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
     } else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // KQ + KQV multi-batch
-        ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
-    } else if (src0->type == GGML_TYPE_F32) {
-        ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
-    } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
-        if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) {
-#ifdef GGML_CUDA_FORCE_DMMV
-            const bool use_mul_mat_vec_q = false;
-#else
-            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
-#endif // GGML_CUDA_FORCE_DMMV
-
-            if (use_mul_mat_vec_q) {
-                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
-            } else {
-                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
-            }
-        } else {
-            if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) {
-                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
-            } else if (use_mul_mat_q) {
-                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
-            } else {
-                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
-            }
-        }
+        ggml_cuda_mul_mat_batched_cublas(src0, src1, dst);
+    } else if (use_dequantize_mul_mat_vec) {
+        ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
+    } else if (use_mul_mat_vec_q) {
+        ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
+    } else if (use_mul_mat_q) {
+        ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
     } else {
-        GGML_ASSERT(false);
+        ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
     }
 }