]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sycl: addressing non-contiguous src1 mul_mats (nc and batched) (#13343)
authorAlberto Cabrera Pérez <redacted>
Thu, 8 May 2025 09:08:01 +0000 (10:08 +0100)
committerGitHub <redacted>
Thu, 8 May 2025 09:08:01 +0000 (10:08 +0100)
* sycl: fixed non-contiguous src1 mul_mats (nc and batched)

* Fixed wrong static_cast inside kernel

ggml/src/ggml-sycl/common.hpp
ggml/src/ggml-sycl/convert.cpp
ggml/src/ggml-sycl/convert.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp

index c71cc89c09eac55786590576e649acaa4b136653..69aad938e88dad442c89699db54b3a65023feca4 100644 (file)
@@ -114,17 +114,12 @@ static void crash() {
   GGML_ABORT("SYCL error");
 }
 
-#define SYCL_CHECK(err)                     \
-  do {                                      \
-    auto err_ = (err);                      \
-    if (err_ != 0)                          \
-      ggml_sycl_error(                      \
-          #err,                             \
-          __func__,                         \
-          __FILE__,                         \
-          __LINE__,                         \
-          "Meet error in this line code!"); \
-  } while (0)
+#define SYCL_CHECK(err)                                                                                    \
+    do {                                                                                                   \
+        auto err_ = (err);                                                                                 \
+        if (err_ != 0)                                                                                     \
+            ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \
+    } while (0)
 
 #if DPCT_COMPAT_RT_VERSION >= 11100
 #define GGML_SYCL_ASSUME(x) __builtin_assume(x)
index 76ac6a4dd1f7bc0c06ec3afb0fbe1e9cae59966f..b2f8a65693363a96aa6fb755deb0fcee03624eb1 100644 (file)
@@ -437,41 +437,52 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
 }
 
 template <typename src_t, typename dst_t>
-static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
-                          const sycl::nd_item<3> &item_ct1) {
+static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
+                          const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
+                          const sycl::nd_item<3> & item_ct1) {
+
     const int64_t work_group_size = item_ct1.get_local_range(2);
-    const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
+    const int64_t global_id       = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
+
+    const int64_t i01 = item_ct1.get_group(1);
+    const int64_t i02 = item_ct1.get_group(0) % ne02;
+    const int64_t i03 = item_ct1.get_group(0) / ne02;
 
     // make each work-item deal with more elements since sycl global range can not exceed max int
-    const src_t * x = (const src_t *) vx;
-    for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
-        y[i] = x[i];
+    const src_t * x = static_cast<const src_t *>(vx);
+    const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;
+    const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;
+
+#pragma unroll
+    for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {
+        y[iy + i00] = static_cast<dst_t>(x[ix + i00]);
     }
 }
 
 template <typename src_t, typename dst_t>
-static void convert_unary_sycl(const void *__restrict__ vx,
-                               dst_t *__restrict__ y, const int64_t k,
-                               dpct::queue_ptr stream) {
-    const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
+static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,
+                                  const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+                                  const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {
+    dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
+
+    sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));
 
     // decrease global range when it exceeds the max int
-    int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
-    sycl::range<3> block_nums(1, 1, num_blocks);
-    sycl::range<3> local_range(1, 1, local_size);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+    // TODO: Downsample logic is separated from the kernel, a rewrite is desirable
+    int64_t        downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);
+    sycl::range<3> workgroup_size(1, 1, downsized_workgroup);
 
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * local_range, local_range),
-            [=](sycl::nd_item<3> item_ct1) {
-                convert_unary<src_t>(vx, y, k, item_ct1);
-            });
-    }
+    queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {
+        convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);
+    });
 }
 
-to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
+template <typename src_t, typename dst_t>
+static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {
+    convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
+}
+
+to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
     switch (type) {
         case GGML_TYPE_Q4_0:
             if (dst->src[0]->extra &&
@@ -574,3 +585,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
             return nullptr;
     }
 }
+
+to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_F32:
+            return convert_unary_nc_sycl<float>;
+        default:
+            return nullptr;
+    }
+}
index 355dae22b40758d6695dcd5519d8ad640e83d7d0..f8cb573e3688bc470aecac9f5bbddf232a1c028b 100644 (file)
@@ -1,6 +1,6 @@
 //
 // MIT license
-// Copyright (C) 2024 Intel Corporation
+// Copyright (C) 2025 Intel Corporation
 // SPDX-License-Identifier: MIT
 //
 
 #include "common.hpp"
 
 template <typename T>
-using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
-                             int64_t k, dpct::queue_ptr stream);
-typedef to_t_sycl_t<float> to_fp32_sycl_t;
+using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream);
+typedef to_t_sycl_t<float>      to_fp32_sycl_t;
 typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
 
-to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
-to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);
+to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst);
+to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst);
 
-#endif // GGML_SYCL_CONVERT_HPP
+// Nc = Non-contiguous
+template <typename T>
+using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
+                                   int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
+
+typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
+to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
+
+#endif  // GGML_SYCL_CONVERT_HPP
index ea5d10f40ee38904d6c61b22c1ba14ab21ab6229..68a26fa481ddb53f005edd5811b0eda99bb5dded 100644 (file)
@@ -2694,35 +2694,31 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
-                                   const sycl::half *src1_as_f16, char *dst,
-                                   const void **ptrs_src, void **ptrs_dst,
-                                   int64_t ne12, int64_t ne13, int64_t ne23,
-                                   size_t nb02, size_t nb03, size_t nb12,
-                                   size_t nb13, size_t nbd2, size_t nbd3,
-                                   int64_t r2, int64_t r3,
-                                   const sycl::nd_item<3> &item_ct1) {
-    int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                  item_ct1.get_local_id(2);
-    int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
-                  item_ct1.get_local_id(1);
+static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
+                                   const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
+                                   size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
+                                   int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
+    const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
+    const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
 
     if (i13 >= ne13 || i12 >= ne12) {
         return;
     }
 
-    int64_t i03 = i13 / r3;
-    int64_t i02 = i12 / r2;
+    const int64_t i03 = i13 / r3;
+    const int64_t i02 = i12 / r2;
+
+    const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
+    const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
+    uint8_t *       dst_bytes  = reinterpret_cast<uint8_t *>(dst);
 
-    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
-    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
+    ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
+    ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
+    ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
 }
 
-static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
-                                             const ggml_tensor *src0,
-                                             const ggml_tensor *src1,
-                                             ggml_tensor *dst) try {
+static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
+                                           const ggml_tensor * src1, ggml_tensor * dst) try {
     GGML_ASSERT(!ggml_is_transposed(src0));
     GGML_ASSERT(!ggml_is_transposed(src1));
     GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
@@ -2730,102 +2726,100 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
+    // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
+    // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
+    GGML_ASSERT(ggml_is_contiguous(dst));
 
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();;
+    queue_ptr queue = ctx.stream();
 
-    void * src0_ddq = src0->data;
-    sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf = (float *) dst->data;
+    dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
 
-    // convert src1 to fp16
+    const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
+    float *            dst_ddf  = static_cast<float *>(dst->data);
+
+    const sycl::half * src1_f16       = static_cast<const sycl::half *>(src1->data);
+    const size_t       type_size_src1 = ggml_type_size(src1->type);
+    GGML_ASSERT(nb10 == type_size_src1);
+
+    // SRC1 strides
+    int64_t                          s11 = nb11 / type_size_src1;
+    int64_t                          s12 = nb12 / type_size_src1;
+    int64_t                          s13 = nb13 / type_size_src1;
     ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
+
+    // convert src1 to fp16
     if (src1->type != GGML_TYPE_F16) {
-        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
+        const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
+        GGML_ASSERT(to_fp16_nc_sycl != nullptr);
         const int64_t ne_src1 = ggml_nelements(src1);
         src1_f16_alloc.alloc(ne_src1);
-        GGML_ASSERT(to_fp16_sycl != nullptr);
-        to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+        to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
+
+        src1_f16 = src1_f16_alloc.get();
+        s11      = ne10;
+        s12      = ne11 * s11;
+        s13      = ne12 * s12;
     }
-    sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
-                                                       : src1_f16_alloc.get();
 
-    char * dst_t;
+    ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
+    char *                           dst_t = reinterpret_cast<char *>(dst_ddf);
 
-    dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
-    dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
+    dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
+    dpct::library_data_t mkl_data_type    = dpct::library_data_t::real_float;
 
     // dst strides
     size_t nbd2 = dst->nb[2];
     size_t nbd3 = dst->nb[3];
 
     const float alpha_f32 = 1.0f;
-    const float beta_f32 = 0.0f;
+    const float beta_f32  = 0.0f;
 
     const void * alpha = &alpha_f32;
     const void * beta  = &beta_f32;
 
-    dst_t = (char *) dst_ddf;
-
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
 
     // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
 
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
-            (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
-            (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
-            cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
+        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
+                                                    oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+                                                    src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
+                                                    src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
+                                                    mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
     } else {
-        const int ne23 = ne12*ne13;
+        const int ne23 = ne12 * ne13;
 
-        ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
-        ggml_sycl_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
+        ggml_sycl_pool_alloc<const void *>         ptrs_src(ctx.pool(), 2 * ne23);
+        ggml_sycl_pool_alloc<void *>               ptrs_dst(ctx.pool(), 1 * ne23);
         ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
 
         sycl::range<3> block_dims(1, ne12, ne13);
-        /*
-        DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(main_stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            main_stream->submit([&](sycl::handler &cgh) {
-                const void **ptrs_src_get = ptrs_src.get();
-                void **ptrs_dst_get = ptrs_dst.get();
-                size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
-                size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
-                cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
-                                 [=](sycl::nd_item<3> item_ct1) {
-                                     k_compute_batched_ptrs(
-                                         src0_as_f16, src1_f16,
-                                         dst_t, ptrs_src_get,
-                                         ptrs_dst_get, ne12, ne13, ne23,
-                                         nb02, nb03, nb12_scaled, nb13_scaled,
-                                         nbd2, nbd3, r2, r3, item_ct1);
-                                 });
+        queue->submit([&](sycl::handler & cgh) {
+            const void ** ptrs_src_get = ptrs_src.get();
+            void **       ptrs_dst_get = ptrs_dst.get();
+            size_t        nb12_scaled  = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
+            size_t        nb13_scaled  = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
+            cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+                k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
+                                       nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
             });
-        }
+        });
+
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+            *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
             (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
-            (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
-            (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
+            (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
+            (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
     }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+} catch (const sycl::exception & exc) {
+    std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
+    std::exit(1);
 }
 
 inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
@@ -2966,7 +2960,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
             // The kernel from the if path is faster for that specific case, but does not support all mul mats.
             ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
         }
-    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         // KQV single-batch
         ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
     } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
@@ -3873,9 +3867,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 if (a->ne[3] != b->ne[3]) {
                     return false;
                 }
-                if (!ggml_is_contiguous(b)) {
-                    return false;
-                }
                 ggml_type a_type = a->type;
                 if (a_type == GGML_TYPE_IQ4_NL  || a_type == GGML_TYPE_IQ4_XS ||
                     a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S  ||