]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix MMQ nwarps for AMD with warp_size==32 (llama/15014)
authorJohannes Gäßler <redacted>
Fri, 1 Aug 2025 18:47:32 +0000 (20:47 +0200)
committerGeorgi Gerganov <redacted>
Sat, 2 Aug 2025 14:51:21 +0000 (17:51 +0300)
src/ggml-cuda/mmq.cuh

index dd60529faa41d58dab99800ae0b19b6fc85fa047..04a8d80e1211b1d22ca1f356cab40bb21b635ee2 100644 (file)
@@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
 #endif // AMD_MFMA_AVAILABLE
 
 #if defined(GGML_USE_HIP)
-static int mmq_get_nwarps_host(const int cc) {
-    return amd_mfma_available(cc) ? 8 : 4;
+static int mmq_get_nwarps_host(const int cc, const int warp_size) {
+    return amd_mfma_available(cc) ? 8 : 256/warp_size;
 }
 #else
-static int mmq_get_nwarps_host(const int /*cc*/) {
-    return 8;
+static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
+    return 256/warp_size;
 }
 #endif // (GGML_USE_HIP)
 
 static constexpr __device__ int mmq_get_nwarps_device() {
-#if defined(GGML_USE_HIP)
 #if defined(AMD_MFMA_AVAILABLE)
     return 8;
 #else
-    return 4;
+    return 256/ggml_cuda_get_physical_warp_size();
 #endif // AMD_MFMA_AVAILABLE
-#else
-    return 8;
-#endif // defined(GGML_USE_HIP)
 }
 
 // ------------------------------------------------------------
@@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
     const int cc = ggml_cuda_info().devices[id].cc;
     const int nsm = ggml_cuda_info().devices[id].nsm;
     const int warp_size = ggml_cuda_info().devices[id].warp_size;
-    const int nwarps = mmq_get_nwarps_host(cc);
+    const int nwarps = mmq_get_nwarps_host(cc, warp_size);
     const int mmq_y = get_mmq_y_host(cc);
 
     const dim3 block_dims(warp_size, nwarps, 1);
@@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
     const int    cc     = ggml_cuda_info().devices[id].cc;
     const size_t smpbo  = ggml_cuda_info().devices[id].smpbo;
     const int warp_size = ggml_cuda_info().devices[id].warp_size;
-    const int nwarps    = mmq_get_nwarps_host(cc);
+    const int nwarps    = mmq_get_nwarps_host(cc, warp_size);
 
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);