]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Fix DMMV dequantization (llama/9279)
authorOuadie EL FAROUKI <redacted>
Wed, 4 Sep 2024 15:26:33 +0000 (16:26 +0100)
committerGeorgi Gerganov <redacted>
Sun, 8 Sep 2024 11:43:07 +0000 (14:43 +0300)
Fixed dmmv dequant for ncols== GGML_SYCL_DMMV_X

src/ggml-sycl/dmmv.cpp

index 5c343822f390f7fcb5e334ee36be21a33bfe2995..0c3dfaa37eb02df2f5bb942ad194d2533abaa331 100644 (file)
@@ -76,8 +76,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
     }
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
+    for (int mask = mask_start; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }