]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : add RDNA4-specific MMVQ parameter table for bs=1 decode (#19478)
authorPikaPikachu <redacted>
Sun, 15 Mar 2026 07:33:39 +0000 (15:33 +0800)
committerGitHub <redacted>
Sun, 15 Mar 2026 07:33:39 +0000 (08:33 +0100)
* mmvq: add RDNA3/RDNA4-specific parameter table (nwarps=8, rows=1)

* mmvq: add dedicated RDNA3 parameter table

* mmvq: exclude RDNA3.5 (gfx1150/1151) from RDNA3 table

ggml/src/ggml-cuda/mmvq.cu
ggml/src/ggml-cuda/vendors/hip.h

index ce25ccf427cf0fa897a9ad892b2e0481e505eeb4..632246e43fd9fb62488a19a792ae708348b5cb46 100644 (file)
@@ -60,11 +60,17 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
 enum mmvq_parameter_table_id {
     MMVQ_PARAMETERS_GENERIC = 0,
     MMVQ_PARAMETERS_GCN,
-    MMVQ_PARAMETERS_RDNA2
+    MMVQ_PARAMETERS_RDNA2,
+    MMVQ_PARAMETERS_RDNA3_0,
+    MMVQ_PARAMETERS_RDNA4
 };
 
 static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
-#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
+#if defined(RDNA4)
+    return MMVQ_PARAMETERS_RDNA4;
+#elif defined(RDNA3_0)
+    return MMVQ_PARAMETERS_RDNA3_0;
+#elif defined(RDNA2) || defined(RDNA3_5)
     return MMVQ_PARAMETERS_RDNA2;
 #elif defined(GCN) || defined(CDNA)
     return MMVQ_PARAMETERS_GCN;
@@ -74,7 +80,13 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
 }
 
 static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
-    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+    if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+        return MMVQ_PARAMETERS_RDNA4;
+    }
+    if (GGML_CUDA_CC_IS_RDNA3_0(cc)) {
+        return MMVQ_PARAMETERS_RDNA3_0;
+    }
+    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) {
         return MMVQ_PARAMETERS_RDNA2;
     }
     if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
@@ -83,7 +95,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
     return MMVQ_PARAMETERS_GENERIC;
 }
 
-static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
+static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
     if (table_id == MMVQ_PARAMETERS_GENERIC) {
         switch (ncols_dst) {
             case 1:
@@ -114,6 +126,50 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
                 return 1;
         }
     }
+    if (table_id == MMVQ_PARAMETERS_RDNA4) {
+        // nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
+        // Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
+        // pressure and lookup table contention at higher thread counts.
+        if (ncols_dst == 1) {
+            switch (type) {
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                case GGML_TYPE_Q5_0:
+                case GGML_TYPE_Q5_1:
+                case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q2_K:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q5_K:
+                case GGML_TYPE_Q6_K:
+                case GGML_TYPE_IQ4_NL:
+                case GGML_TYPE_IQ4_XS:
+                    return 8;
+                default:
+                    return 1;
+            }
+        }
+        return 1;
+    }
+    if (table_id == MMVQ_PARAMETERS_RDNA3_0) {
+        // RDNA3 (W7900): stricter whitelist than RDNA4.
+        // Q2_K / Q5_K / IQ4_XS regress in full quant sweeps.
+        if (ncols_dst == 1) {
+            switch (type) {
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                case GGML_TYPE_Q5_0:
+                case GGML_TYPE_Q5_1:
+                case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q6_K:
+                case GGML_TYPE_IQ4_NL:
+                    return 8;
+                default:
+                    return 1;
+            }
+        }
+        return 1;
+    }
     return 1;
 }
 
@@ -138,7 +194,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
 }
 
 template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
-__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
+__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
         const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -151,7 +207,7 @@ static __global__ void mul_mat_vec_q(
     constexpr int qi  = ggml_cuda_type_traits<type>::qi;
     constexpr int vdr = get_vdr_mmvq(type);
     constexpr mmvq_parameter_table_id table_id = get_device_table_id();
-    constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
+    constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
     constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
@@ -355,12 +411,13 @@ static __global__ void mul_mat_vec_q(
     }
 }
 
+template<ggml_type type>
 static std::pair<dim3, dim3> calc_launch_params(
         const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
         const int warp_size, const mmvq_parameter_table_id table_id) {
     const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
     const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
-    const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
+    const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
     return {block_nums, block_dims};
 }
 
@@ -420,7 +477,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
     if (has_ids && ncols_dst > 1) {
         // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
         constexpr int c_ncols_dst = 1;
-        std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
+        std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
         mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
              channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
              sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -431,7 +488,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
     switch (ncols_dst) {
         case 1: {
             constexpr int c_ncols_dst = 1;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -439,7 +496,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         } break;
         case 2: {
             constexpr int c_ncols_dst = 2;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -447,7 +504,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         } break;
         case 3: {
             constexpr int c_ncols_dst = 3;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -455,7 +512,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         } break;
         case 4: {
             constexpr int c_ncols_dst = 4;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -463,7 +520,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         } break;
         case 5: {
             constexpr int c_ncols_dst = 5;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -471,7 +528,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         } break;
         case 6: {
             constexpr int c_ncols_dst = 6;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -479,7 +536,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         } break;
         case 7: {
             constexpr int c_ncols_dst = 7;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -487,7 +544,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         } break;
         case 8: {
             constexpr int c_ncols_dst = 8;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
index 5cc1b54319c97dfdf8323793241520f77c422ef9..35d1e1a06398b42616a1791605813101ae745ece 100644 (file)
 #define RDNA3
 #endif // defined(__GFX11__)
 
+#if defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3_5
+#endif // defined(__gfx1150__) || defined(__gfx1151__)
+
+#if defined(RDNA3) && !defined(RDNA3_5)
+#define RDNA3_0
+#endif // defined(RDNA3) && !defined(RDNA3_5)
+
 #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
     defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
 #define RDNA2