#define GGML_SYCL_BACKEND_HPP
#include "binbcast.hpp"
-#include "concat.hpp"
#include "common.hpp"
+#include "concat.hpp"
#include "conv.hpp"
#include "convert.hpp"
+#include "cpy.hpp"
#include "dequantize.hpp"
#include "dmmv.hpp"
+#include "element_wise.hpp"
+#include "gla.hpp"
+#include "im2col.hpp"
#include "mmq.hpp"
#include "mmvq.hpp"
-#include "rope.hpp"
#include "norm.hpp"
+#include "outprod.hpp"
+#include "quants.hpp"
+#include "rope.hpp"
#include "softmax.hpp"
#include "tsembd.hpp"
-#include "im2col.hpp"
#include "wkv.hpp"
-#include "outprod.hpp"
-#include "element_wise.hpp"
-#include "cpy.hpp"
-#include "gla.hpp"
-#endif // GGML_SYCL_BACKEND_HPP
+#endif // GGML_SYCL_BACKEND_HPP
extern int g_ggml_sycl_debug;
extern int g_ggml_sycl_disable_optimize;
+extern int g_ggml_sycl_prioritize_dmmv;
#define GGML_SYCL_DEBUG(...) \
do { \
int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0;
+int g_ggml_sycl_prioritize_dmmv = 0;
static ggml_sycl_device_info ggml_sycl_init() {
ggml_sycl_device_info info = {};
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
+ g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
+ GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ)
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
std::exit(1);
}
+enum class mul_mat_algo {
+ DMMV = 0,
+ MMVQ = 1,
+ MUL_MAT_SYCL = 2,
+};
+
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
// TODO: accuracy issues in MMQ
GGML_UNUSED(type);
return false;
}
+inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return true;
+ default:
+ return false;
+ }
+}
+
+inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return true;
+ default:
+ return false;
+ }
+}
+
+inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return true;
+ default:
+ return false;
+ }
+}
+
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
int offset_blks = offset / sizeof(block_q4_0);
- auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
+ auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
stream->parallel_for(
reorder_qw(data_device, ncols, nrows, size, 0, stream);
}
-/*
-* This function could be called when the OP (mul_mat) function support reorder optimizition.
-*/
-static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
- ggml_tensor * dst) {
- if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
- ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
- dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
- src0->type == GGML_TYPE_Q4_0 &&
- src1->ne[2]==1 && src1->ne[3]==1) {
+static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
+ return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
+ ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
+ dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
+ dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
+}
- ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
- if (!extra) return; //only happen in CI/UT permute case.
+static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
+ ggml_tensor * dst, mul_mat_algo mm_algorithm) {
+ if (!should_reorder_tensor(*ctx, dst)) {
+ return;
+ }
- if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder.
+ ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
+ if (!extra || extra->optimized_feature.reorder) {
+ return; // Skip permutations and already reordered tensors
+ }
- reorder_qw(src0, ctx->stream());
- extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
+ switch (mm_algorithm) {
+ case mul_mat_algo::DMMV:
+ if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
+ return;
+ }
+ break;
+ case mul_mat_algo::MMVQ:
+ if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
+ return;
+ }
+ break;
+ case mul_mat_algo::MUL_MAT_SYCL:
+ if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
+ return;
+ }
+ break;
}
+
+ reorder_qw(src0, ctx->stream());
+ extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
}
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
int64_t min_compute_capability = INT_MAX;
if (split) {
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
auto & tensor_split = buft_ctx->tensor_split;
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
// skip devices that are not going to do any work:
}
}
} else {
- min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
+ min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
}
// check data types and tensor shapes for custom matrix multiplication kernels:
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
#endif // SYCL_USE_XMX
+
// mmvq path is faster in the CUDA backend.
- if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
+ if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
+ // is enabled takes precedence over DMMV, the current if-else implementation
+ // requires disabling DMMV if both conditions are met
+ || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
+ }
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// TODO: Refactor and cleanup of mul mat dispatching.
// KQ + KQV multi-batch
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
- opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
- // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
+ constexpr bool convert_src1_to_q8_1 = false;
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
} else if (use_mul_mat_vec_q) {
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
+ constexpr bool convert_src1_to_q8_1 = true;
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
} else if (use_mul_mat_q) {
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
+ constexpr bool convert_src1_to_q8_1 = true;
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
} else {
- opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
+ constexpr bool convert_src1_to_q8_1 = false;
+ // MUL_MAT_SYCL supports reorder
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
}
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
}
#include "mmvq.hpp"
+
+#include "ggml.h"
+#include "common.hpp"
+#include "quants.hpp"
#include "vecdotq.hpp"
-#include <cassert>
+
+template <typename reorder_vec_dot_q_sycl>
+static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
+ using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
+ using block_traits = typename block_type::traits;
+
+ const auto sg = nd_item.get_sub_group();
+ const int sg_range = sg.get_group_linear_range();
+ const int workgroup_id = nd_item.get_group_linear_id();
+ const int sg_id = sg.get_group_linear_id();
+ const int row = workgroup_id * sg_range + sg_id;
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / block_traits::qk;
+ constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
+ constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
+
+ static_assert(blocks_per_subgroup > 0);
+ static_assert(block_elements_per_subgroup > 0);
+
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ float partial_sum = 0.0f;
+ for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
+ const int ibx = row * blocks_per_row + i; // x block index
+ // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
+ const int bx_offset = block_type::get_block_offset(ibx);
+ const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
+
+ // Y block index that aligns with ibx
+ const int iby = i * block_type::block_to_q8_1_ratio();
+
+#pragma unroll
+ for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
+ // x block quant index when casting the quants to int
+ const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
+
+ partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
+ }
+ }
+
+ auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());
+
+ if (sg.leader()) {
+ dst[row] = sum;
+ }
+}
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
}
}
-static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
- float *dst, const int ncols,
- const int nrows,
+static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
+ const int nrows, dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK4_0 == 0);
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
+ constexpr size_t num_subgroups = 16;
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
+
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
+
+ stream->submit([&](sycl::handler & cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
+ nd_item);
+ });
+ });
+}
+
+static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK4_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
- {
-
- stream->submit([&](sycl::handler &cgh) {
- cgh.parallel_for(
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
- [=](sycl::nd_item<3> item_ct1)
- [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
- mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
- VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
- vx, vy, dst, ncols, nrows, item_ct1);
- });
+ {
+ stream->submit([&](sycl::handler & cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
});
}
}
}
}
-void ggml_sycl_op_mul_mat_vec_q(
- ggml_backend_sycl_context & ctx,
- const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
- const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
- float *dst_dd_i, const int64_t row_low, const int64_t row_high,
- const int64_t src1_ncols, const int64_t src1_padded_col_size,
- const dpct::queue_ptr &stream) {
-
+void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
+ ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
+ const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,
+ const dpct::queue_ptr & stream) {
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne10 % QK8_1 == 0);
- const int64_t ne00 = src0->ne[0];
+ const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
int id;
- SYCL_CHECK(
- CHECK_TRY_ERROR(id = get_current_device_id()));
+ SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));
const size_t q8_1_ts = sizeof(block_q8_1);
const size_t q8_1_bs = QK8_1;
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
- for (int i = 0; i < src1_ncols; i++)
- {
+ for (int i = 0; i < src1_ncols; i++) {
const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
- const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
- float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
+ const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
+ float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
switch (src0->type) {
- case GGML_TYPE_Q4_0:
- mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q4_1:
- mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q5_0:
- mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q5_1:
- mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q8_0:
- mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q2_K:
- mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q3_K:
- mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q4_K:
- mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q5_K:
- mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_Q6_K:
- mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ1_S:
- mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ1_M:
- mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ2_XXS:
- mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ2_XS:
- mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ2_S:
- mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ3_XXS:
- mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ3_S:
- mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ4_NL:
- mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- case GGML_TYPE_IQ4_XS:
- mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
- break;
- default:
- GGML_ABORT("fatal error");
+ case GGML_TYPE_Q4_0:
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
+ reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ } else {
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
+ mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ }
+ break;
+ case GGML_TYPE_Q4_1:
+ mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ1_S:
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ1_M:
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ2_XXS:
+ mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ2_S:
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ3_S:
+ mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ4_NL:
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ default:
+ GGML_ABORT("fatal error");
}
}
GGML_UNUSED(src1);
--- /dev/null
+//
+// MIT license
+// Copyright (C) 2025 Codeplay Software Ltd.
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_QUANTS_HPP
+#define GGML_SYCL_QUANTS_HPP
+
+#include "ggml-common.h"
+#include "ggml.h"
+
+namespace ggml_sycl_reordered {
+
+
+// The reordered block moves quants (qs) and scales(d) to two
+// uniform regions of memory that is contiguous in the same tensor.
+// What this means is that instead of having:
+// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN]
+// We have:
+// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN]
+//
+// Notes: out-of-bounds qs will run into d values
+// Aligment relies on the allocated size of qs
+
+template <ggml_type type> struct block_q_t;
+
+
+// qk number of weights / quants in a block
+// qr number of weights in a byte (described as 'before dequantization')
+// for quantization types that has low and high bits split, qr is calculated with
+// using the lower bits, e.g for Q6 quants QR6 is 2
+// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field)
+// See ggml-common.h to see how these are calculated
+template <> struct block_q_t<GGML_TYPE_Q4_0> {
+ struct traits {
+ static constexpr uint32_t qk = QK4_0;
+ static constexpr uint32_t qi = QI4_0;
+ static constexpr uint32_t qr = QR4_0;
+ static constexpr uint32_t vdr_mmvq = 2;
+ };
+
+ static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
+
+ static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
+ return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
+ }
+
+ static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
+};
+
+} // namespace ggml_sycl_reordered
+
+#endif // GGML_SYCL_QUANTS_HPP
//
// MIT license
-// Copyright (C) 2024 Intel Corporation
+// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
#define GGML_SYCL_VECDOTQ_HPP
#include "dpct/helper.hpp"
+#include "ggml.h"
+#include "quants.hpp"
-typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
+ const int & iqs);
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
const uint16_t* x16 =
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+template <ggml_type T> struct reorder_vec_dot_q_sycl {
+ static_assert(T != T, "ggml_type for reorder vecdot not implemented");
+};
+
+template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
+ static constexpr ggml_type gtype = GGML_TYPE_Q4_0;
+
+ using q4_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_0>;
+ using q4_0_traits = typename q4_0_block::traits;
+
+ __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) {
+ int sumi = 0;
+
+#pragma unroll
+ for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
+ sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
+ }
+
+ const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
+
+ // second part effectively subtracts 8 from each quant value
+ return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
+ }
+
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
+ const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+ const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
+ const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
+ int v[q4_0_traits::vdr_mmvq];
+ int u[2 * q4_0_traits::vdr_mmvq];
+
+#pragma unroll
+
+ for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
+ v[i] = get_int_from_uint8(bq4_0, iqs + i);
+ u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
+ }
+
+ return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
+ };
+};
+
#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4
template <int vdr>
-static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
- const float &d4,
- const sycl::half2 &ds8) {
+static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4,
+ const sycl::half2 & ds8) {
int sumi = 0;
#pragma unroll
for (int i = 0; i < vdr; ++i) {
sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
}
- const sycl::float2 ds8f =
- ds8.convert<float, sycl::rounding_mode::automatic>();
+ const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
// second part effectively subtracts 8 from each quant value
return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
int v[VDR_Q4_0_Q8_1_MMVQ];
- int u[2*VDR_Q4_0_Q8_1_MMVQ];
+ int u[2 * VDR_Q4_0_Q8_1_MMVQ];
#pragma unroll
for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
- v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
- u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
- u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
+ v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
+ u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
}
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);