]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Minor arithmetic improvement to mmvq wrapper kernel (llama/7172)
authorOuadie EL FAROUKI <redacted>
Fri, 10 May 2024 00:32:15 +0000 (01:32 +0100)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
ggml-sycl.cpp

index 57fe4ea3d4ac25d2cf65511bffcc03b6c6f58901..79aec4d9f02e1ba48bfdf775c3acf9bfdaaac1a1 100644 (file)
@@ -8330,24 +8330,26 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
     const int blocks_per_row = ncols / qk;
     const int blocks_per_warp = vdr * WARP_SIZE / qi;
 
-// partial sum for each thread
+    const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
+
+    // partial sum for each thread
     float tmp = 0.0f;
 
     const block_q_t  * x = (const block_q_t  *) vx;
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+    for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
          i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
+      const int ibx = row * blocks_per_row + i; // x block index
 
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+      const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
 
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
+      const int iqs =
+          vdr *
+          (item_ct1.get_local_id(2) -
+           i * qi_vdr); // x block quant index when casting the quants to int
 
-        tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
+      tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
     }
 
     // sum up partial sums and write back result