]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
SYCL: Introducing memory host pool (llama/11251)
authorNicolò Scipione <redacted>
Sun, 19 Jan 2025 13:33:34 +0000 (14:33 +0100)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
* Implement host pool for matrix_info

Creating a new memory pool on the host to store memory location for
matrix_info needed to launch gemm_batch from oneMKL/oneMath.
Removing complex support in gemm_batch since it is not used in llama.cpp

* Remove unnecessary headers and cast

* Reorder member variable to avoid warning on initialization

* Formatting

* Remove unused variable

* Address PR review feedback - remove warning

---------

Signed-off-by: nscipione <redacted>
ggml/src/ggml-sycl/common.hpp
ggml/src/ggml-sycl/dpct/helper.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp

index e9500f3a1682b629fede5692fda151554f714cb3..abad847ca81999da53e20457a29f1a85303195d6 100644 (file)
@@ -333,8 +333,12 @@ struct ggml_backend_sycl_context {
     // pool
     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
 
+    std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
+
     static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
 
+    static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
+
     ggml_sycl_pool & pool(int device) {
         if (pools[device] == nullptr) {
             pools[device] = new_pool_for_device(stream(device,0), device);
@@ -345,6 +349,15 @@ struct ggml_backend_sycl_context {
     ggml_sycl_pool & pool() {
         return pool(device);
     }
+
+    ggml_sycl_pool & host_pool(int device) {
+        if (host_pools[device] == nullptr) {
+            host_pools[device] = new_pool_for_host(stream(device, 0), device);
+        }
+        return *host_pools[device];
+    }
+
+    ggml_sycl_pool & host_pool() { return host_pool(device); }
 };
 
 // common device functions
index e167948e7a3f0cdf75ba60b57b66299931cbd5dc..c96395be613126282a6ed425a6aba9400787bd29 100644 (file)
@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
     return device_type.str();
 }
 
+template <typename Ts> struct matrix_info_t {
+    oneapi::mkl::transpose transpose_info[2];
+    Ts                     value_info[2];
+    std::int64_t           size_info[3];
+    std::int64_t           ld_info[3];
+    std::int64_t           groupsize_info;
+};
+
 namespace dpct
 {
     typedef sycl::queue *queue_ptr;
@@ -1727,26 +1735,13 @@ namespace dpct
         };
 
         template <class Ta, class Tb, class Tc, class Ts>
-        inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
-                                    oneapi::mkl::transpose b_trans, int m, int n, int k,
-                                    const void *alpha, const void **a, int lda,
-                                    const void **b, int ldb, const void *beta, void **c,
-                                    int ldc, int batch_size)
-        {
-            struct matrix_info_t
-            {
-                oneapi::mkl::transpose transpose_info[2];
-                Ts value_info[2];
-                std::int64_t size_info[3];
-                std::int64_t ld_info[3];
-                std::int64_t groupsize_info;
-            };
-
+        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
+                                    int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
+                                    int ldb, const void * beta, void ** c, int ldc, int batch_size,
+                                    matrix_info_t<float> * matrix_info) {
             Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
             Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
 
-            matrix_info_t *matrix_info =
-                (matrix_info_t *)std::malloc(sizeof(matrix_info_t));
             matrix_info->transpose_info[0] = a_trans;
             matrix_info->transpose_info[1] = b_trans;
             matrix_info->value_info[0] = alpha_value;
@@ -1763,23 +1758,18 @@ namespace dpct
             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
                 oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
                 matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
-                matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
-                matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
-                matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
-                &(matrix_info->groupsize_info));
+                matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
+                reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
+                matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
+                reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
 #else
             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
                 q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
-                matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
+                matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
                 reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
-                matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
-                matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
+                matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
+                reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
 #endif
-
-            q.submit([&](sycl::handler &cgh)
-                     {
-    cgh.depends_on(e);
-    cgh.host_task([=] { std::free(matrix_info); }); });
         }
 
         template <class Ta, class Tb, class Tc, class Ts>
@@ -2422,25 +2412,11 @@ namespace dpct
     /// \param [in] ldc Leading dimension of C.
     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
     /// \param [in] scaling_type Data type of the scaling factors.
-    inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
-                           oneapi::mkl::transpose b_trans, int m, int n, int k,
-                           const void *alpha, const void *a[],
-                           library_data_t a_type, int lda, const void *b[],
-                           library_data_t b_type, int ldb, const void *beta,
-                           void *c[], library_data_t c_type, int ldc,
-                           int batch_size, library_data_t scaling_type)
-    {
-        if (scaling_type == library_data_t::real_float &&
-            c_type == library_data_t::complex_float)
-        {
-            scaling_type = library_data_t::complex_float;
-        }
-        else if (scaling_type == library_data_t::real_double &&
-                 c_type == library_data_t::complex_double)
-        {
-            scaling_type = library_data_t::complex_double;
-        }
-
+    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
+                           int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
+                           const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
+                           library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
+                           matrix_info_t<float> * matrix_info) {
         std::uint64_t key =
             detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
         switch (key)
@@ -2449,48 +2425,24 @@ namespace dpct
             library_data_t::real_float, library_data_t::real_float,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<float, float, float, float>(
-                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
-                batch_size);
+            detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
+                                                                beta, c, ldc, batch_size, matrix_info);
             break;
         }
         case detail::get_type_combination_id(
             library_data_t::real_double, library_data_t::real_double,
             library_data_t::real_double, library_data_t::real_double):
         {
-            detail::gemm_batch_impl<double, double, double, double>(
-                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
-                batch_size);
-            break;
-        }
-        case detail::get_type_combination_id(
-            library_data_t::complex_float, library_data_t::complex_float,
-            library_data_t::complex_float, library_data_t::complex_float):
-        {
-            detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
-                                    std::complex<float>, std::complex<float>>(
-                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
-                batch_size);
-            break;
-        }
-        case detail::get_type_combination_id(
-            library_data_t::complex_double, library_data_t::complex_double,
-            library_data_t::complex_double, library_data_t::complex_double):
-        {
-            detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
-                                    std::complex<double>, std::complex<double>>(
-                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
-                batch_size);
+            detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
+                                                                    beta, c, ldc, batch_size, matrix_info);
             break;
         }
         case detail::get_type_combination_id(
             library_data_t::real_half, library_data_t::real_half,
             library_data_t::real_half, library_data_t::real_half):
         {
-            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
-                                    sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
-                                                a, lda, b, ldb, beta, c, ldc,
-                                                batch_size);
+            detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
+                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
 #ifdef __INTEL_MKL__
@@ -2498,19 +2450,16 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
-                                    oneapi::mkl::bfloat16, float>(
-                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
-                batch_size);
+            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
+                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
         case detail::get_type_combination_id(
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
-                                    float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
-                                           b, ldb, beta, c, ldc, batch_size);
+            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
+                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
 #endif
@@ -2522,10 +2471,9 @@ namespace dpct
                 dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
             float beta_float =
                 dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
-            detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
-                                    float>(q, a_trans, b_trans, m, n, k, &alpha_float,
-                                           a, lda, b, ldb, &beta_float, c, ldc,
-                                           batch_size);
+            detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
+                q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
+                matrix_info);
             break;
         }
         case detail::get_type_combination_id(
@@ -2533,8 +2481,7 @@ namespace dpct
             library_data_t::real_float, library_data_t::real_float):
         {
             detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
-                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
-                batch_size);
+                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
         case detail::get_type_combination_id(
@@ -2542,8 +2489,7 @@ namespace dpct
             library_data_t::real_float, library_data_t::real_float):
         {
             detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
-                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
-                batch_size);
+                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
         case detail::get_type_combination_id(
@@ -2557,8 +2503,7 @@ namespace dpct
             sycl::half alpha_half(alpha_value);
             sycl::half beta_half(beta_value);
             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
-                q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
-                batch_size);
+                q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
             break;
         }
         default:
index 5272ca45423f255ae8e346493ab156bb9eb824a8..ed4d8bb8baee1fc9df408ae991faae68e52ec28a 100644 (file)
@@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
     }
 };
 
+struct ggml_sycl_pool_host : public ggml_sycl_pool {
+    queue_ptr qptr;
+    int       device;
+
+    inline static int counter{ 0 };
+
+    struct ggml_sycl_buffer {
+        void * ptr  = nullptr;
+        size_t size = 0;
+    };
+
+    // Set arbitrarly to 64
+    static constexpr int          MAX_POOL_SIZE{ 64 };
+    std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
+    size_t                        pool_size   = 0;
+
+    explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
+
+    ~ggml_sycl_pool_host() {
+        for (int i = 0; i < MAX_POOL_SIZE; ++i) {
+            ggml_sycl_buffer & b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+                SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
+                b.ptr = nullptr;
+                pool_size -= b.size;
+                b.size = 0;
+            }
+        }
+        counter = 0;
+    }
+
+    void * alloc(size_t size, size_t * actual_size) override {
+        if (counter == MAX_POOL_SIZE) {
+            ggml_sycl_buffer b               = buffer_pool[0];
+            void *           ptr             = b.ptr;
+            *actual_size                     = b.size;
+            counter                          = 1;
+            return ptr;
+        }
+        ggml_sycl_buffer & b = buffer_pool[counter];
+
+        if (b.ptr == nullptr) {
+            void * ptr;
+
+            SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
+            if (!ptr) {
+                GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
+                return nullptr;
+            }
+            pool_size += size;
+            *actual_size = size;
+            counter      = counter + 1;
+            return ptr;
+        } else {
+            ++counter;
+            b.size = size;
+            return b.ptr;
+        }
+    }
+
+    void free(void * ptr, size_t size) override {
+        // if the pool is not completed add the pointer to it in place of the first nullptr found.
+        // Otherwise do nothing, pointers will be freed once the pool is deallocated.
+        for (int i = 0; i < MAX_POOL_SIZE; ++i) {
+            ggml_sycl_buffer & b = buffer_pool[i];
+            if (b.ptr == nullptr) {
+                b.ptr  = ptr;
+                b.size = size;
+                return;
+            }
+        }
+    }
+};
+
+std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
+    // return pool for the host to speed up memory management
+    return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
+}
+
 std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
     // TBD: NO VMM support
     // if (ggml_sycl_info().devices[device].vmm) {
@@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
 
         ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
         ggml_sycl_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
+        ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
 
         sycl::range<3> block_dims(1, ne12, ne13);
         /*
@@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
             });
         }
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
-            (const void **)(ptrs_src.get() + 0 * ne23),
-            dpct::library_data_t::real_half, nb01 / nb00,
-            (const void **)(ptrs_src.get() + 1 * ne23),
-            dpct::library_data_t::real_half, nb11 / nb10, beta,
-            (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
-            cu_compute_type)));
+            *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
+            (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
+            (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
+            (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
     }
 }
 catch (sycl::exception const &exc) {