From: Ouadie EL FAROUKI Date: Wed, 4 Sep 2024 15:26:33 +0000 (+0100) Subject: [SYCL] Fix DMMV dequantization (#9279) X-Git-Tag: upstream/0.0.4488~821 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=5910ea942772ab6cbc21d0ad2d1208750ba39e1d;p=pkg%2Fggml%2Fsources%2Fllama.cpp [SYCL] Fix DMMV dequantization (#9279) Fixed dmmv dequant for ncols== GGML_SYCL_DMMV_X --- diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5c343822..0c3dfaa3 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -76,8 +76,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * } // sum up partial sums and write back result -#pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); }