]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA/HIP: refractor mmqv to unify the calculation of nwarps and rows per block betwee...
authoruvos <redacted>
Tue, 11 Mar 2025 19:16:03 +0000 (20:16 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
refactor mmqv to unify the calculation of nwarps and rows per block between host and device code.

---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/mmvq.cu

index 1832314ec133be218ff2c048f34b949762a99de4..4d4ac47c034e1d26b86a10ced3f24b7a169e99f5 100644 (file)
@@ -395,11 +395,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
 
 static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
+#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
     c = __builtin_amdgcn_sdot4(a, b, c, false);
 #elif defined(RDNA3)
     c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
-#elif defined(__gfx1010__) || defined(__gfx900__)
+#elif defined(RDNA1) || defined(__gfx900__)
     int tmp1;
     int tmp2;
     asm("\n \
index 4fb466ca00627f995461920cf2a74d45d1c327a7..a7d518a574ddc0bbf632e39fc65d91614c47201b 100644 (file)
@@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
         1;
 }
 
+enum mmvq_parameter_table_id {
+    MMVQ_PARAMETERS_GENERIC = 0,
+    MMVQ_PARAMETERS_GCN,
+    MMVQ_PARAMETERS_RDNA2
+};
+
+static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
+#if defined(RDNA2) || defined(RDNA3)
+    return MMVQ_PARAMETERS_RDNA2;
+#elif defined(GCN) || defined(CDNA)
+    return MMVQ_PARAMETERS_GCN;
+#else
+    return MMVQ_PARAMETERS_GENERIC;
+#endif
+}
+
+static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
+    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
+        return MMVQ_PARAMETERS_RDNA2;
+    }
+    if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
+        return MMVQ_PARAMETERS_GCN;
+    }
+    return MMVQ_PARAMETERS_GENERIC;
+}
+
+static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_parameter_table_id table_id) {
+    if (table_id == MMVQ_PARAMETERS_GENERIC) {
+        switch (ncols_y) {
+            case 1:
+            case 2:
+            case 3:
+            case 4:
+                return 4;
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+                return 2;
+            default:
+                return 1;
+        }
+    } else if (table_id == MMVQ_PARAMETERS_GCN) {
+        switch (ncols_y) {
+            case 1:
+            case 2:
+            case 3:
+            case 4:
+                return 2;
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+            default:
+                return 1;
+        }
+    }
+    return 1;
+}
+
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
+    if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
+        switch (ncols_y) {
+            case 1:
+                return 1;
+            case 2:
+            case 3:
+            case 4:
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+                return 2;
+            default:
+                return 1;
+        }
+    }
+    return 1;
+}
+
 template <ggml_type type, int ncols_y>
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 // tell the compiler to use as many registers as it wants, see nwarps definition below
-__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -59,24 +137,20 @@ static __global__ void mul_mat_vec_q(
     constexpr int qk  = ggml_cuda_type_traits<type>::qk;
     constexpr int qi  = ggml_cuda_type_traits<type>::qi;
     constexpr int vdr = get_vdr_mmvq(type);
+    constexpr mmvq_parameter_table_id table_id = get_device_table_id();
+    constexpr int nwarps = calc_nwarps(ncols_y, table_id);
+    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
-    constexpr int nwarps              = 1;
-    constexpr int rows_per_cuda_block = 1;
-#else
-    constexpr int nwarps              = ncols_y <= 4 ? 4 : 2;
-    constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
-
-    const     int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    const     int tid = warp_size*threadIdx.y + threadIdx.x;
     const     int row0 = rows_per_cuda_block*blockIdx.x;
     const     int blocks_per_row_x = ncols_x / qk;
     const     int blocks_per_col_y = nrows_y / QK8_1;
-    constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
+    constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
-// partial sum for each thread
+    // partial sum for each thread
     float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
 
     const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
         }
     }
 
-    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
+    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
     if (threadIdx.y > 0) {
 #pragma unroll
         for (int j = 0; j < ncols_y; ++j) {
@@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q(
             for (int l = 0; l < nwarps-1; ++l) {
                 tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
             }
-            tmp[j][i] = warp_reduce_sum(tmp[j][i]);
+            tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
         }
 
         if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
@@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q(
     }
 }
 
+static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
+    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
+    const dim3 block_nums(nblocks, 1, 1);
+    const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
+    return {block_nums, block_dims};
+}
+
 template <ggml_type type>
 static void mul_mat_vec_q_cuda(
     const void * vx, const void * vy, float * dst,
@@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda(
     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
     GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
 
-    int id = ggml_cuda_get_device();
-
-    int64_t nwarps = 1;
-    int64_t rows_per_cuda_block = 1;
-
-    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
-        switch(ncols_y) {
-            case 1:
-                nwarps = 4;
-                rows_per_cuda_block = 1;
-                break;
-            case 2:
-            case 3:
-            case 4:
-                nwarps = 4;
-                rows_per_cuda_block = 2;
-                break;
-            case 5:
-            case 6:
-            case 7:
-            case 8:
-                nwarps = 2;
-                rows_per_cuda_block = 2;
-                break;
-            default:
-                GGML_ABORT("fatal error");
-                break;
-        }
-    }
-
-    const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
-    const dim3 block_nums(nblocks, 1, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
+    const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
     switch (ncols_y) {
         case 1:
-            mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 1;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 2:
-            mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 2;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 3:
-            mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 3;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 4:
-            mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 4;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 5:
-            mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 5;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 6:
-            mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 6;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 7:
-            mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 7;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         case 8:
-            mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+        {
+            constexpr int c_ncols_y = 8;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
+        }
         default:
             GGML_ABORT("fatal error");
             break;