]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
use max work group size for device to replace the magic number (#14732)
authorNeo Zhang Jianyu <redacted>
Fri, 18 Jul 2025 02:23:14 +0000 (10:23 +0800)
committerGitHub <redacted>
Fri, 18 Jul 2025 02:23:14 +0000 (10:23 +0800)
ggml/src/ggml-sycl/ggml-sycl.cpp

index a6f9af0c86e11611043038f804dd697d8b9311a9..872eb4b052db9dafc2bcb1db23cd158d6ce5ba44 100644 (file)
@@ -3530,8 +3530,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
             SYCL_CHECK(CHECK_TRY_ERROR(
                 stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
 
+            const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
+            assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
+
             {
-                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
+                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
                 sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
                 sycl_launch(stream, [&](sycl::handler & cgh) {
                     sycl::local_accessor<int, 0> src1_row_acc(cgh);
@@ -3575,7 +3578,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
             ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
             {
-                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
+                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
                 sycl::range<3> grid_dims(1, 1, num_src1_rows);
                 sycl_launch(stream, [&](sycl::handler & cgh) {
                     const char *__restrict dst_contiguous_get =