]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
[SYCL] fix scratch size of softmax (#8642)
authorluoyu-intel <redacted>
Tue, 23 Jul 2024 07:43:28 +0000 (07:43 +0000)
committerGitHub <redacted>
Tue, 23 Jul 2024 07:43:28 +0000 (15:43 +0800)
ggml/src/ggml-sycl/softmax.cpp

index c5d9a837eb79412b8090bd571871a39a7146b3ec..17a542e49036278c0e650ba6661ae96af2deef12 100644 (file)
@@ -152,7 +152,8 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
 
     const sycl::range<3> block_dims(1, 1, nth);
     const sycl::range<3> block_nums(1, 1, nrows_x);
-    const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
+    const size_t n_val_tmp = nth / WARP_SIZE;
+    const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
 
     const uint32_t n_head_kv   = nrows_x/nrows_y;
     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));