]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
HIP: Patch failed testcase in WMMA-MMQ kernels for RDNA 4 (llama/17502)
authorJiacheng (Jason) Chen <redacted>
Wed, 26 Nov 2025 10:18:48 +0000 (05:18 -0500)
committerGeorgi Gerganov <redacted>
Thu, 11 Dec 2025 13:32:45 +0000 (15:32 +0200)
* patch failed test case MUL_MAT(type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) for enabling WMMA on RDNA4

* Quick clean up on mma.cuh to add ggml_cuda_memcpy_1 back in for half2 and bfloat162

src/ggml-cuda/mma.cuh
src/ggml-cuda/mmq.cuh

index caa08b360b566a71d2b900e618923cf02459ec5b..c0a9c2c08a9ed3dac0f4a9dce22b9d3bf533fcc9 100644 (file)
@@ -437,18 +437,27 @@ namespace ggml_cuda_mma {
             xi[0] = xs[0];
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-        if constexpr (I == 16 && J == 4) {
-            int64_t * xi = (int64_t *) t.x;
-            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
-            xi[0] = xs[0];
-        }else if constexpr (I == 16 && J == 8) {
-            int64_t * xi = (int64_t *) t.x;
-            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
-            xi[0] = xs[0];
+        if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+            ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+
+        } else if constexpr (std::is_same_v<T, int>) {
+            if constexpr (I == 16 && J == 4) {
+                int64_t * xi = (int64_t *) t.x;
+                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
+                xi[0] = xs[0];
 
-            const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
-            xi[1] = xs1[0];
-        }else{
+            }else if constexpr (I == 16 && J == 8) {
+                int64_t * xi = (int64_t *) t.x;
+                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
+                xi[0] = xs[0];
+
+                const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
+                xi[1] = xs1[0];
+
+            }else{
+                NO_DEVICE_CODE;
+            }
+        } else {
             NO_DEVICE_CODE;
         }
 #else
index 99760d56c7294a03c224536ac1ec105ccb6f9fa9..82468b384e26face836241610686d80f38888b24 100644 (file)
@@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
     const size_t nbs_ids = mmq_x*sizeof(int);
-    const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
     const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
     return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
 }