auto data_a = get_memory<const Ta>(a);
auto data_b = get_memory<const Tb>(b);
auto data_c = get_memory<Tc>(c);
- oneapi::mkl::blas::column_major::gemm(
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
- data_b, ldb, beta_value, data_c, ldc);
+#ifdef GGML_SYCL_NVIDIA
+ oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
+ a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
+ beta_value, data_c, ldc);
+#else
+ oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
+ beta_value, data_c, ldc);
+#endif
}
template <typename VecT, class BinaryOperation, class = void>
matrix_info->ld_info[2] = ldc;
matrix_info->groupsize_info = batch_size;
+#ifdef GGML_SYCL_NVIDIA
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
+ matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
+ matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
+ matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
+ matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
+ &(matrix_info->groupsize_info));
+#else
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
- q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
- matrix_info->size_info, matrix_info->size_info + 1,
- matrix_info->size_info + 2, matrix_info->value_info,
- reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
- reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
- matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
+ q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
+ matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
+ reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
+ matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
+#endif
q.submit([&](sycl::handler &cgh)
{
auto data_a = get_memory<const Ta>(a);
auto data_b = get_memory<const Tb>(b);
auto data_c = get_memory<Tc>(c);
+#ifdef GGML_SYCL_NVIDIA
oneapi::mkl::blas::column_major::gemm_batch(
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
- stride_a, data_b, ldb, stride_b, beta_value,
- data_c, ldc, stride_c, batch_size);
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
+ alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
+ batch_size);
+#else
+ oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
+ stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
+ stride_c, batch_size);
+#endif
}
} // namespace detail
const float alpha = 1.0f;
const float beta = 0.0f;
#if !GGML_SYCL_DNNL
+# ifdef GGML_SYCL_NVIDIA
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
- *stream, oneapi::mkl::transpose::trans,
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
- dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
- src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
+ oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
+ ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
+# else
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
+ *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
+ dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
dst_dd_i, ldc)));
+# endif
#else
auto dnnl_stream = ctx.stream_dnnl(stream);
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
try {
// Perform matrix multiplication using oneMKL GEMM
- oneapi::mkl::blas::column_major::gemm(*stream,
- oneapi::mkl::transpose::nontrans, src1_op,
- ne0, ne1, ne01,
- alpha,
- src0_d, ne00,
- src1_d, ldb,
- beta,
- dst_d, ne0);
+#ifdef GGML_SYCL_NVIDIA
+ oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
+ oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
+ ne00, src1_d, ldb, beta, dst_d, ne0);
+#else
+ oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
+ src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
+#endif
}
catch (sycl::exception const& exc) {
std::cerr << exc.what() << std::endl;