]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: more warps for mmvq on NVIDIA (llama/5394)
authorJohannes Gäßler <redacted>
Thu, 8 Feb 2024 20:56:40 +0000 (21:56 +0100)
committerGeorgi Gerganov <redacted>
Sat, 10 Feb 2024 07:55:47 +0000 (09:55 +0200)
ggml-cuda.cu

index db9da24594cb249c0dc8a38e88d7e74a379870f1..5053757e6d41ef33702b972abc782a93ae6d4bce 100644 (file)
@@ -5310,22 +5310,26 @@ template <bool need_check> static __global__ void
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
-template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
+#define MMVQ_NWARPS_NVIDIA    4
+#define MMVQ_NWARPS_AMD_RDNA2 1
+#define MMVQ_NWARPS_AMD_OLD   4
+
+template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void mul_mat_vec_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
 
     const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
 
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-
-    if (row >= nrows_x) {
-        return;
-    }
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    const int row = blockIdx.x;
 
     const int blocks_per_row_x = ncols_x / qk;
     const int blocks_per_col_y = nrows_y / QK8_1;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+    const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
 
 // partial sum for each thread
     float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
@@ -5333,12 +5337,12 @@ static __global__ void mul_mat_vec_q(
     const block_q_t  * x = (const block_q_t  *) vx;
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
-    for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
+    for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
         const int ibx = row*blocks_per_row_x + i; // x block index
 
         const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
 
-        const int iqs  = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
+        const int iqs  = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
 
 #pragma unroll
         for (int j = 0; j < ncols_y; ++j) {
@@ -5346,9 +5350,25 @@ static __global__ void mul_mat_vec_q(
         }
     }
 
+    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
+    if (threadIdx.y > 0) {
+#pragma unroll
+        for (int j = 0; j < ncols_y; ++j) {
+            tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
+        }
+    }
+    __syncthreads();
+    if (threadIdx.y > 0) {
+        return;
+    }
+
     // sum up partial sums and write back result
 #pragma unroll
     for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+        for (int i = 0; i < nwarps-1; ++i) {
+            tmp[j] += tmp_shared[i][j][threadIdx.x];
+        }
         tmp[j] = warp_reduce_sum(tmp[j]);
 
         if (threadIdx.x == 0) {
@@ -6833,46 +6853,65 @@ static void mul_mat_vec_q_cuda(
     GGML_ASSERT(ncols_x % qk == 0);
     GGML_ASSERT(ncols_y <= 4);
 
-    const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    switch (ncols_y) {
-        case 1:
-            mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-            break;
-        case 2:
-            mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-            break;
-        case 3:
-            mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-            break;
-        case 4:
-            mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-            break;
-        // case 5:
-        //     mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
-        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-        //     break;
-        // case 6:
-        //     mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
-        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-        //     break;
-        // case 7:
-        //     mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
-        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-        //     break;
-        // case 8:
-        //     mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
-        //         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
-        //     break;
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    int nwarps;
+    if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
+        nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
+    } else {
+        nwarps = MMVQ_NWARPS_NVIDIA;
+    }
+
+    const dim3 block_nums(nrows_x, 1, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    switch (nwarps) {
+        case 1: switch(ncols_y) {
+            case 1:
+                mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
+                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                break;
+            case 2:
+                mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
+                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                break;
+            case 3:
+                mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
+                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                break;
+            case 4:
+                mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
+                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        } break;
+        case 4: switch(ncols_y) {
+            case 1:
+                mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
+                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                break;
+            case 2:
+                mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
+                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                break;
+            case 3:
+                mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
+                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                break;
+            case 4:
+                mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
+                    <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        } break;
+
         default:
             GGML_ASSERT(false);
-            // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
-            //     <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
             break;
     }
 }