]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix MMQ stream-k rounding if ne00 % 128 != 0 (llama/8311)
authorJohannes Gäßler <redacted>
Fri, 5 Jul 2024 07:05:34 +0000 (09:05 +0200)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 10:03:28 +0000 (13:03 +0300)
src/ggml-cuda/mmq.cuh

index deaed066f7c908c6e088e247068f0d55df28d84b..a97afc7ac80aa81b56a525811e7c7152661ccaf7 100644 (file)
@@ -2305,8 +2305,11 @@ static __global__ void mul_mat_q(
     const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int64_t       kbc      = GGML_PAD((int64_t) blockIdx.x     *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
-    const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
+    int64_t kbc      = (int64_t) blockIdx.x     *blocks_per_ne00*ntx*nty / gridDim.x;
+    int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
+
+    kbc      -= (kbc      % blocks_per_ne00) % blocks_per_warp;
+    kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
 
     // kb0 == k index when doing the matrix multiplication for an output tile.
     int kb0_start = kbc % blocks_per_ne00;
@@ -2362,8 +2365,11 @@ static __global__ void mul_mat_q_stream_k_fixup(
     const int bidx_stop  = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
 
     for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
-        const int64_t kbc      = GGML_PAD((int64_t) bidx     *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
-        const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
+        int64_t kbc      = (int64_t) bidx     *blocks_per_ne00*ntx*nty / block_num_mmq;
+        int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
+
+        kbc      -= (kbc      % blocks_per_ne00) % blocks_per_warp;
+        kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
 
         // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
         if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {