return device_type.str();
}
+template <typename Ts> struct matrix_info_t {
+ oneapi::mkl::transpose transpose_info[2];
+ Ts value_info[2];
+ std::int64_t size_info[3];
+ std::int64_t ld_info[3];
+ std::int64_t groupsize_info;
+};
+
namespace dpct
{
typedef sycl::queue *queue_ptr;
};
template <class Ta, class Tb, class Tc, class Ts>
- inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
- oneapi::mkl::transpose b_trans, int m, int n, int k,
- const void *alpha, const void **a, int lda,
- const void **b, int ldb, const void *beta, void **c,
- int ldc, int batch_size)
- {
- struct matrix_info_t
- {
- oneapi::mkl::transpose transpose_info[2];
- Ts value_info[2];
- std::int64_t size_info[3];
- std::int64_t ld_info[3];
- std::int64_t groupsize_info;
- };
-
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
+ matrix_info_t<float> * matrix_info) {
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
- matrix_info_t *matrix_info =
- (matrix_info_t *)std::malloc(sizeof(matrix_info_t));
matrix_info->transpose_info[0] = a_trans;
matrix_info->transpose_info[1] = b_trans;
matrix_info->value_info[0] = alpha_value;
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
- matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
- matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
- matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
- &(matrix_info->groupsize_info));
+ matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
+ reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
+ matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
+ reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#else
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
- matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
+ matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
- matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
- matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
+ matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
+ reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
#endif
-
- q.submit([&](sycl::handler &cgh)
- {
- cgh.depends_on(e);
- cgh.host_task([=] { std::free(matrix_info); }); });
}
template <class Ta, class Tb, class Tc, class Ts>
/// \param [in] ldc Leading dimension of C.
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
/// \param [in] scaling_type Data type of the scaling factors.
- inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
- oneapi::mkl::transpose b_trans, int m, int n, int k,
- const void *alpha, const void *a[],
- library_data_t a_type, int lda, const void *b[],
- library_data_t b_type, int ldb, const void *beta,
- void *c[], library_data_t c_type, int ldc,
- int batch_size, library_data_t scaling_type)
- {
- if (scaling_type == library_data_t::real_float &&
- c_type == library_data_t::complex_float)
- {
- scaling_type = library_data_t::complex_float;
- }
- else if (scaling_type == library_data_t::real_double &&
- c_type == library_data_t::complex_double)
- {
- scaling_type = library_data_t::complex_double;
- }
-
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
+ matrix_info_t<float> * matrix_info) {
std::uint64_t key =
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
switch (key)
library_data_t::real_float, library_data_t::real_float,
library_data_t::real_float, library_data_t::real_float):
{
- detail::gemm_batch_impl<float, float, float, float>(
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
- batch_size);
+ detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
+ beta, c, ldc, batch_size, matrix_info);
break;
}
case detail::get_type_combination_id(
library_data_t::real_double, library_data_t::real_double,
library_data_t::real_double, library_data_t::real_double):
{
- detail::gemm_batch_impl<double, double, double, double>(
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
- batch_size);
- break;
- }
- case detail::get_type_combination_id(
- library_data_t::complex_float, library_data_t::complex_float,
- library_data_t::complex_float, library_data_t::complex_float):
- {
- detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
- std::complex<float>, std::complex<float>>(
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
- batch_size);
- break;
- }
- case detail::get_type_combination_id(
- library_data_t::complex_double, library_data_t::complex_double,
- library_data_t::complex_double, library_data_t::complex_double):
- {
- detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
- std::complex<double>, std::complex<double>>(
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
- batch_size);
+ detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
+ beta, c, ldc, batch_size, matrix_info);
break;
}
case detail::get_type_combination_id(
library_data_t::real_half, library_data_t::real_half,
library_data_t::real_half, library_data_t::real_half):
{
- detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
- sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
- a, lda, b, ldb, beta, c, ldc,
- batch_size);
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
break;
}
#ifdef __INTEL_MKL__
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_bfloat16, library_data_t::real_float):
{
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
- oneapi::mkl::bfloat16, float>(
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
- batch_size);
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
break;
}
case detail::get_type_combination_id(
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_float, library_data_t::real_float):
{
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
- float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
- b, ldb, beta, c, ldc, batch_size);
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
break;
}
#endif
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
float beta_float =
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
- detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
- float>(q, a_trans, b_trans, m, n, k, &alpha_float,
- a, lda, b, ldb, &beta_float, c, ldc,
- batch_size);
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
+ matrix_info);
break;
}
case detail::get_type_combination_id(
library_data_t::real_float, library_data_t::real_float):
{
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
- batch_size);
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
break;
}
case detail::get_type_combination_id(
library_data_t::real_float, library_data_t::real_float):
{
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
- batch_size);
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
break;
}
case detail::get_type_combination_id(
sycl::half alpha_half(alpha_value);
sycl::half beta_half(beta_value);
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
- q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
- batch_size);
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
break;
}
default:
}
};
+struct ggml_sycl_pool_host : public ggml_sycl_pool {
+ queue_ptr qptr;
+ int device;
+
+ inline static int counter{ 0 };
+
+ struct ggml_sycl_buffer {
+ void * ptr = nullptr;
+ size_t size = 0;
+ };
+
+ // Set arbitrarly to 64
+ static constexpr int MAX_POOL_SIZE{ 64 };
+ std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
+ size_t pool_size = 0;
+
+ explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
+
+ ~ggml_sycl_pool_host() {
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
+ ggml_sycl_buffer & b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
+ b.ptr = nullptr;
+ pool_size -= b.size;
+ b.size = 0;
+ }
+ }
+ counter = 0;
+ }
+
+ void * alloc(size_t size, size_t * actual_size) override {
+ if (counter == MAX_POOL_SIZE) {
+ ggml_sycl_buffer b = buffer_pool[0];
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ counter = 1;
+ return ptr;
+ }
+ ggml_sycl_buffer & b = buffer_pool[counter];
+
+ if (b.ptr == nullptr) {
+ void * ptr;
+
+ SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
+ if (!ptr) {
+ GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
+ return nullptr;
+ }
+ pool_size += size;
+ *actual_size = size;
+ counter = counter + 1;
+ return ptr;
+ } else {
+ ++counter;
+ b.size = size;
+ return b.ptr;
+ }
+ }
+
+ void free(void * ptr, size_t size) override {
+ // if the pool is not completed add the pointer to it in place of the first nullptr found.
+ // Otherwise do nothing, pointers will be freed once the pool is deallocated.
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
+ ggml_sycl_buffer & b = buffer_pool[i];
+ if (b.ptr == nullptr) {
+ b.ptr = ptr;
+ b.size = size;
+ return;
+ }
+ }
+ }
+};
+
+std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
+ // return pool for the host to speed up memory management
+ return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
+}
+
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
// TBD: NO VMM support
// if (ggml_sycl_info().devices[device].vmm) {
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
sycl::range<3> block_dims(1, ne12, ne13);
/*
});
}
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
- *main_stream, oneapi::mkl::transpose::trans,
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
- (const void **)(ptrs_src.get() + 0 * ne23),
- dpct::library_data_t::real_half, nb01 / nb00,
- (const void **)(ptrs_src.get() + 1 * ne23),
- dpct::library_data_t::real_half, nb11 / nb10, beta,
- (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
- cu_compute_type)));
+ *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
+ (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
}
}
catch (sycl::exception const &exc) {