]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sycl: use oneDNN for matrices multiplication (llama/12972)
authorŁukasz Ślusarczyk <redacted>
Thu, 15 May 2025 14:53:41 +0000 (16:53 +0200)
committerGeorgi Gerganov <redacted>
Mon, 19 May 2025 11:58:39 +0000 (14:58 +0300)
ggml/CMakeLists.txt
ggml/src/ggml-sycl/CMakeLists.txt
ggml/src/ggml-sycl/gemm.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp

index a8300e16d87fe19e72d4d6fe8d8de00c06baa316..4746d5cb76c08efbc06aaeed6f593ddc76b4e4bb 100644 (file)
@@ -193,6 +193,7 @@ option(GGML_RPC                             "ggml: use RPC"
 option(GGML_SYCL                            "ggml: use SYCL"                                  OFF)
 option(GGML_SYCL_F16                        "ggml: use 16 bit floats for sycl calculations"   OFF)
 option(GGML_SYCL_GRAPH                      "ggml: enable graphs in the SYCL backend"         ON)
+option(GGML_SYCL_DNN                        "ggml: enable oneDNN in the SYCL backend"         ON)
 set   (GGML_SYCL_TARGET "INTEL" CACHE STRING
                                             "ggml: sycl target device")
 set   (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
index 231fb71dab5dab27c320df600b2318db3e918d52..a2e26124802b21391be2ac468ce2d31361c98060 100644 (file)
@@ -49,34 +49,38 @@ endif()
 target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
 
 # Link against oneDNN
-find_package(DNNL)
 set(GGML_SYCL_DNNL 0)
-if(DNNL_FOUND)
-    if (NOT DEFINED DNNL_GPU_VENDOR)
-        # default to intel target
-        set(DNNL_GPU_VENDOR "INTEL")
-        if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
-            message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
+if(GGML_SYCL_DNN)
+    find_package(DNNL)
+    if(DNNL_FOUND)
+        if (NOT DEFINED DNNL_GPU_VENDOR)
+            # default to intel target
+            set(DNNL_GPU_VENDOR "INTEL")
+            if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
+                message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
+            endif()
         endif()
-    endif()
 
-    # Verify oneDNN was compiled for the same target as llama
-    if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
-        target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
-        set(GGML_SYCL_DNNL 1)
-        get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
-        foreach(CONFIG ${CONFIGS})
-            get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
-            message(STATUS "Found oneDNN: ${DNNL_LIB}")
-        endforeach()
+        # Verify oneDNN was compiled for the same target as llama
+        if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
+            target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
+            set(GGML_SYCL_DNNL 1)
+            get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
+            foreach(CONFIG ${CONFIGS})
+                get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
+                message(STATUS "Found oneDNN: ${DNNL_LIB}")
+            endforeach()
+        else()
+            message(WARNING
+                "oneDNN must be compiled for the same target as llama.cpp.
+                 llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
+                 Disabling oneDNN support.")
+        endif()
     else()
-        message(WARNING
-            "oneDNN must be compiled for the same target as llama.cpp.
-             llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
-             Disabling oneDNN support.")
+        message(STATUS "oneDNN not found, disabling oneDNN support")
     endif()
 else()
-    message(STATUS "oneDNN not found, disabling oneDNN support")
+    message(STATUS "oneDNN support disabled by the user")
 endif()
 target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
 
index 4ebbb5b66fb47a4134d0f56380caa281ac2489ed..6cbc7e0f6938cc8767444944c39f2748b3b0c7fd 100644 (file)
@@ -32,16 +32,36 @@ public:
         else static_assert(0);
     }
 
-    static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
-                                const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
+    // matrix A has m rows, k columns
+    // matrix B has k rows, n columns
+    // nra - number of elements to skip when moving into next row in A
+    // nrb - number of elements to skip when moving into next row in B
+    // nca - number of elements to skip when moving into next column in A
+    // ncb - number of elements to skip when moving into next column in B
+    // stride_a - number of elements to skip when moving to next A matrix
+    // stride_b - number of elements to skip when moving to next B matrix
+    // batches_a - number of A matrices
+    // batches_b - number of B matrices
+    static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
+        const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
+        const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
+        void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
+
         auto stream = ctx.stream_dnnl(q);
         auto eng = ctx.engine_dnnl(q);
-        dnnl::memory::dims a_dims = { m, k };
-        dnnl::memory::dims b_dims = { k, n };
-        dnnl::memory::dims c_dims = { m, n };
-        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
-        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
-        const auto c_md    = dnnl::memory::desc(c_dims, ct, tag::ab);
+
+        // { # strides, # rows, # columns }
+        dnnl::memory::dims a_dims = { batches_a, m, k };
+        dnnl::memory::dims b_dims = { batches_b, k, n };
+        dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
+
+        // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
+        dnnl::memory::dims a_strides = { stride_a, nra, nca };
+        dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
+
+        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
+        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
+        const auto c_md    = dnnl::memory::desc(c_dims, ct, tag::abc);
 
         dnnl::primitive_attr primitive_attr;
         primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
@@ -63,6 +83,15 @@ public:
 
         matmul_prim.execute(stream, matmul_args);
     }
+
+    // matrices A and B are column major, both having k rows
+    // matrix A has m column, matrix B has n columns
+    // output: column major matrix C = A transposed * B
+    static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
+        const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
+
+        gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
+    }
 };
 
 #endif
index 0ea729948ec7a952a506f0df9da1119661918e5f..1205fce0e7c714585e7ae0f313f8139fd595f936 100644 (file)
@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
 int g_ggml_sycl_debug = 0;
 int g_ggml_sycl_disable_optimize = 0;
 int g_ggml_sycl_disable_graph = 0;
+int g_ggml_sycl_disable_dnn = 0;
 int g_ggml_sycl_prioritize_dmmv = 0;
 
 static ggml_sycl_device_info ggml_sycl_init() {
@@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
         g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
         g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
         g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
+        g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
         g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
         GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
         GGML_LOG_INFO("Running with Environment Variables:\n");
         GGML_LOG_INFO("  GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
         GGML_LOG_INFO("  GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
+#ifdef GGML_SYCL_GRAPH
         GGML_LOG_INFO("  GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
+#else
+        GGML_LOG_INFO("  GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
+#endif
+#if GGML_SYCL_DNNL
+        GGML_LOG_INFO("  GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
+#else
+        GGML_LOG_INFO("  GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
+#endif
         GGML_LOG_INFO("  GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
         GGML_LOG_INFO("Build with Macros:\n");
 #if defined(GGML_SYCL_FORCE_MMQ)
@@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne10 = src1->ne[0];
-
+    GGML_ASSERT(ne00 == ne10);
 
     const int64_t row_diff = row_high - row_low;
 
     int id;
     SYCL_CHECK(
         CHECK_TRY_ERROR(id = get_current_device_id()));
-#if !GGML_SYCL_DNNL
-    const int64_t ne0 = dst->ne[0];
+
+    const int64_t ne0 = dst->ne[0]; // used by MKL only
     // the main device has a larger memory buffer to hold the results from all GPUs
     // ldc == nrows of the matrix that cuBLAS writes into
-    int ldc = id == ctx.device ? ne0 : row_diff;
-#endif
+    int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
 
 #ifdef GGML_SYCL_F16
     bool use_fp16 = true;  // TODO(Yu) SYCL capability check
@@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
                                          : src1_as_f16.get();
         ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
 
-#if !GGML_SYCL_DNNL
-        const sycl::half alpha_f16 = 1.0f;
-        const sycl::half beta_f16  = 0.0f;
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
-            *stream, oneapi::math::transpose::trans,
-            oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
-            &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
-            src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
-            dst_f16.get(), dpct::library_data_t::real_half, ldc,
-            dpct::library_data_t::real_half)));
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
-        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
-#else
-        DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
-                                  DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
-                                  dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
-        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+#if GGML_SYCL_DNNL
+        if (!g_ggml_sycl_disable_dnn) {
+            DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
+                                      DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
+                                      dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
+            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
+            to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+        }
+        else
 #endif
+        {
+            const sycl::half alpha_f16 = 1.0f;
+            const sycl::half beta_f16  = 0.0f;
+            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
+                *stream, oneapi::math::transpose::trans,
+                oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
+                &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
+                src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
+                dst_f16.get(), dpct::library_data_t::real_half, ldc,
+                dpct::library_data_t::real_half)));
+            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
+            to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+        }
     }
     else {
         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
         const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
 
-#if !GGML_SYCL_DNNL
-        const float alpha = 1.0f;
-        const float beta  = 0.0f;
-        SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
-            get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
-            src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
-            dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
-#else
-        DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
-                                  DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
-                                  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
+#if GGML_SYCL_DNNL
+        if (!g_ggml_sycl_disable_dnn) {
+            DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
+                                      DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
+                                      dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
+        }
+        else
 #endif
+        {
+            const float alpha = 1.0f;
+            const float beta  = 0.0f;
+            SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
+                get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
+                src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
+                dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
+        }
     }
     GGML_UNUSED(dst);
     GGML_UNUSED(src1_ddq_i);
@@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
+static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
                                    const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
                                    size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
                                    int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
@@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
 
     const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
     const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
-    uint8_t *       dst_bytes  = reinterpret_cast<uint8_t *>(dst);
+    uint8_t *       dst_bytes  = static_cast<uint8_t *>(dst);
 
     ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
     ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
@@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
     GGML_ASSERT(!ggml_is_transposed(src1));
     GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
     }
 
     ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
-    char *                           dst_t = reinterpret_cast<char *>(dst_ddf);
 
     dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
     dpct::library_data_t mkl_data_type    = dpct::library_data_t::real_float;
@@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
 
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
+    GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
+    GGML_ASSERT(ne10 == ne00);
 
     // broadcast factors
     const int64_t r2 = ne12 / ne02;
     const int64_t r3 = ne13 / ne03;
 
-    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
-        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
-                                                    oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
-                                                    src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
-                                                    src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
-                                                    mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
-    } else {
-        const int ne23 = ne12 * ne13;
-
-        ggml_sycl_pool_alloc<const void *>         ptrs_src(ctx.pool(), 2 * ne23);
-        ggml_sycl_pool_alloc<void *>               ptrs_dst(ctx.pool(), 1 * ne23);
-        ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
-
-        sycl::range<3> block_dims(1, ne12, ne13);
-        queue->submit([&](sycl::handler & cgh) {
-            const void ** ptrs_src_get = ptrs_src.get();
-            void **       ptrs_dst_get = ptrs_dst.get();
-            size_t        nb12_scaled  = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
-            size_t        nb13_scaled  = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
-            cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-                k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
-                                       nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
+#if GGML_SYCL_DNNL
+    if (!g_ggml_sycl_disable_dnn) {
+        auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
+            (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
+
+            DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
+                            src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
+                            src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
+                            dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
+        };
+
+        if (r2 == 1 && r3 == 1) {
+            if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+                dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
+            }
+            else {
+                for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
+                    const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
+                    const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
+                    float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
+                    dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
+                }
+            }
+        } else {
+            // iterate over batches from smaller set of matrices (matrix 0)
+            for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
+                for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
+                    const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
+                    const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
+                    float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
+                    dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
+                }
+            }
+        }
+    }
+    else
+#endif
+    {
+        if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+            // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
+                                                        oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+                                                        src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
+                                                        src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
+                                                        mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
+        } else {
+            const int ne23 = ne12 * ne13;
+
+            ggml_sycl_pool_alloc<const void *>         ptrs_src(ctx.pool(), 2 * ne23);
+            ggml_sycl_pool_alloc<void *>               ptrs_dst(ctx.pool(), 1 * ne23);
+            ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
+
+            sycl::range<3> block_dims(1, ne12, ne13);
+            queue->submit([&](sycl::handler & cgh) {
+                const void ** ptrs_src_get = ptrs_src.get();
+                void **       ptrs_dst_get = ptrs_dst.get();
+                size_t        nb12_scaled  = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
+                size_t        nb13_scaled  = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
+                cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+                    k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
+                                           nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
+                });
             });
-        });
 
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
-            (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
-            (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
-            (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
+            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
+                *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+                (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
+                (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
+                (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
+        }
     }
 } catch (const sycl::exception & exc) {
     std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
@@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
             return GGML_STATUS_SUCCESS;
         }
 
-        sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
+        sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
+
         model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
         ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
         model_sycl_graph.end_recording();