]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
HIP: Use mmq on MFMA devices for MUL_MAT_ID in cases where a lot of splits would...
authoruvos <redacted>
Sun, 28 Dec 2025 19:12:55 +0000 (20:12 +0100)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/mmq.cu
src/ggml-cuda/mmq.cuh

index 55fa2e6a7cc3473f2172ef2bae13ca9536c95af0..40ffe92c575b7cb1918b60b1ae49a0ace1a24592 100644 (file)
@@ -2211,7 +2211,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
 
             const int cc            = ggml_cuda_info().devices[id].cc;
             const int warp_size     = ggml_cuda_info().devices[id].warp_size;
-            use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
             use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
             use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
             any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
@@ -2219,7 +2219,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     } else {
         const int cc            = ggml_cuda_info().devices[ctx.device].cc;
         const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;
-        use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
         use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
         use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
         any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
@@ -2287,7 +2287,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             return;
         }
 
-        if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
+        if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
             ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
             return;
         }
index 6156dcdae74ac4fd5070ea2f9b17c591724578a8..85692d454300011b8794d5da082247ff2f514b28 100644 (file)
@@ -259,7 +259,7 @@ void ggml_cuda_op_mul_mat_q(
     GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size);
 }
 
-bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) {
 #ifdef GGML_CUDA_FORCE_CUBLAS
     return false;
 #endif // GGML_CUDA_FORCE_CUBLAS
@@ -320,7 +320,10 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         if (GGML_CUDA_CC_IS_CDNA3(cc)) {
             return true;
         }
-        if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
+        if (n_experts > 64 || ne11 <= 128) {
+            return true;
+        }
+        if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
             return true;
         }
         if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
index 63451ffab7f32678d0ed079a8997e54aff321871..a382e6a6979f4da0c7681370171158eafb1cc92a 100644 (file)
@@ -4082,4 +4082,4 @@ void ggml_cuda_op_mul_mat_q(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
 
-bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);