]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
[SYCL] Optimize mul_mat for Q4_0 on Intel GPU (#12035)
authorNeo Zhang Jianyu <redacted>
Mon, 24 Feb 2025 14:33:23 +0000 (22:33 +0800)
committerGitHub <redacted>
Mon, 24 Feb 2025 14:33:23 +0000 (22:33 +0800)
* opt performance by reorder for Intel GPU

* detect hw type and save opt feature, and print opt feature

* correct name

* support optimize graph once when compute graph, record the opt status in tensor->extra, make CI passed

* add env variable GGML_SYCL_DISABLE_OPT for debug

* use syclex::architecture replace the custom hw define, update the guide for GGML_SYCL_DISABLE_OPT

* add performance data

* mv getrows functions to separeted files

* fix global variables

---------

Co-authored-by: arthw <redacted>
14 files changed:
docs/backend/SYCL.md
examples/sycl/run-llama2.sh
ggml/src/ggml-sycl/CMakeLists.txt
ggml/src/ggml-sycl/common.cpp
ggml/src/ggml-sycl/common.hpp
ggml/src/ggml-sycl/convert.cpp
ggml/src/ggml-sycl/convert.hpp
ggml/src/ggml-sycl/dequantize.hpp
ggml/src/ggml-sycl/dmmv.cpp
ggml/src/ggml-sycl/getrows.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/getrows.hpp [new file with mode: 0644]
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/sycl_hw.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/sycl_hw.hpp [new file with mode: 0644]

index 0cb39e7927cd6fd9cfa7236ba59edaf6247eca63..5da439e94e09267cebf2f38829e12ce628acd527 100644 (file)
@@ -42,6 +42,16 @@ The following release is verified with good quality:
 
 ## News
 
+- 2025.2
+  - Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
+    |GPU|Base tokens/s|Increased tokens/s|Percent|
+    |-|-|-|-|
+    |PVC 1550|39|73|+87%|
+    |Flex 170|39|50|+28%|
+    |Arc770|42|55|+30%|
+    |MTL|13|16|+23%|
+    |ARL-H|14|17|+21%|
+
 - 2024.11
   - Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer.
 
@@ -97,8 +107,8 @@ SYCL backend supports Intel GPU Family:
 | Intel Data Center Max Series  | Support | Max 1550, 1100                        |
 | Intel Data Center Flex Series | Support | Flex 170                              |
 | Intel Arc Series              | Support | Arc 770, 730M, Arc A750               |
-| Intel built-in Arc GPU        | Support | built-in Arc GPU in Meteor Lake       |
-| Intel iGPU                    | Support | iGPU in 13700k, i5-1250P, i7-1260P, i7-1165G7 |
+| Intel built-in Arc GPU        | Support | built-in Arc GPU in Meteor Lake, Arrow Lake    |
+| Intel iGPU                    | Support | iGPU in 13700k,iGPU in 13400, i5-1250P, i7-1260P, i7-1165G7 |
 
 *Notes:*
 
@@ -660,8 +670,10 @@ use 1 SYCL GPUs: [0] with Max compute units:512
 | Name              | Value            | Function                                                                                                                  |
 |-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
 | GGML_SYCL_DEBUG   | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG                                                                             |
+| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
 | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
 
+
 ## Known Issues
 
 - `Split-mode:[row]` is not supported.
index 3b9ba3b2da4914c5e9f4582623cc0e6b1d00e7eb..8ce71e37000a10df905877f8f7993cc3158b14ba 100755 (executable)
@@ -3,7 +3,7 @@
 #  MIT license
 #  Copyright (C) 2024 Intel Corporation
 #  SPDX-License-Identifier: MIT
-
+export ONEAPI_DEVICE_SELECTOR="level_zero:0"
 source /opt/intel/oneapi/setvars.sh
 
 #export GGML_SYCL_DEBUG=1
@@ -13,7 +13,7 @@ source /opt/intel/oneapi/setvars.sh
 INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
 MODEL_FILE=models/llama-2-7b.Q4_0.gguf
 NGL=33
-CONEXT=8192
+CONEXT=4096
 
 if [ $# -gt 0 ]; then
     GGML_SYCL_DEVICE=$1
index 3579a311aac0784ab1971823132d84473455c132..3ad044432a27d623b12c695713b7f99b3eac2da8 100644 (file)
@@ -1,3 +1,5 @@
+message(STATUS  "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}")
+
 if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
     message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
 endif()
index 022e7b7637bd3c67cb9f681bfb84cf259fd9c436..9069c47865f68868628a05bfa332bc10a9fa1e0a 100644 (file)
@@ -99,3 +99,20 @@ catch (sycl::exception const &exc) {
             << ", line:" << __LINE__ << std::endl;
   std::exit(1);
 }
+
+
+void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) {
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
+            if (extra->events[i][is] != nullptr) {
+                SYCL_CHECK(CHECK_TRY_ERROR(dpct::destroy_event(extra->events[i][is])));
+            }
+        }
+        if (extra->data_device[i] != nullptr && streams.size()>0) {
+            ggml_sycl_set_device(i);
+            SYCL_CHECK(
+                CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i]))));
+        }
+    }
+    delete extra;
+}
index a5cab5065fc801e25a986770520fa41d7ca3e06f..7c503a1b10e4aabf63867e4edd985dec59e24014 100644 (file)
@@ -19,6 +19,9 @@
 #include "dpct/helper.hpp"
 #include "ggml-sycl.h"
 #include "presets.hpp"
+#include "sycl_hw.hpp"
+
+
 #if GGML_SYCL_DNNL
 #include "dnnl.hpp"
 #include "dnnl_sycl.hpp"
 void* ggml_sycl_host_malloc(size_t size);
 void ggml_sycl_host_free(void* ptr);
 
+
 extern int g_ggml_sycl_debug;
+extern int g_ggml_sycl_disable_optimize;
+
 #define GGML_SYCL_DEBUG(...)        \
   do {                              \
     if (g_ggml_sycl_debug)          \
@@ -182,18 +188,24 @@ inline dpct::err0 ggml_sycl_set_device(const int device) try {
 }
 
 //////////////////////
+struct optimize_feature {
+    bool reorder=false;
+};
+
+struct sycl_device_info {
+    int     cc;                 // compute capability
+    // int     nsm;                // number of streaming multiprocessors
+    // size_t  smpb;               // max. shared memory per block
+    bool    vmm;                // virtual memory support
+    size_t  total_vram;
+    sycl_hw_info hw_info;
+    optimize_feature opt_feature;
+};
+
 
 struct ggml_sycl_device_info {
     int device_count;
 
-    struct sycl_device_info {
-        int     cc;                 // compute capability
-        // int     nsm;                // number of streaming multiprocessors
-        // size_t  smpb;               // max. shared memory per block
-        bool    vmm;                // virtual memory support
-        size_t  total_vram;
-    };
-
     sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
 
     std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
@@ -260,17 +272,46 @@ struct ggml_tensor_extra_gpu {
                                        // tensors
   dpct::event_ptr events[GGML_SYCL_MAX_DEVICES]
                         [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs
+  optimize_feature optimized_feature;
 };
 
+void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
+
+inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
+    optimize_feature opt;
+
+    opt.reorder =
+        (arch == syclex::architecture::intel_gpu_dg1 ||
+         arch == syclex::architecture::intel_gpu_acm_g10 ||
+         arch == syclex::architecture::intel_gpu_acm_g11 ||
+         arch == syclex::architecture::intel_gpu_acm_g12 ||
+         arch == syclex::architecture::intel_gpu_pvc ||
+         arch == syclex::architecture::intel_gpu_pvc_vg ||
+         arch == syclex::architecture::intel_gpu_mtl_u ||
+         arch == syclex::architecture::intel_gpu_mtl_s ||
+         arch == syclex::architecture::intel_gpu_mtl_h ||
+         arch == syclex::architecture::intel_gpu_arl_u ||
+         arch == syclex::architecture::intel_gpu_arl_s ||
+         arch == syclex::architecture::intel_gpu_arl_h ||
+         arch == syclex::architecture::intel_gpu_bmg_g21 ||
+         arch == syclex::architecture::intel_gpu_lnl_m
+        );
+
+    return opt;
+}
+
 struct ggml_backend_sycl_context {
     int device;
     std::string name;
+    optimize_feature opt_feature;
+    bool optimized_graph=false;
 
     queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };
 
     explicit ggml_backend_sycl_context(int device) :
         device(device),
         name(GGML_SYCL_NAME + std::to_string(device)) {
+        opt_feature = ggml_sycl_info().devices[device].opt_feature;
     }
 
     queue_ptr stream(int device, int stream) {
@@ -680,5 +721,4 @@ bool gpu_has_xmx(sycl::device &dev);
 void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                  const ggml_tensor *src1, ggml_tensor *dst,
                                  const ggml_sycl_op_flatten_t op);
-
 #endif // GGML_SYCL_COMMON_HPP
index 05b01db2d8b05e2a1eee390e70e654750ee0865f..86b200e07030fa14ced39bd6eb9070f8881fcbcf 100644 (file)
@@ -125,6 +125,25 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
     }
 }
 
+template <typename dst_t>
+static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k,
+                                     dpct::queue_ptr stream) {
+
+    dpct::has_capability_or_fail(stream->get_device(),
+                                    {sycl::aspect::fp16});
+
+    int constexpr WARP_K = WARP_SIZE * QK4_0;
+    const int n_warp = (k + WARP_K - 1) / WARP_K;
+    GGML_ASSERT(k % 2 == 0);
+    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
+        sycl::range<3>(1, 1, WARP_SIZE),
+        sycl::range<3>(1, 1, WARP_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
+            dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
+        });
+
+}
+
 template <typename dst_t>
 static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
@@ -452,10 +471,15 @@ static void convert_unary_sycl(const void *__restrict__ vx,
     }
 }
 
-to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
+to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
     switch (type) {
         case GGML_TYPE_Q4_0:
-            return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
+            if (dst->src[0]->extra &&
+                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
+                return dequantize_row_q4_0_sycl_reorder;
+            } else {
+                return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
+            }
         case GGML_TYPE_Q4_1:
             return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
         case GGML_TYPE_Q5_0:
@@ -499,10 +523,15 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
     }
 }
 
-to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
+to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
     switch (type) {
         case GGML_TYPE_Q4_0:
-            return dequantize_row_q4_0_sycl;
+            if (dst->src[0]->extra &&
+                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
+                return dequantize_row_q4_0_sycl_reorder;
+            } else {
+                return dequantize_row_q4_0_sycl;
+            }
         case GGML_TYPE_Q4_1:
             return dequantize_row_q4_1_sycl;
         case GGML_TYPE_Q5_0:
index 0ce2874aaaef9cb519211b0343e54fcb7dd06ec6..355dae22b40758d6695dcd5519d8ad640e83d7d0 100644 (file)
@@ -21,7 +21,7 @@ using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
 typedef to_t_sycl_t<float> to_fp32_sycl_t;
 typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
 
-to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type);
-to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type);
+to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
+to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);
 
 #endif // GGML_SYCL_CONVERT_HPP
index b8304c3a274a2879760f9bc65645b567cc7883aa..651c2160d248fb06ee983c137483d77dd0ed6891 100644 (file)
@@ -16,6 +16,8 @@
 #include "common.hpp"
 
 typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
+typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
+                                            const int iqs, dfloat2 &v);
 
 static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
                                             const int iqs, dfloat2 &v) {
@@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
 #endif // GGML_SYCL_F16
 }
 
+static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
+                                            const int iqs, dfloat2 &v) {
+    // const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);
+
+    const int vui = *((const uint8_t *)qs+iqs);
+
+    v.x() = vui & 0xF;
+    v.y() = vui >> 4;
+
+#ifdef GGML_SYCL_F16
+    // v = v - {8.0f, 8.0f};
+    // v = v * {d, d};
+    v.s0() = (v.s0() - 8.0f) * d;
+    v.s1() = (v.s1() - 8.0f) * d;
+
+#else
+    v.x() = (v.x() - 8.0f) * d;
+    v.y() = (v.y() - 8.0f) * d;
+#endif // GGML_SYCL_F16
+}
+
 static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
                                             const int iqs, dfloat2 &v) {
     const block_q4_1 * x = (const block_q4_1 *) vx;
@@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
     }
 }
 
+template<typename dst_t>
+static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
+                                  const sycl::nd_item<3> &item_ct1) {
+
+    const int64_t i = item_ct1.get_group(2);
+    auto k=nb32;
+    // assume 32 threads
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int lane_ib = i * WARP_SIZE + tid;
+
+    if (lane_ib >= k / QK4_0) {
+        return;
+    }
+
+    dst_t * y_ptr = yy + lane_ib * QK4_0;
+
+    auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
+    auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
+
+    const float d = float(*s_ptr);
+
+#pragma unroll
+    for (int l = 0; l < QK4_0 / 2; ++l) {
+        int vq = qs[l];
+        y_ptr[l + 0] = d * ((vq & 0xF) - 8);
+        y_ptr[l + 16] = d * ((vq >> 4) - 8);
+    }
+
+}
+
 template<typename dst_t>
 static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
                                   const sycl::nd_item<3> &item_ct1) {
index 0d097357ce79be724d681fadd3b060b326e7e2d7..99d3859de897926fe6382d4297f1217424895128 100644 (file)
@@ -3,7 +3,6 @@
 #include "dequantize.hpp"
 #include "presets.hpp"
 
-
 static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const sycl::half *x = (const sycl::half *)vx;
 
@@ -91,6 +90,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
     }
 }
 
+template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_reorder>
+static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
+                                   const sycl::nd_item<3> &item_ct1) {
+    // qk = quantized weights per x block
+    // qr = number of quantized weights per data value in x block
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int tid = item_ct1.get_local_id(2);
+
+
+    const int ncols_left = ncols % (QK4_0*WARP_SIZE);
+    const int ncols_align = ncols - ncols_left;
+    const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
+    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+// partial sum for each thread
+#ifdef GGML_SYCL_F16
+    sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
+#else
+    float tmp = 0.0f;
+#endif // GGML_SYCL_F16
+    const char *d_ptr = (const char*)vx+ncols*nrows/2;
+    int i=0;
+    for (i = 0; i < ncols_align; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int ib = (row*ncols + col)/qk; // x block index
+        const int iqs = (col%qk)/qr; // x quant index
+        const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+        for (int j = 0; j < vals_per_iter; j += 2) {
+            // process 2 vals per j iter
+
+            // dequantize
+            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+            dfloat2 v;
+            dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
+
+            // matrix multiplication
+            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_SYCL_F16
+            dfloat2 t1{y[iybs + iqs + j / qr + 0],
+                        y[iybs + iqs + j / qr + y_offset]};
+
+            tmp += v * t1;
+#else
+            tmp += v.x() * y[iybs + iqs + j / qr + 0];
+            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
+#endif // GGML_SYCL_F16
+        }
+    }
+
+    for (; i < ncols; i += iter_stride) {
+        if (tid>=ncols_left/QK4_0) continue;
+        const int col = i + vals_per_iter*tid;
+        const int ib = (row*ncols + col)/qk; // x block index
+        const int iqs = (col%qk)/qr; // x quant index
+        const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+        for (int j = 0; j < vals_per_iter; j += 2) {
+            // process 2 vals per j iter
+
+            // dequantize
+            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+            dfloat2 v;
+            dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
+
+            // matrix multiplication
+            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_SYCL_F16
+            dfloat2 t1{y[iybs + iqs + j / qr + 0],
+                        y[iybs + iqs + j / qr + y_offset]};
+
+            tmp += v * t1;
+#else
+            tmp += v.x() * y[iybs + iqs + j / qr + 0];
+            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
+#endif // GGML_SYCL_F16
+        }
+    }
+
+    // sum up partial sums and write back result
+    const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
+    for (int mask = mask_start; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (tid == 0) {
+#ifdef GGML_SYCL_F16
+        dst[row] = tmp.x() + tmp.y();
+#else
+        dst[row] = tmp;
+#endif // GGML_SYCL_F16
+    }
+}
+
 static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
                                          float *dst, const int ncols,
                                          const int nrows,
@@ -759,6 +864,28 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
     }
 }
 
+static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
+                    vx, y, dst, ncols, nrows, item_ct1);
+            });
+    }
+}
+
 
 static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
                                              float *dst, const int ncols,
@@ -953,7 +1080,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
 
     const int64_t ne00 = src0->ne[0];
     const int64_t row_diff = row_high - row_low;
-
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
 #ifdef GGML_SYCL_F16
@@ -967,7 +1093,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
 
     if (src1_convert_f16) {
         src1_dfloat = src1_dfloat_a.alloc(ne00);
-        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
         GGML_ASSERT(to_fp16_sycl != nullptr);
         to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
     }
@@ -977,7 +1103,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&
+                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
+                dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            } else {
+                dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            }
             break;
         case GGML_TYPE_Q4_1:
             dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
@@ -1020,4 +1151,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
     GGML_UNUSED(src1_ddq_i);
     GGML_UNUSED(src1_ncols);
     GGML_UNUSED(src1_padded_row_size);
+    GGML_UNUSED(ctx);
 }
diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp
new file mode 100644 (file)
index 0000000..51c19f6
--- /dev/null
@@ -0,0 +1,308 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "ggml-impl.h"
+#include "common.hpp"
+#include "dequantize.hpp"
+#include "getrows.hpp"
+
+
+template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void k_get_rows(
+            const void * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12,
+            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
+
+    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+                     item_ct1.get_local_id(2)) *
+                    2;
+    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1);
+    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) /
+                    ne12;
+    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) %
+                    ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+
+    const int ib = i00/qk; // block index
+    const int iqs = (i00%qk)/qr; // quant index
+    const int iybs = i00 - i00%qk; // dst block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(src0_row, ib, iqs, v);
+
+    dst_row[iybs + iqs + 0] = v.x();
+    dst_row[iybs + iqs + y_offset] = v.y();
+}
+
+template<int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_recorder, typename dst_t>
+static void k_get_rows_reorder(
+            const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12,
+            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
+
+    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+                     item_ct1.get_local_id(2)) *
+                    2;
+    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1);
+    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) /
+                    ne12;
+    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) %
+                    ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+    auto ncols = ne00;
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+
+    const int src0_off = i01 * ncols + i00;
+    const int ib = src0_off / QK4_0; // block index
+    const int iqs = (i00%qk)/qr; // x quant index
+    const int iybs = i00 - i00%qk; // dst block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v);
+
+    dst_row[iybs + iqs + 0] = v.x();
+    dst_row[iybs + iqs + y_offset] = v.y();
+
+    GGML_UNUSED(nb01);
+    GGML_UNUSED(nb02);
+    GGML_UNUSED(nb03);
+}
+
+template<typename src0_t, typename dst_t>
+static void k_get_rows_float(
+            const src0_t * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12,
+            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
+
+    const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+                    item_ct1.get_local_id(2);
+    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1);
+    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) /
+                    ne12;
+    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) %
+                    ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+    dst_row[i00] = src0_row[i00];
+}
+
+template <int qk, int qr, dequantize_kernel_t dq>
+static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                          ggml_tensor *dst, const void *src0_dd,
+                          const int32_t *src1_dd, float *dst_dd,
+                          queue_ptr stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
+    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    GGML_ASSERT(ne00 % 2 == 0);
+
+    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             k_get_rows<qk, qr, dq>(
+                                 src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
+                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+                         });
+
+    GGML_UNUSED(dst);
+    GGML_UNUSED(ctx);
+}
+
+template <int qk, int qr, dequantize_kernel_t_reorder dq_reorder>
+static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                          ggml_tensor *dst, const void *src0_dd,
+                          const int32_t *src1_dd, float *dst_dd,
+                          queue_ptr stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
+    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    GGML_ASSERT(ne00 % 2 == 0);
+
+    const uint8_t* src0_q = (const uint8_t*)src0_dd;
+    const size_t ncols = ne00;
+    const size_t nrows = ne01;
+    const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
+    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
+                             k_get_rows_reorder<qk, qr, dq_reorder>(
+                                 src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
+                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+                         });
+
+    GGML_UNUSED(dst);
+    GGML_UNUSED(ctx);
+}
+
+
+template <typename src0_t>
+static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                const ggml_tensor *src1, ggml_tensor *dst,
+                                const src0_t *src0_dd, const int32_t *src1_dd,
+                                float *dst_dd, queue_ptr stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+    const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
+    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
+                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+            });
+    }
+
+    GGML_UNUSED(dst);
+    GGML_UNUSED(ctx);
+}
+
+void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                  const ggml_tensor *src1, ggml_tensor *dst,
+                                  const float *src0_d, const float *src1_d,
+                                  float *dst_d, const queue_ptr &stream) {
+
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
+
+    const int32_t * src1_i32 = (const int32_t *) src1_d;
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
+                                src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_F32:
+            get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_0:
+            if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
+                get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            } else {
+                get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            }
+            break;
+        case GGML_TYPE_Q4_1:
+            get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        default:
+            // TODO: k-quants
+            GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
diff --git a/ggml/src/ggml-sycl/getrows.hpp b/ggml/src/ggml-sycl/getrows.hpp
new file mode 100644 (file)
index 0000000..cdbe6c2
--- /dev/null
@@ -0,0 +1,23 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_GETROWS_HPP
+#define GGML_SYCL_GETROWS_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+    const ggml_tensor *src1, ggml_tensor *dst,
+    const float *src0_d, const float *src1_d,
+    float *dst_d, const queue_ptr &stream);
+
+#endif // GGML_SYCL_GETROWS_HPP
index d4c97ad17b8a0d072e13e37883897e25ca70e630..792e0569ca6cf4d2012ad352e769a26095a34aa0 100644 (file)
 #include "ggml-sycl/backend.hpp"
 #include "ggml-sycl/presets.hpp"
 #include "ggml-sycl/gemm.hpp"
+#include "ggml-sycl/sycl_hw.hpp"
+#include "ggml-sycl/getrows.hpp"
 
 static bool g_sycl_loaded = false;
 int g_ggml_sycl_debug = 0;
+int g_ggml_sycl_disable_optimize = 0;
 
 static ggml_sycl_device_info ggml_sycl_init() {
     ggml_sycl_device_info info = {};
@@ -64,14 +67,18 @@ static ggml_sycl_device_info ggml_sycl_init() {
     for (int i = 0; i < info.device_count; ++i) {
         info.devices[i].vmm = 0;
         dpct::device_info prop;
+        sycl::device device = dpct::dev_mgr::instance().get_device(i);
+
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
-            prop, dpct::dev_mgr::instance().get_device(i))));
+            prop, device)));
 
         info.default_tensor_split[i] = total_vram;
         total_vram += prop.get_global_mem_size();
 
         info.devices[i].cc =
             100 * prop.get_major_version() + 10 * prop.get_minor_version();
+        info.devices[i].hw_info = get_device_hw_info(&device);
+        info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch);
 
         info.max_work_group_sizes[i] = prop.get_max_work_group_size();
     }
@@ -110,6 +117,27 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
             global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
 }
 
+void print_device_opt_feature(int device_count) {
+    GGML_LOG_INFO("SYCL Optimization Feature:\n");
+    GGML_LOG_INFO(
+        "|ID|        Device Type|Reorder|\n");
+    GGML_LOG_INFO(
+        "|--|-------------------|-------|\n");
+    std::map<std::string, size_t> DeviceNums;
+    for (int id = 0; id < device_count; ++id) {
+      sycl::device device = dpct::dev_mgr::instance().get_device(id);
+      std::string backend_type = get_device_backend_and_type(device);
+      int type_id = DeviceNums[backend_type]++;
+      std::stringstream device_type;
+      device_type << "[" << backend_type << ":" << std::to_string(type_id)
+                  << "]";
+      std::string device_type_s = device_type.str();
+      device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
+      GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
+        ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
+    }
+
+}
 void ggml_backend_sycl_print_sycl_devices() {
     GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
     int device_count = dpct::dev_mgr::instance().device_count();
@@ -138,6 +166,8 @@ void ggml_backend_sycl_print_sycl_devices() {
                   << "]";
       print_device_detail(id, device, device_type.str());
     }
+
+    print_device_opt_feature(device_count);
 }
 
 static inline int get_sycl_env(const char *env_name, int default_val) {
@@ -159,17 +189,21 @@ static void ggml_check_sycl() try {
 
     if (!initialized) {
         g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
+        g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
         GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
-        GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
+        GGML_LOG_INFO("Running with Environment Variables:\n");
+        GGML_LOG_INFO("  GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
+        GGML_LOG_INFO("  GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
+        GGML_LOG_INFO("Build with Macros:\n");
 #if defined(GGML_SYCL_FORCE_MMQ)
-        GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ:   yes\n");
+        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: yes\n");
 #else
-        GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ:   no\n");
+        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: no\n");
 #endif
 #if defined(GGML_SYCL_F16)
-        GGML_LOG_INFO("GGML_SYCL_F16: yes\n");
+        GGML_LOG_INFO("  GGML_SYCL_F16: yes\n");
 #else
-        GGML_LOG_INFO("GGML_SYCL_F16: no\n");
+        GGML_LOG_INFO("  GGML_SYCL_F16: no\n");
 #endif
 
 /* NOT REMOVE, keep it for next optimize for XMX.
@@ -241,19 +275,27 @@ struct ggml_backend_sycl_buffer_context {
     void * dev_ptr = nullptr;
     queue_ptr stream;
     std::string name;
+    optimize_feature opt_feature;
+    std::vector<ggml_tensor_extra_gpu *> tensor_extras;
 
-     ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
+    ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
         device(device), dev_ptr(dev_ptr), stream(stream) {
             check_allow_gpu_index(device);
             name = (GGML_SYCL_NAME + std::to_string(device));
+            opt_feature = ggml_sycl_info().devices[device].opt_feature;
         }
 
-
     ~ggml_backend_sycl_buffer_context() {
         if (dev_ptr != nullptr) {
             ggml_sycl_set_device(device);
             SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
         }
+
+        //release extra used by tensors
+        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
+            release_extra_gpu(extra);
+        }
+
     }
 };
 
@@ -291,6 +333,9 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
         return;
     }
 
+    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
+    tensor->extra = extra;
+    ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
 
     if (ggml_is_quantized(tensor->type)) {
         // initialize padding to 0 to avoid possible NaN values
@@ -316,7 +361,6 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
                                                 size_t size) try {
 
     ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-
     ggml_sycl_set_device(ctx->device);
     auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
     SYCL_CHECK(
@@ -660,32 +704,7 @@ struct ggml_backend_sycl_split_buffer_type_context {
 struct ggml_backend_sycl_split_buffer_context {
     ~ggml_backend_sycl_split_buffer_context() try {
         for (ggml_tensor_extra_gpu * extra : tensor_extras) {
-            for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-                for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
-                    if (extra->events[i][is] != nullptr) {
-                        /*
-                        DPCT1009:206: SYCL uses exceptions to report errors and
-                        does not use the error codes. The original code was
-                        commented out and a warning string was inserted. You
-                        need to rewrite this code.
-                        */
-                        SYCL_CHECK(CHECK_TRY_ERROR(
-                            dpct::destroy_event(extra->events[i][is])));
-                    }
-                }
-                if (extra->data_device[i] != nullptr) {
-                    /*
-                    DPCT1009:207: SYCL uses exceptions to report errors and does
-                    not use the error codes. The original code was commented out
-                    and a warning string was inserted. You need to rewrite this
-                    code.
-                    */
-                    ggml_sycl_set_device(i);
-                    SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
-                        extra->data_device[i], *(streams[i]))));
-                }
-            }
-            delete extra;
+            release_extra_gpu(extra, streams);
         }
     }
     catch (sycl::exception const &exc) {
@@ -723,7 +742,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
     ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
 
     ctx->tensor_extras.push_back(extra);
-        ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
+    ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
 
     for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
         int64_t row_low, row_high;
@@ -1337,83 +1356,6 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
     reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
 }
 
-template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void k_get_rows(
-            const void * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12,
-            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
-    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                     item_ct1.get_local_id(2)) *
-                    2;
-    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) /
-                    ne12;
-    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) %
-                    ne12;
-
-    if (i00 >= ne00) {
-        return;
-    }
-
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
-
-    const int ib = i00/qk; // block index
-    const int iqs = (i00%qk)/qr; // quant index
-    const int iybs = i00 - i00%qk; // dst block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
-
-    // dequantize
-    dfloat2 v;
-    dequantize_kernel(src0_row, ib, iqs, v);
-
-    dst_row[iybs + iqs + 0] = v.x();
-    dst_row[iybs + iqs + y_offset] = v.y();
-}
-
-template<typename src0_t, typename dst_t>
-static void k_get_rows_float(
-            const src0_t * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12,
-            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
-    const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                    item_ct1.get_local_id(2);
-    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) /
-                    ne12;
-    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) %
-                    ne12;
-
-    if (i00 >= ne00) {
-        return;
-    }
-
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
-
-    dst_row[i00] = src0_row[i00];
-}
-
 static void mul_mat_p021_f16_f32(
     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1896,81 +1838,6 @@ static  void pool2d_nchw_kernel(
         o_ptr[cur_oh * ow + cur_ow] = res;
 }
 
-template <int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                          ggml_tensor *dst, const void *src0_dd,
-                          const int32_t *src1_dd, float *dst_dd,
-                          queue_ptr stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
-    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
-    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
-
-    // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
-
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
-
-    GGML_ASSERT(ne00 % 2 == 0);
-
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             k_get_rows<qk, qr, dq>(
-                                 src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
-                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
-                         });
-
-    GGML_UNUSED(dst);
-    GGML_UNUSED(ctx);
-}
-
-template <typename src0_t>
-static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const src0_t *src0_dd, const int32_t *src1_dd,
-                                float *dst_dd, queue_ptr stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
-    const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
-    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
-
-    // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
-
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
-
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
-                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
-            });
-    }
-
-    GGML_UNUSED(dst);
-    GGML_UNUSED(ctx);
-}
-
 static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
                                    const int ky, const int kx_padded,
                                    queue_ptr stream) {
@@ -2494,52 +2361,6 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_d, const float *src1_d,
-                                  float *dst_d, const queue_ptr &stream) {
-
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
-
-    const int32_t * src1_i32 = (const int32_t *) src1_d;
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
-                                src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_F32:
-            get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q4_0:
-            get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        default:
-            // TODO: k-quants
-            GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
-            GGML_ABORT("fatal error");
-            break;
-    }
-}
-
-
 static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                 const ggml_tensor *src1, ggml_tensor *dst,
                                 const float *src0_d, const float *src1_d,
@@ -2589,11 +2410,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
     if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
         use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
         dst->op_params[0] == GGML_PREC_DEFAULT) {
-
         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
         ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
         if (src0->type != GGML_TYPE_F16) {
-            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
+            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
             GGML_ASSERT(to_fp16_sycl != nullptr);
             size_t ne = row_diff*ne00;
             src0_as_f16.alloc(ne);
@@ -2605,7 +2425,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
 
         ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
         if (src1->type != GGML_TYPE_F16) {
-            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
             GGML_ASSERT(to_fp16_sycl != nullptr);
             size_t ne = src1_ncols*ne10;
             src1_as_f16.alloc(ne);
@@ -2626,13 +2446,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
             src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
             dst_f16.get(), dpct::library_data_t::real_half, ldc,
             dpct::library_data_t::real_half)));
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
 #else
         auto dnnl_stream = ctx.stream_dnnl(stream);
         DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
             src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
 #endif
     }
@@ -2641,13 +2461,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
         ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
         ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
         if (src0->type != GGML_TYPE_F32) {
-            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
+            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
             GGML_ASSERT(to_fp32_sycl != nullptr);
             src0_ddq_as_f32.alloc(row_diff*ne00);
             to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
         }
         if (src1->type != GGML_TYPE_F32) {
-            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
+            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
             GGML_ASSERT(to_fp32_sycl != nullptr);
             src1_ddq_as_f32.alloc(src1_ncols*ne10);
             to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
@@ -3085,7 +2905,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
     for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
         const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
         const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
-
         for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
             if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
                 continue;
@@ -3393,7 +3212,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
     // convert src1 to fp16
     ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
     if (src1->type != GGML_TYPE_F16) {
-        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
         const int64_t ne_src1 = ggml_nelements(src1);
         src1_f16_alloc.alloc(ne_src1);
         GGML_ASSERT(to_fp16_sycl != nullptr);
@@ -3509,6 +3328,7 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
 }
 
 static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+
     const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
     int64_t min_compute_capability = INT_MAX;
 
@@ -3570,6 +3390,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
         ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
         ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
+        // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
     } else if (use_mul_mat_vec_q) {
         ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
     } else if (use_mul_mat_q) {
@@ -4251,10 +4072,72 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
+void reorder_qw(char *data_device, const int ncols, const int nrows,
+                size_t size, size_t offset, dpct::queue_ptr stream) {
+    auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
+    SYCL_CHECK(
+        CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
+            .wait()));
+    GGML_ASSERT((size % sizeof(block_q4_0) == 0));
+    GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
+    int offset_blks = offset / sizeof(block_q4_0);
+    auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
+    auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
+
+    stream->parallel_for(
+        size / sizeof(block_q4_0),
+            [=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+            const block_q4_0* x = (const block_q4_0*)tmp_buf;
+            const int ib = i;
+
+            for (int j = 0; j < QK4_0/2; j ++)
+            {
+                *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
+            }
+            *(d_ptr + ib) = x[ib].d;
+        });
+
+    sycl::free(tmp_buf, *stream);
+}
+
+void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
+    char*data_device = (char*)src0->data;
+    size_t ncols = src0->ne[0];
+    size_t nrows = src0->ne[1];
+    size_t size = ggml_nbytes(src0);
+
+    reorder_qw(data_device, ncols, nrows, size, 0, stream);
+}
+
+void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
+    ggml_tensor *src0 = dst->src[0];
+    ggml_tensor *src1 = dst->src[1];
+
+    if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 &&
+        src1->ne[2]==1 && src1->ne[3]==1) {
+        reorder_qw(src0, stream);
+        ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
+        GGML_ASSERT(extra);
+        extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
+    }
+}
+
+void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
+    dpct::queue_ptr stream = ctx->stream();
+    if (ctx->optimized_graph) {
+       return;
+    }
+    ctx->optimized_graph = true;
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
+    }
+}
 static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
     ggml_sycl_set_main_device(sycl_ctx->device);
 
+    if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
diff --git a/ggml/src/ggml-sycl/sycl_hw.cpp b/ggml/src/ggml-sycl/sycl_hw.cpp
new file mode 100644 (file)
index 0000000..da121ff
--- /dev/null
@@ -0,0 +1,13 @@
+#include "sycl_hw.hpp"
+
+
+sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
+  sycl_hw_info res;
+  int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
+  res.device_id = id;
+
+  syclex::architecture arch = device_ptr->get_info<syclex::info::device::architecture>();
+  res.arch = arch;
+
+  return res;
+}
diff --git a/ggml/src/ggml-sycl/sycl_hw.hpp b/ggml/src/ggml-sycl/sycl_hw.hpp
new file mode 100644 (file)
index 0000000..bf68945
--- /dev/null
@@ -0,0 +1,23 @@
+#ifndef SYCL_HW_HPP
+#define SYCL_HW_HPP
+
+#include <algorithm>
+#include <stdio.h>
+#include <vector>
+#include <map>
+
+#include <sycl/sycl.hpp>
+
+namespace syclex = sycl::ext::oneapi::experimental;
+
+struct sycl_hw_info {
+  syclex::architecture arch;
+  int32_t device_id;
+};
+
+bool is_in_vector(std::vector<int> &vec, int item);
+
+sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
+
+
+#endif // SYCL_HW_HPP