else static_assert(0);
}
- static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
- const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
+ // matrix A has m rows, k columns
+ // matrix B has k rows, n columns
+ // nra - number of elements to skip when moving into next row in A
+ // nrb - number of elements to skip when moving into next row in B
+ // nca - number of elements to skip when moving into next column in A
+ // ncb - number of elements to skip when moving into next column in B
+ // stride_a - number of elements to skip when moving to next A matrix
+ // stride_b - number of elements to skip when moving to next B matrix
+ // batches_a - number of A matrices
+ // batches_b - number of B matrices
+ static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
+ const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
+ const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
+ void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
+
auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q);
- dnnl::memory::dims a_dims = { m, k };
- dnnl::memory::dims b_dims = { k, n };
- dnnl::memory::dims c_dims = { m, n };
- const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
- const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
+
+ // { # strides, # rows, # columns }
+ dnnl::memory::dims a_dims = { batches_a, m, k };
+ dnnl::memory::dims b_dims = { batches_b, k, n };
+ dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
+
+ // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
+ dnnl::memory::dims a_strides = { stride_a, nra, nca };
+ dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
+
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
matmul_prim.execute(stream, matmul_args);
}
+
+ // matrices A and B are column major, both having k rows
+ // matrix A has m column, matrix B has n columns
+ // output: column major matrix C = A transposed * B
+ static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
+
+ gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
+ }
};
#endif
int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0;
+int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_prioritize_dmmv = 0;
static ggml_sycl_device_info ggml_sycl_init() {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
+ g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
+#ifdef GGML_SYCL_GRAPH
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
+#else
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
+#endif
+#if GGML_SYCL_DNNL
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
+#else
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
+#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ)
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
-
+ GGML_ASSERT(ne00 == ne10);
const int64_t row_diff = row_high - row_low;
int id;
SYCL_CHECK(
CHECK_TRY_ERROR(id = get_current_device_id()));
-#if !GGML_SYCL_DNNL
- const int64_t ne0 = dst->ne[0];
+
+ const int64_t ne0 = dst->ne[0]; // used by MKL only
// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
- int ldc = id == ctx.device ? ne0 : row_diff;
-#endif
+ int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
#ifdef GGML_SYCL_F16
bool use_fp16 = true; // TODO(Yu) SYCL capability check
: src1_as_f16.get();
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
-#if !GGML_SYCL_DNNL
- const sycl::half alpha_f16 = 1.0f;
- const sycl::half beta_f16 = 0.0f;
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
- *stream, oneapi::math::transpose::trans,
- oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
- &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
- src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
- dst_f16.get(), dpct::library_data_t::real_half, ldc,
- dpct::library_data_t::real_half)));
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
-#else
- DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
- DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
- dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+#if GGML_SYCL_DNNL
+ if (!g_ggml_sycl_disable_dnn) {
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
+ dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+ }
+ else
#endif
+ {
+ const sycl::half alpha_f16 = 1.0f;
+ const sycl::half beta_f16 = 0.0f;
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
+ *stream, oneapi::math::transpose::trans,
+ oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
+ &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
+ src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
+ dst_f16.get(), dpct::library_data_t::real_half, ldc,
+ dpct::library_data_t::real_half)));
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+ }
}
else {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
-#if !GGML_SYCL_DNNL
- const float alpha = 1.0f;
- const float beta = 0.0f;
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
- get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
- src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
- dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
-#else
- DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
- DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
- dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
+#if GGML_SYCL_DNNL
+ if (!g_ggml_sycl_disable_dnn) {
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
+ DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
+ }
+ else
#endif
+ {
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
+ get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
+ src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
+ dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
+ }
}
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i);
std::exit(1);
}
-static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
+static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
- uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
+ uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
}
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
- char * dst_t = reinterpret_cast<char *>(dst_ddf);
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
+ GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
+ GGML_ASSERT(ne10 == ne00);
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
- // there is no broadcast and src0, src1 are contiguous across dims 2, 3
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
- oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
- src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
- src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
- mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
- } else {
- const int ne23 = ne12 * ne13;
-
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
- ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
- ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
-
- sycl::range<3> block_dims(1, ne12, ne13);
- queue->submit([&](sycl::handler & cgh) {
- const void ** ptrs_src_get = ptrs_src.get();
- void ** ptrs_dst_get = ptrs_dst.get();
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
- k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
- nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
+#if GGML_SYCL_DNNL
+ if (!g_ggml_sycl_disable_dnn) {
+ auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
+ (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
+
+ DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
+ src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
+ src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
+ dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
+ };
+
+ if (r2 == 1 && r3 == 1) {
+ if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+ dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
+ }
+ else {
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
+ const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
+ float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
+ }
+ }
+ } else {
+ // iterate over batches from smaller set of matrices (matrix 0)
+ for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
+ const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
+ float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
+ }
+ }
+ }
+ }
+ else
+#endif
+ {
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
+ src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
+ mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
+ } else {
+ const int ne23 = ne12 * ne13;
+
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
+ ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
+
+ sycl::range<3> block_dims(1, ne12, ne13);
+ queue->submit([&](sycl::handler & cgh) {
+ const void ** ptrs_src_get = ptrs_src.get();
+ void ** ptrs_dst_get = ptrs_dst.get();
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+ k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
+ });
});
- });
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
- *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
- (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
- (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
- (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
+ *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
+ (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
+ }
}
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
return GGML_STATUS_SUCCESS;
}
- sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
+ sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
+
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording();