--- /dev/null
+#include <sycl/sycl.hpp>
+#include "common.hpp"
+#include "add-id.hpp"
+
+static void add_id_kernel(
+ const float* src0,
+ const float* src1,
+ const int32_t* src2,
+ float* dst,
+ int64_t ne0,
+ int64_t ne1,
+ size_t nb01,
+ size_t nb02,
+ size_t nb11,
+ size_t nb21,
+ sycl::nd_item<3> item_ct1) {
+ const int64_t i1 = item_ct1.get_group(2);
+ const int64_t i2 = item_ct1.get_group(1);
+
+ const int i11 =
+ *(const int32_t*)((const char*)src2 + i1 * sizeof(int32_t) + i2 * nb21);
+
+ const size_t nb1 = ne0 * sizeof(float);
+ const size_t nb2 = ne1 * nb1;
+
+ float* dst_row = (float*)((char*)dst + i1 * nb1 + i2 * nb2);
+ const float* src0_row =
+ (const float*)((const char*)src0 + i1 * nb01 + i2 * nb02);
+ const float* src1_row = (const float*)((const char*)src1 + i11 * nb11);
+
+ for (int64_t i0 = item_ct1.get_local_id(2); i0 < ne0;
+ i0 += item_ct1.get_local_range(2)) {
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
+ }
+}
+
+void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+ const ggml_tensor* src0 = dst->src[0];
+ const ggml_tensor* src1 = dst->src[1];
+ const ggml_tensor* src2 = dst->src[2];
+
+ GGML_TENSOR_TERNARY_OP_LOCALS
+
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+ GGML_ASSERT(nb20 == sizeof(int32_t));
+
+ const float* src0_d = (const float*)src0->data;
+ const float* src1_d = (const float*)src1->data;
+ const int32_t* src2_d = (const int32_t*)src2->data;
+ float* dst_d = (float*)dst->data;
+
+ int threads = std::min((int)ne00, 768); // cols
+ ctx.stream()->parallel_for(
+ sycl::nd_range<3>(
+ sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),
+ sycl::range<3>(1, 1, threads)),
+ [=](sycl::nd_item<3> item_ct1) {
+ add_id_kernel(
+ src0_d,
+ src1_d,
+ src2_d,
+ dst_d,
+ ne0,
+ ne1,
+ nb01,
+ nb02,
+ nb11,
+ nb21,
+ item_ct1);
+ });
+}
--- /dev/null
+#ifndef GGML_SYCL_ADD_ID_HPP
+#define GGML_SYCL_ADD_ID_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_add_id(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#endif // GGML_SYCL_ADD_ID_HPP
return sycl::uint2(div_val, mod_val);
}
+static __dpct_inline__ int ggml_sycl_dp4a(const int a, const int b, int c) {
+ return dpct::dp4a(a, b, c);
+}
+
+static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {
+ uint32_t bits;
+ if (x == 0) {
+ bits = 0x00400000;
+ } else {
+ bits = (uint32_t) x << 23;
+ }
+
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+}
+
#endif // GGML_SYCL_COMMON_HPP
}
}
+template <typename dst_t>
+static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_mxfp4(vx, y, item_ct1);
+ });
+}
+
template <typename src_t, typename dst_t>
static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
}
+
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_iq4_xs_sycl;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
#ifdef GGML_SYCL_HAS_BF16
return dequantize_row_iq4_xs_sycl;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_sycl;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
#ifdef GGML_SYCL_HAS_BF16
}
}
+template<typename dst_t>
+static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ const sycl::nd_item<3> &item_ct1) {
+ // auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int64_t i = item_ct1.get_group(2);
+ const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
+
+ const int64_t tid = item_ct1.get_local_id(2);
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[ib].qs + 4*il;
+ const float d = ggml_sycl_e8m0_to_fp32(x[ib].e);
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
+ y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
+ }
+}
#endif // GGML_SYCL_DEQUANTIZE_HPP
: id);
}
+ template <typename T1, typename T2>
+ using dot_product_acc_t = std::conditional_t<
+ std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
+ uint32_t,
+ int32_t>;
+
+ template <typename T>
+ sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {
+ return sycl::vec<T, 1>(val)
+ .template as<sycl::vec<
+ std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>,
+ 4>>()
+ .template convert<T>();
+ }
+
template <typename T1, typename T2, typename T3>
- inline auto dp4a(T1 a, T2 b, T3 c)
- {
- return syclcompat::dp4a(a, b, c);
+ inline auto dp4a(T1 a, T2 b, T3 c) {
+ dot_product_acc_t<T1, T2> res = c;
+ auto va = extract_and_sign_or_zero_extend4(a);
+ auto vb = extract_and_sign_or_zero_extend4(b);
+ res += va[0] * vb[0];
+ res += va[1] * vb[1];
+ res += va[2] * vb[2];
+ res += va[3] * vb[3];
+ return res;
}
struct sub_sat
atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
}
+ inline unsigned int byte_level_permute(
+ unsigned int a, unsigned int b, unsigned int s) {
+ unsigned int ret;
+ ret = ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) |
+ (((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff)
+ << 8) |
+ (((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff)
+ << 16) |
+ (((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff)
+ << 24);
+ return ret;
+ }
+
+ inline uint32_t byte_level_permute_custom(
+ uint32_t low32, uint32_t high32, uint32_t sel, int mode = 0) {
+ constexpr uint16_t lookup[6][4] = {
+ {0x3210, 0x4321, 0x5432, 0x6543}, // Forward 4-byte extract
+ {0x5670, 0x6701, 0x7012, 0x0123}, // Backward 4-byte extract
+ {0x0000, 0x1111, 0x2222, 0x3333}, // Replicate 8-bit values
+ {0x3210, 0x3211, 0x3222, 0x3333}, // Edge clamp left
+ {0x0000, 0x1110, 0x2210, 0x3210}, // Edge clamp right
+ {0x1010, 0x3232, 0x1010, 0x3232} // Replicate 16-bit values
+ };
+
+ if (mode >= 1 && mode <= 6) {
+ return byte_level_permute(low32, high32, lookup[mode - 1][sel & 0x3]);
+ } else if (!mode) {
+ return byte_level_permute(low32, high32, sel);
+ }
+ return 0;
+ }
+
} // COPY from DPCT head files
#endif // GGML_SYCL_DPCT_HELPER_HPP
});
}
+__dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
+ x = sycl::fmin(x, limit);
+ g = sycl::fmax(sycl::fmin(g, limit), -limit);
+
+ float out_glu = x / (1.0f + sycl::native::exp(-x * alpha));
+ out_glu = out_glu * (1.0f + g);
+ return out_glu;
+}
+
+
+template <typename T>
+static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
+ const int64_t n, const int64_t o0, const int64_t o1,
+ float alpha, float limit, sycl::nd_item<3> item_ct1) {
+ const int64_t i = int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+
+ const int64_t j0 = (i / n) * o0 + (i % n);
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+ float xi = x[j0];
+ float gi = g[j1];
+
+ dst[i] = ggml_sycl_op_swiglu_oai_single(xi, gi, alpha, limit);
+}
+
+template <typename T>
+static void swiglu_oai_sycl(const T * x,
+ const T * g,
+ T * dst,
+ const int64_t k,
+ const int64_t n,
+ const int64_t o0,
+ const int64_t o1,
+ const float alpha,
+ const float limit,
+ dpct::queue_ptr stream) {
+ const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
+ });
+}
+
+void ggml_sycl_op_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ void * src0_d = src0->data;
+ void * src1_d = src1 ? src1->data : src0->data;
+ const int64_t src0_o = src0->nb[1];
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+ void * dst_d = dst->data;
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+ dpct::queue_ptr stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == dst->type);
+ GGML_ASSERT(dst->ne[0] == nc);
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+ if (src1) {
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+ GGML_ASSERT(src1->ne[0] == nc);
+ GGML_ASSERT(src0->type == src1->type);
+ }
+
+ //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float limit = ggml_get_op_params_f32(dst, 3);
+
+ float * src0_p = (float *) src0_d;
+ float * src1_p = (float *) src1_d;
+
+ if (!src1) {
+ src0_p += swapped ? nc : 0;
+ src1_p += swapped ? 0 : nc;
+ }
+
+ swiglu_oai_sycl(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
+}
+
static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
ggml_sycl_op_swiglu(ctx, dst);
}
+void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+ ggml_sycl_op_swiglu_oai(ctx, dst);
+}
+
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu_erf(ctx, dst);
#include "ggml.h"
#include <limits> // For std::numeric_limits
+#define SYCL_GLU_BLOCK_SIZE 256
+
template <typename T>
T neg_infinity() {
return -std::numeric_limits<T>::infinity();
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
+#include "ggml-sycl/add-id.hpp"
#include "ggml-sycl/backend.hpp"
#include "ggml-sycl/common.hpp"
#include "ggml-sycl/element_wise.hpp"
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
// mmvq and mmq need the __dp4a instruction which is available for gen12+
// Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
#endif // SYCL_USE_XMX
-
// mmvq path is faster in the CUDA backend.
if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
case GGML_OP_ADD1: // TODO: more efficient implementation
ggml_sycl_add(ctx, dst);
break;
+ case GGML_OP_ADD_ID:
+ ggml_sycl_add_id(ctx, dst);
+ break;
case GGML_OP_SUB:
ggml_sycl_sub(ctx, dst);
break;
case GGML_GLU_OP_SWIGLU:
ggml_sycl_swiglu(ctx, dst);
break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ ggml_sycl_swiglu_oai(ctx, dst);
+ break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_sycl_geglu_erf(ctx, dst);
break;
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
}
}
ggml_type src0_type = op->src[0]->type;
- if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
- // TODO: support MXFP4
+ if (src0_type == GGML_TYPE_BF16 ) {
+ // TODO: support GGML_TYPE_BF16
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
+
// TODO: The configuration below needs more work to be supported with oneDNN
- if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
- return false;
+ if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
+ a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
+ return false;
}
+
// TODO: This specific configuration can fail with oneDNN and needs more debugging
if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
- return true;
case GGML_OP_ADD:
case GGML_OP_ADD1:
+ case GGML_OP_ADD_ID:
case GGML_OP_SUB:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_MUL:
}
}
+static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_MXFP4 == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+
+ {
+ stream->submit([&](sycl::handler & cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
const int nrows,
case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
default:
GGML_ABORT("fatal error");
}
#include "pad.hpp"
static void pad_f32(const float * src, float * dst,
- const int lp0, const int rp0, const int lp1, const int rp1,
- const int lp2, const int rp2, const int lp3, const int rp3,
- const int ne0, const int ne1, const int ne2, const int ne3) {
- auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int lp0, const int rp0, const int lp1, const int rp1,
+ const int lp2, const int rp2, const int lp3, const int rp3,
+ const int ne0, const int ne1, const int ne2, const int ne3,
+ sycl::nd_item<3> item_ct1) {
int i0 = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
int i1 = item_ct1.get_group(1);
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
- ne2, ne3);
+ ne2, ne3, item_ct1);
});
}
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
- GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast<int>(sizeof(float)));
+ GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
const int src_stride_inner = ncs;
const int src_stride_seq = ncs * d_inner;
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
const int & iqs);
+static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
+ const uint8_t * x8 = (const uint8_t *) x;
+
+ int x32 = x8[4*i32 + 0] << 0;
+ x32 |= x8[4*i32 + 1] << 8;
+ x32 |= x8[4*i32 + 2] << 16;
+ x32 |= x8[4*i32 + 3] << 24;
+
+ return x32;
+}
+
+
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
const uint16_t* x16 =
(const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte
val2 = v1 | (v2 << 16);
}
+static __dpct_inline__ sycl::int2 get_int_from_table_16(
+ const int& q4, const int8_t* table) {
+ const uint32_t* table32 = (const uint32_t*)table;
+ uint32_t tmp[2];
+ const uint32_t low_high_selection_indices =
+ (0x32103210 | ((q4 & 0x88888888) >> 1));
+#pragma unroll
+ for (uint32_t i = 0; i < 2; ++i) {
+ const uint32_t shift = 16 * i;
+
+ const uint32_t low =
+ dpct::byte_level_permute(table32[0], table32[1], q4 >> shift);
+ const uint32_t high =
+ dpct::byte_level_permute(table32[2], table32[3], q4 >> shift);
+ tmp[i] = dpct::byte_level_permute(
+ low, high, low_high_selection_indices >> shift);
+ }
+ return sycl::int2(
+ dpct::byte_level_permute(tmp[0], tmp[1], 0x6420),
+ dpct::byte_level_permute(tmp[0], tmp[1], 0x7531));
+}
+
#define VDR_Q2_K_Q8_1_MMVQ 1
// contiguous v/x values
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
}
+#define VDR_MXFP4_Q8_1_MMVQ 2
+#define VDR_MXFP4_Q8_1_MMQ 4
+
+static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
+ const block_q8_1 * __restrict__ bq8_1,
+ const int & iqs) {
+ const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq;
+
+ const int * q8 = (const int *) bq8_1->qs + iqs;
+
+ int sumi = 0;
+#pragma unroll
+ for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
+ const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
+ const sycl::int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+ sumi = ggml_sycl_dp4a(v.x(), q8[l + 0], sumi);
+ sumi = ggml_sycl_dp4a(v.y(), q8[l + 4], sumi);
+ }
+
+ const float d = ggml_sycl_e8m0_to_fp32(bq4->e) * 0.5f * (bq8_1->ds)[0];
+ return d * sumi;
+}
+
+
static __dpct_inline__ float
vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {