]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: fix cuda kernel launch configuration for k_compute_batched_ptrs to support...
authorleejet <redacted>
Sun, 26 Oct 2025 18:13:31 +0000 (02:13 +0800)
committerGitHub <redacted>
Sun, 26 Oct 2025 18:13:31 +0000 (19:13 +0100)
* fix k_compute_batched_ptrs

* add backend ops test

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Johannes Gäßler <redacted>
* reduce the batch size

---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/ggml-cuda.cu
tests/test-backend-ops.cpp

index 19f72975c0ee46605ca96e4a9d07c1184a1bb6f4..6b688bfecdedd1bdde8680ce4ecb3adf0db82374 100644 (file)
@@ -1957,8 +1957,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
 
         size_t src1_stride_size = sizeof(cuda_t);
 
-        dim3 block_dims(ne13, ne12);
-        k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+        const int threads_x = 16;
+        const int threads_y = 16;
+        dim3 block_dims(threads_x, threads_y);
+
+        dim3 grid_dims(
+            (ne13 + threads_x - 1) / threads_x,
+            (ne12 + threads_y - 1) / threads_y
+        );
+        k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
                 src0_ptr, src1_ptr, dst_t,
                 ptrs_src.get(), ptrs_dst.get(),
                 ne12, ne13,
index 33ac27ff5ca00df96147686516608e8e4e5e1924..0ad73e944edaccde534a5299d77fcdd27f7784cb 100644 (file)
@@ -6697,6 +6697,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 1024, {3, 2}, {1, 1}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 1024, {3, 2}, {1, 1}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
+
+            // test cases with large batch size
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1536, 1}, {1, 1}));
         }
     }
     for (ggml_type type_a : other_types) {