]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
SYCL : Move to compile time oneMKL interface backend selection for NVIDIA backend...
authorNicolò Scipione <redacted>
Wed, 4 Dec 2024 01:29:20 +0000 (02:29 +0100)
committerGitHub <redacted>
Wed, 4 Dec 2024 01:29:20 +0000 (09:29 +0800)
* [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend

Move to compile time selection to backend to avoid latency at run time.
Add it to all mkl gemm calls and only for NVIDIA backend.

Signed-off-by: nscipione <redacted>
* Formatting

* Address PR comments to increase readibility

---------

Signed-off-by: nscipione <redacted>
ggml/src/ggml-sycl/CMakeLists.txt
ggml/src/ggml-sycl/dpct/helper.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/outprod.cpp

index 83f223fd7b6fc8f1a8033e9e0540368ff9b213ca..3579a311aac0784ab1971823132d84473455c132 100644 (file)
@@ -68,7 +68,8 @@ else()
         target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
     elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
         set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
-        target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
+        add_compile_definitions(GGML_SYCL_NVIDIA)
+        target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas)
     elseif (GGML_SYCL_TARGET STREQUAL "AMD")
         if (NOT GGML_SYCL_DEVICE_ARCH)
             message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
index c2f28bb49579e9877cfe0042eafd95b2ef2055fe..d1b5dd87c69222d359989af120adbf4de2e001c9 100644 (file)
@@ -1689,9 +1689,14 @@ namespace dpct
             auto data_a = get_memory<const Ta>(a);
             auto data_b = get_memory<const Tb>(b);
             auto data_c = get_memory<Tc>(c);
-            oneapi::mkl::blas::column_major::gemm(
-                q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
-                data_b, ldb, beta_value, data_c, ldc);
+#ifdef GGML_SYCL_NVIDIA
+            oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
+                                                  a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
+                                                  beta_value, data_c, ldc);
+#else
+            oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
+                                                  beta_value, data_c, ldc);
+#endif
         }
 
         template <typename VecT, class BinaryOperation, class = void>
@@ -1754,14 +1759,22 @@ namespace dpct
             matrix_info->ld_info[2] = ldc;
             matrix_info->groupsize_info = batch_size;
 
+#ifdef GGML_SYCL_NVIDIA
+            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
+                oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
+                matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
+                matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
+                matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
+                matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
+                &(matrix_info->groupsize_info));
+#else
             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
-                q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
-                matrix_info->size_info, matrix_info->size_info + 1,
-                matrix_info->size_info + 2, matrix_info->value_info,
-                reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
-                reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
-                matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
+                q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
+                matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
+                reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
+                matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
                 matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
+#endif
 
             q.submit([&](sycl::handler &cgh)
                      {
@@ -1783,10 +1796,16 @@ namespace dpct
             auto data_a = get_memory<const Ta>(a);
             auto data_b = get_memory<const Tb>(b);
             auto data_c = get_memory<Tc>(c);
+#ifdef GGML_SYCL_NVIDIA
             oneapi::mkl::blas::column_major::gemm_batch(
-                q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
-                stride_a, data_b, ldb, stride_b, beta_value,
-                data_c, ldc, stride_c, batch_size);
+                oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
+                alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
+                batch_size);
+#else
+            oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
+                                                        stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
+                                                        stride_c, batch_size);
+#endif
         }
 
     } // namespace detail
index 1310981e52f4c970d78fb37f815e827ebefff974..135efb521a980ec37a3e5f20d38e957f546511d7 100644 (file)
@@ -2573,12 +2573,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
         const float alpha = 1.0f;
         const float beta = 0.0f;
 #if !GGML_SYCL_DNNL
+#    ifdef GGML_SYCL_NVIDIA
         SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
-            *stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
-            dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
-            src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
+            oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
+            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
+            ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
+#    else
+        SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
+            *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
+            dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
             dst_dd_i, ldc)));
+#    endif
 #else
         auto dnnl_stream = ctx.stream_dnnl(stream);
          DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
index e61cdc2ca5d5377fb2e9c8a4d9b414d05c894366..ef9af0b7633ab6b304f9de544c34bf36a910284f 100644 (file)
@@ -40,14 +40,14 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
 
     try {
         // Perform matrix multiplication using oneMKL GEMM
-        oneapi::mkl::blas::column_major::gemm(*stream,
-            oneapi::mkl::transpose::nontrans, src1_op,
-            ne0, ne1, ne01,
-            alpha,
-            src0_d, ne00,
-            src1_d, ldb,
-            beta,
-            dst_d, ne0);
+#ifdef GGML_SYCL_NVIDIA
+        oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
+                                              oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
+                                              ne00, src1_d, ldb, beta, dst_d, ne0);
+#else
+        oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
+                                              src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
+#endif
     }
     catch (sycl::exception const& exc) {
         std::cerr << exc.what() << std::endl;