]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sycl: refactor quantization to q8_1 (#14815)
authorAlberto Cabrera Pérez <redacted>
Mon, 28 Jul 2025 10:05:53 +0000 (11:05 +0100)
committerGitHub <redacted>
Mon, 28 Jul 2025 10:05:53 +0000 (11:05 +0100)
* sycl: quantization to q8_1 refactor

* Refactored src1 copy logic in op_mul_mat

ggml/src/ggml-sycl/backend.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/quantize.hpp [new file with mode: 0644]

index f839a42bc90c99cd280d114bad6c6dddf9b17436..410a67b0195265a7bf203d2283f29bd43a819a64 100644 (file)
@@ -28,6 +28,7 @@
 #include "mmvq.hpp"
 #include "norm.hpp"
 #include "outprod.hpp"
+#include "quantize.hpp"
 #include "quants.hpp"
 #include "rope.hpp"
 #include "set_rows.hpp"
index a023d6fb4525ba09670af311b44540dde7ddc855..b08941c328b7d4fc1a6467fbc849c047aa861235 100644 (file)
@@ -44,6 +44,7 @@
 #include "ggml-sycl/set_rows.hpp"
 #include "ggml-sycl/sycl_hw.hpp"
 #include "ggml-sycl/getrows.hpp"
+#include "ggml-sycl/quantize.hpp"
 #include "ggml.h"
 
 static bool g_sycl_loaded = false;
@@ -1373,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
 
 
 
-template<int QUANT_BLOCK_TILE>
-static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
-
-    if (ix >= kx_padded) {
-        return;
-    }
-
-    const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                   item_ct1.get_local_id(1);
-
-    const int i_padded = iy*kx_padded + ix;
-
-    block_q8_1 * y = (block_q8_1 *) vy;
-
-    const int ib = i_padded / QK8_1; // block index
-    const int iqs = i_padded % QK8_1; // quant index
-    typedef  sycl::vec<float, QUANT_BLOCK_TILE> TC;
-    typedef  sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
-    TC zeros;
-    TQ qzeros;
-#pragma unroll
-    for (int i = 0; i < QUANT_BLOCK_TILE; i++)
-    {
-        zeros[i] = 0.f;
-        qzeros[i] = 0;
-    }
-    const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
-    float sum = xi[0];
-    float amax = sycl::fabs(xi[0]);
-#pragma unroll
-    for (int i = 1; i < QUANT_BLOCK_TILE; i++)
-    {
-        sum += xi[i];
-        amax = sycl::fmax(sycl::fabs(xi[i]), amax);
-    }
-    sum = warp_reduce_sum(sum, item_ct1);
-    amax = warp_reduce_max(amax, item_ct1);
-
-    const float d = amax / 127;
-    TQ q = qzeros;
-    if (amax != 0.0f)
-    {
-#pragma unroll
-        for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
-            q[i] = sycl::round(xi[i] / d);
-        }
-    }
-
-    *(TQ *)&y[ib].qs[iqs] = q;
-
-    if (iqs > 0) {
-        return;
-    }
-
-    reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
-    reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
-}
-
-template <int ElementsPerWI>
-static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
-                                                      const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
-    /*
-        Quantizes and reorders the resultant q8 tensor in a per row fashion
-        Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
-    */
-
-    auto subgroup_id = it.get_group(0);
-    auto wi_id       = it.get_local_id(0);
-
-    const int num_blocks_per_row = kx / QK8_1;
-    auto      row                = subgroup_id / num_blocks_per_row;
-    auto      col                = subgroup_id % num_blocks_per_row;
-
-    auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
-    auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
-
-    auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
-    auto ds_ptr    = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
-
-    sycl::vec<float, ElementsPerWI>  wi_f32_vals;
-    sycl::vec<int8_t, ElementsPerWI> quantized_values;
-
-    auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
-    wi_f32_vals           = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
-
-    float sum  = 0.0f;
-    float amax = 0.0f;
-
-#pragma unroll(ElementsPerWI)
-    for (int i = 0; i < ElementsPerWI; i++) {
-        sum += wi_f32_vals[i];
-        amax                = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
-        quantized_values[i] = 0;
-    }
-    sum     = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
-    amax    = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
-    float d = amax == 0 ? 1 : amax / 127;
-
-#pragma unroll(ElementsPerWI)
-    for (int i = 0; i < ElementsPerWI; i++) {
-        quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
-    }
-
-    d = amax == 0 ? 0 : d;
-
-    *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
-    if (wi_id == 0) {
-        *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
-    }
-}
-
 static void mul_mat_p021_f16_f32(
     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1770,32 +1657,6 @@ static  void pool2d_nchw_kernel(
         o_ptr[cur_oh * ow + cur_ow] = res;
 }
 
-static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
-                                   bool reorder_q8_tensor, queue_ptr stream) {
-    if (reorder_q8_tensor) {
-        auto local_range      = std::size_t(WARP_SIZE);
-        auto num_quant_blocks = ky * (kx / QK8_1);
-        auto global_range     = num_quant_blocks * local_range;
-        stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
-                             [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                                 quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
-                             });
-    } else {
-        const int            block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
-        const sycl::range<3> num_blocks(1, ky, block_num_x);
-        int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
-        static_assert(QK8_1 % WARP_SIZE == 0);
-        const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
-        {
-            dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-
-            stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
-                                 [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                                     quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
-                                 });
-        }
-    }
-}
 
 static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
                                            float *dst, const int ncols_x,
@@ -2372,10 +2233,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
     peer_access_enabled = enable_peer_access;
 }
 
+template <template <int> typename quantize_f>
 static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                 ggml_sycl_op_mul_mat_t op,
-                                 const bool convert_src1_to_q8_1) try {
+                                 ggml_sycl_op_mul_mat_t op) try {
 
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
 
@@ -2470,6 +2331,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
         }
     }
 
+    constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
+                                                      no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
     for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
         if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
             continue;
@@ -2495,20 +2358,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
             dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
         }
 
-        if (convert_src1_to_q8_1) {
+        if constexpr(quantize_enabled) {
             dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
 
             if (src1_on_device && src1_is_contiguous) {
-                bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
                 scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
                                                      /*num_src=*/2, " : converting src1 to Q8_1");
-                quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
-                /*
-                DPCT1010:90: SYCL uses exceptions to report errors and does not
-                use the error codes. The call was replaced with 0. You need to
-                rewrite this code.
-                */
-                SYCL_CHECK(0);
+                try {
+                    quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
+                } catch (sycl::exception const &exc) {
+                    std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
+                              << ", line:" << __LINE__ << std::endl;
+                    std::exit(1);
+                }
             }
         }
 
@@ -2524,11 +2386,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
     // here an event is recorded that signals that the main device has finished calculating the input data
     if (split && used_devices > 1) {
         ggml_sycl_set_device(ctx.device);
-        /*
-        DPCT1024:91: The original code returned the error code that was further
-        consumed by the program logic. This original code was replaced with 0.
-        You may need to rewrite the program logic consuming the error code.
-        */
         SYCL_CHECK(CHECK_TRY_ERROR(
             *src0_extra->events[ctx.device][0] =
                 ctx.stream()->ext_oneapi_submit_barrier()));
@@ -2552,11 +2409,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
 
             // wait for main GPU data if necessary
             if (split && (i != ctx.device || is != 0)) {
-                /*
-                DPCT1009:163: SYCL uses exceptions to report errors and does not
-                use the error codes. The original code was commented out and a
-                warning string was inserted. You need to rewrite this code.
-                */
                 SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
                     {*src0_extra->events[ctx.device][0]})));
             }
@@ -2582,39 +2434,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
                 // copy src0, src1 to device if necessary
                 if (src1_is_contiguous) {
                     if (i != ctx.device) {
-                        if (convert_src1_to_q8_1) {
+                        if constexpr (quantize_enabled) {
                             char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
-                          SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
-                                src1_ddq_i, src1_ddq_i_source,
-                                src1_ncols * src1_padded_col_size * q8_1_ts /
-                                    q8_1_bs).wait()));
+                            SYCL_CHECK(
+                                CHECK_TRY_ERROR(stream
+                                                    ->memcpy(src1_ddq_i, src1_ddq_i_source,
+                                                             src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
+                                                    .wait()));
                         } else {
-
                             float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
-                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+                            src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
 
-                            SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
-                                src1_ddf_i, src1_ddf_i_source,
-                                src1_ncols * ne10 * sizeof(float))));
+                            SYCL_CHECK(
+                                CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
+                                                               src1_ncols * ne10 * sizeof(float))));
                         }
                     }
-                } else if (src1_on_device && !src1_is_contiguous) {
-                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
-                                   src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
                 } else {
-                    GGML_ABORT("fatal error");
-                }
+                    if (src1_on_device) {
+                        SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
+                                                           src1_col_0 + src1_ncols, stream));
+                    } else {
+                        GGML_ABORT("src1 is non-contiguous and not on device");
+                    }
 
-                if (convert_src1_to_q8_1 && !src1_is_contiguous) {
-                    scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
-                                                         /*num_src=*/2, " : converting src1 to Q8_1");
-                    quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
-                    /*
-                    DPCT1010:92: SYCL uses exceptions to report errors and does
-                    not use the error codes. The call was replaced with 0. You
-                    need to rewrite this code.
-                    */
-                    SYCL_CHECK(0);
+                    if constexpr (quantize_enabled) {
+                        scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
+                                                             /*num_src=*/2, " : converting src1 to Q8_1");
+                        try {
+                            quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
+                                                                  src1_padded_col_size, stream);
+                        } catch (const sycl::exception & exc) {
+                            std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
+                                      << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
+                            std::exit(1);
+                        }
+                    }
                 }
 
                 if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
@@ -2626,12 +2481,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
                 // do the computation
                 SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
                     dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
-                /*
-                DPCT1010:93: SYCL uses exceptions to report errors and does not
-                use the error codes. The call was replaced with 0. You need to
-                rewrite this code.
-                */
-                SYCL_CHECK(0);
 
                 // copy dst to host or other device if necessary
                 if (!dst_on_device) {
@@ -2662,12 +2511,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
 
                 // add event for the main device to wait on until other device is done
                 if (split && (i != ctx.device || is != 0)) {
-                    /*
-                    DPCT1024:94: The original code returned the error code that
-                    was further consumed by the program logic. This original
-                    code was replaced with 0. You may need to rewrite the
-                    program logic consuming the error code.
-                    */
                     SYCL_CHECK(CHECK_TRY_ERROR(
                         *src0_extra->events[i][is] =
                             stream->ext_oneapi_submit_barrier()));
@@ -3351,19 +3194,20 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
         // KQ + KQV multi-batch
         ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
-        constexpr bool convert_src1_to_q8_1 = false;
         opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
+        ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
     } else if (use_mul_mat_vec_q) {
-        constexpr bool convert_src1_to_q8_1 = true;
         opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
+        ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
+        if (extra && extra->optimized_feature.reorder) {
+            ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
+        } else {
+            ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
+        }
     } else if (use_mul_mat_q) {
-        constexpr bool convert_src1_to_q8_1 = true;
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
+        ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
     } else {
-        constexpr bool convert_src1_to_q8_1 = false;
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
+        ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
     }
 }
 
diff --git a/ggml/src/ggml-sycl/quantize.hpp b/ggml/src/ggml-sycl/quantize.hpp
new file mode 100644 (file)
index 0000000..b5c7a54
--- /dev/null
@@ -0,0 +1,133 @@
+/***************************************************************************
+ *
+ *  Copyright (C) 2025 Codeplay Software Ltd.
+ *  Copyright (C) 2025 Intel Corporation
+ *
+ *  MIT License
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License.
+ *
+ *  quantize.hpp
+ *
+ *  Description:
+ *     Sycl backend specific quantization functions
+ **************************************************************************/
+
+#pragma once
+
+#include <sycl/nd_item.hpp>
+
+#include "ggml-sycl/dpct/helper.hpp"
+
+template <int ElementsPerWI>
+__dpct_inline__ static void quantize_q8_1_impl(const float * __restrict__ x,
+                                               sycl::vec<int8_t, ElementsPerWI> & quantized_values, float & d,
+                                               float & sum, const sycl::nd_item<1> & it) {
+    auto subgroup_id = it.get_group(0);
+    auto wi_id       = it.get_local_id(0);
+
+    sycl::vec<float, ElementsPerWI> wi_f32_vals;
+
+    auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
+    wi_f32_vals           = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
+
+    float amax = 0.0f;
+
+#pragma unroll(ElementsPerWI)
+    for (int i = 0; i < ElementsPerWI; i++) {
+        sum += wi_f32_vals[i];
+        amax                = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
+        quantized_values[i] = 0;
+    }
+    sum  = sycl::reduce_over_group(it.get_sub_group(), sum, sycl::plus<float>());
+    amax = sycl::reduce_over_group(it.get_sub_group(), amax, sycl::maximum<float>());
+    d    = amax == 0 ? 1 : amax / 127;
+
+#pragma unroll(ElementsPerWI)
+    for (int i = 0; i < ElementsPerWI; i++) {
+        quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
+    }
+
+    d = amax == 0 ? 0 : d;
+}
+
+// No op to control codepath in ggml_sycl_op_mul_mat
+template <int ElementsPerWI> struct no_quantize_q8_1 {
+    void operator()(const float *, void *, int, int, const sycl::nd_item<1> &) const {}
+};
+
+template <int ElementsPerWI> struct quantize_and_reorder_q8_1_soa {
+    __dpct_inline__ void operator()(const float * __restrict__ x, void * reordered_q8_tensor, const int kx,
+                                    const int kx_padded, const sycl::nd_item<1> & it) const {
+        /*
+        Quantizes and reorders the resultant q8 tensor in a per row fashion
+        Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
+    */
+        auto subgroup_id = it.get_group(0);
+        auto wi_id       = it.get_local_id(0);
+
+        sycl::vec<int8_t, ElementsPerWI> quantized_values;
+        float                            d   = 0.0f;
+        float                            sum = 0.0f;
+        quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);
+
+        const int num_blocks_per_row = kx / QK8_1;
+        auto      row                = subgroup_id / num_blocks_per_row;
+        auto      col                = subgroup_id % num_blocks_per_row;
+        auto      row_offset         = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
+        auto      col_offset         = QK8_1 * col + wi_id * ElementsPerWI;
+
+        auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
+        *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
+
+        auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
+        if (wi_id == 0) {
+            *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
+        }
+    }
+};
+
+template <int ElementsPerWI> struct quantize_q8_1 {
+    __dpct_inline__ void operator()(const float * __restrict__ x, void * q8_tensor, const int kx, const int kx_padded,
+                                    const sycl::nd_item<1> & it) const {
+        auto subgroup_id = it.get_group(0);
+        auto wi_id       = it.get_local_id(0);
+
+        const int num_blocks_per_row = kx / QK8_1;
+        auto      row                = subgroup_id / num_blocks_per_row;
+        const int pitch              = kx_padded / QK8_1;
+
+        sycl::vec<int8_t, ElementsPerWI> quantized_values;
+        float                            d   = 0.0f;
+        float                            sum = 0.0f;
+        quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);
+
+        block_q8_1 * quant_ptr = (block_q8_1 *) q8_tensor;
+        auto         block_id  = subgroup_id % num_blocks_per_row + row * pitch;
+
+        int8_t * qs                                               = &(quant_ptr[block_id].qs[wi_id * ElementsPerWI]);
+        *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(qs) = quantized_values;
+        if (wi_id == 0) {
+            quant_ptr[block_id].ds = sycl::half2(sycl::half(d), sycl::half(sum));
+        }
+    }
+};
+
+template <template <int> typename quantize_f>
+void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
+                            dpct::queue_ptr stream) {
+    static_assert(QK8_1 % WARP_SIZE == 0);
+    auto local_range      = std::size_t(WARP_SIZE);
+    auto num_quant_blocks = ky * (kx / QK8_1);
+    auto global_range     = num_quant_blocks * local_range;
+    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
+
+    stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
+                         [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                             quantize_f<QK8_1 / WARP_SIZE>()(x, vy, kx, kx_padded, it);
+                         });
+}