]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1 (#15038)
authorGeorgi Gerganov <redacted>
Sat, 2 Aug 2025 14:13:05 +0000 (17:13 +0300)
committerGitHub <redacted>
Sat, 2 Aug 2025 14:13:05 +0000 (17:13 +0300)
* cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1

ggml-ci

* cont : fix cont types

ggml-ci

* cont : adopt variable names and comment from the other branch

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-sycl/ggml-sycl.cpp

index 51792794673bbe0216c3522f18864ef7663d6d93..8885fb7fbdd2f03b55b8dc84d239182fbddd48a6 100644 (file)
@@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
     ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
     ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
 
+    bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
+    bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
+
     // Handle src0
     src0_ptr = (const cuda_t *) src0->data;
 
@@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
         s11 = ne10;
         s12 = ne11*s11;
         s13 = ne12*s12;
+
+        is_src1_cont_2 = true;
     }
 
     // Setup destination buffer
@@ -1918,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
     const int64_t r2 = ne12/ne02;
     const int64_t r3 = ne13/ne03;
 
-    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+    if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
+        // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
+        const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
+        const int64_t smb = ne12 == 1 ? s13       : s12;
+
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
         CUBLAS_CHECK(
         cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
-                       src1_ptr, cu_data_type_b, s11,       s12,       // strideB
-                beta,     dst_t, cu_data_type,   ne0,       ne1*ne0,   // strideC
+                alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma,     // strideA
+                       src1_ptr, cu_data_type_b, s11,       smb,     // strideB
+                beta,     dst_t, cu_data_type,   ne0,       ne1*ne0, // strideC
                 ne12*ne13,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
index 2acdef98a6a0b42ed03f915599dc8dfa4ec6fba0..f68f1739a9fa894a60e4096378cd1aa1adfb4216 100644 (file)
@@ -2688,6 +2688,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
     const size_t       type_size_src0 = ggml_type_size(src0->type);
     const size_t       type_size_src1 = ggml_type_size(src1->type);
 
+    bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
+    bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
+
     // SRC1 strides
     int64_t                          s11 = nb11 / type_size_src1;
     int64_t                          s12 = nb12 / type_size_src1;
@@ -2737,6 +2740,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
         s11      = ne10;
         s12      = ne11 * s11;
         s13      = ne12 * s12;
+
+        is_src1_cont_2 = true;
     }
 
     ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
@@ -2852,12 +2857,16 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
     else
 #endif
     {
-        if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+        if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
+            // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
+            const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
+            const int64_t smb = ne12 == 1 ? s13       : s12;
+
             // there is no broadcast and src0, src1 are contiguous across dims 2, 3
             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
                                                         oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
-                                                        src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
-                                                        src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
+                                                        src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
+                                                        src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
                                                         mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
         } else {
             const int ne23 = ne12 * ne13;