]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : reorganize source code + improve CMake (#865)
authorGeorgi Gerganov <redacted>
Wed, 26 Jun 2024 16:33:53 +0000 (19:33 +0300)
committerGitHub <redacted>
Wed, 26 Jun 2024 16:33:53 +0000 (19:33 +0300)
* scripts : update sync [no ci]

* ggml : move headers one up [no ci]

* files : reorganize + update CMake

ggml-ci

* cmake : build normal ggml library

ggml-ci

* cmake : link math library to test + remove ci for code cov

ggml-ci

* files : move public headers to include

ggml-ci

104 files changed:
.github/workflows/ci.yml
.gitignore
CMakeLists.txt
cmake/FindSIMD.cmake [new file with mode: 0644]
examples/gpt-2/main-alloc.cpp
examples/gpt-2/main-backend.cpp
examples/gpt-2/main-batched.cpp
examples/gpt-2/main-ctx.cpp
examples/gpt-2/main-sched.cpp
examples/gpt-2/quantize.cpp
examples/gpt-j/README.md
examples/gpt-j/main.cpp
examples/gpt-j/quantize.cpp
examples/magika/main.cpp
examples/mnist/main-cnn.cpp
examples/mnist/main-cpu.cpp
examples/mnist/main-mtl.cpp
examples/mnist/main-mtl.m
examples/mnist/main.cpp
examples/python/README.md
examples/simple/simple-backend.cpp
examples/yolo/yolov3-tiny.cpp
include/ggml-alloc.h [new file with mode: 0644]
include/ggml-backend.h [new file with mode: 0644]
include/ggml-blas.h [new file with mode: 0644]
include/ggml-cuda.h [new file with mode: 0644]
include/ggml-kompute.h [new file with mode: 0644]
include/ggml-metal.h [new file with mode: 0644]
include/ggml-rpc.h [new file with mode: 0644]
include/ggml-sycl.h [new file with mode: 0644]
include/ggml-vulkan.h [new file with mode: 0644]
include/ggml.h [new file with mode: 0644]
include/ggml/ggml-alloc.h [deleted file]
include/ggml/ggml-backend.h [deleted file]
include/ggml/ggml.h [deleted file]
scripts/sync-llama-am.sh
scripts/sync-llama.sh
scripts/sync-whisper-am.sh
scripts/sync-whisper.sh
spm-headers/ggml-alloc.h
spm-headers/ggml-backend.h
spm-headers/ggml-metal.h [new symlink]
spm-headers/ggml.h
src/CMakeLists.txt
src/ggml-backend.c
src/ggml-blas.h [deleted file]
src/ggml-cuda.cu
src/ggml-cuda.h [deleted file]
src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/mma.cuh
src/ggml-cuda/mmq.cu
src/ggml-cuda/mmq.cuh
src/ggml-cuda/mmvq.cuh
src/ggml-cuda/unary.cu
src/ggml-cuda/unary.cuh
src/ggml-impl.h
src/ggml-kompute.h [deleted file]
src/ggml-metal.h [deleted file]
src/ggml-metal.m
src/ggml-quants.c
src/ggml-rpc.cpp
src/ggml-rpc.h [deleted file]
src/ggml-sycl.cpp
src/ggml-sycl.h [deleted file]
src/ggml-sycl/backend.hpp
src/ggml-sycl/common.hpp
src/ggml-sycl/convert.cpp [new file with mode: 0644]
src/ggml-sycl/convert.hpp [new file with mode: 0644]
src/ggml-sycl/dequantize.hpp [new file with mode: 0644]
src/ggml-sycl/dmmv.cpp [new file with mode: 0644]
src/ggml-sycl/dmmv.hpp [new file with mode: 0644]
src/ggml-sycl/dpct/helper.hpp
src/ggml-sycl/mmq.cpp [new file with mode: 0644]
src/ggml-sycl/mmq.hpp [new file with mode: 0644]
src/ggml-sycl/mmvq.cpp [new file with mode: 0644]
src/ggml-sycl/mmvq.hpp [new file with mode: 0644]
src/ggml-sycl/presets.hpp
src/ggml-sycl/vecdotq.hpp [new file with mode: 0644]
src/ggml-vulkan.cpp
src/ggml-vulkan.h [deleted file]
src/ggml.c
tests/CMakeLists.txt
tests/test-arange.cpp
tests/test-backend-buffer.cpp
tests/test-backend-ops.cpp
tests/test-conv-transpose.c
tests/test-conv1d.cpp
tests/test-conv2d.cpp
tests/test-customop.c
tests/test-dup.c
tests/test-mul-mat.cpp
tests/test-mul-mat0.c
tests/test-pool.c
tests/test-rel-pos.c
tests/test-timestep_embedding.cpp
tests/test0.c
tests/test0.zig
tests/test1.c
tests/test1.zig
tests/test2.c
tests/test2.zig
tests/test3.c
tests/test3.zig

index 817e669530be8bc0caf21dd30b88be4b2a2ad1d9..2276ec7d73ff926f6a23f039a591747ffa77b4e9 100644 (file)
@@ -29,7 +29,7 @@ jobs:
 
     - name: Configure CMake
       working-directory: ./build
-      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON -DGGML_CLBLAST=ON ..
+      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_CLBLAST=ON ..
 
     - name: Build
       working-directory: ./build
@@ -39,13 +39,6 @@ jobs:
       working-directory: ./build
       run: ctest --verbose --timeout 900
 
-    - name: Test Coverage
-      working-directory: ./build
-      run: |
-        llvm-profdata merge -sparse tests/*.profraw -o ggml.profdata
-        llvm-cov      report ./bin/test-grad0 -instr-profile=ggml.profdata
-        llvm-cov      report ./bin/test-opt   -instr-profile=ggml.profdata
-
   test-macos-metal:
     runs-on: macos-13
     env:
@@ -61,7 +54,7 @@ jobs:
 
     - name: Configure CMake
       working-directory: ./build
-      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON -DGGML_METAL=OFF ..
+      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_METAL=OFF ..
 
     - name: Build
       working-directory: ./build
@@ -71,13 +64,6 @@ jobs:
       working-directory: ./build
       run: ctest --verbose --timeout 900
 
-    - name: Test Coverage
-      working-directory: ./build
-      run: |
-        xcrun llvm-profdata merge -sparse tests/*.profraw -o ggml.profdata
-        xcrun llvm-cov      report ./bin/test-grad0 -instr-profile=ggml.profdata
-        xcrun llvm-cov      report ./bin/test-opt   -instr-profile=ggml.profdata
-
   build:
 
     strategy:
@@ -112,7 +98,7 @@ jobs:
 
     - name: Configure CMake
       working-directory: ./build
-      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_TEST_COVERAGE=ON -DGGML_METAL=OFF ..
+      run: cmake -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DGGML_METAL=OFF ..
 
     - name: Build
       working-directory: ./build
@@ -121,19 +107,3 @@ jobs:
     - name: Test
       working-directory: ./build
       run: ctest --verbose --timeout 900
-
-    - name: Test Coverage for Ubuntu
-      if: matrix.os == 'ubuntu-latest'
-      working-directory: ./build
-      run: |
-        llvm-profdata merge -sparse tests/*.profraw -o ggml.profdata
-        llvm-cov      report ./bin/test-grad0 -instr-profile=ggml.profdata
-        llvm-cov      report ./bin/test-opt   -instr-profile=ggml.profdata
-
-    - name: Test Coverage for MacOS
-      if: matrix.os == 'macos-latest'
-      working-directory: ./build
-      run: |
-        xcrun llvm-profdata merge -sparse tests/*.profraw -o ggml.profdata
-        xcrun llvm-cov      report ./bin/test-grad0 -instr-profile=ggml.profdata
-        xcrun llvm-cov      report ./bin/test-opt   -instr-profile=ggml.profdata
index dd2ca4b975e33c29affd5d64e0ceb9e3c5d740e3..6b6c6a346d43fd226b8ca2362c8aee222719523e 100644 (file)
@@ -1,13 +1,5 @@
 build/
-build-blas/
-build-debug/
-build-release/
-build-sanitize-addr/
-build-sanitize-thread/
-build-cov/
-build-ci-debug/
-build-ci-release/
-build-cublas/
+build-*/
 out/
 tmp/
 models/
@@ -19,6 +11,7 @@ CMakeSettings.json
 .vscode/
 .clangd
 
+.venv/
 .exrc
 .cache
 .DS_Store
index f8f418bfa388c0f9e4aa404d74468f75dbfda051..f3763f7eb9fef117888eefac70c49877ec330d7e 100644 (file)
@@ -1,20 +1,29 @@
-cmake_minimum_required (VERSION 3.12)
-project(ggml VERSION 0.1.0)
+cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
+project("ggml" C CXX)
+include(CheckIncludeFileCXX)
 
-set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
-set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
-set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
 
-if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
+if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
+    set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
+    set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
+endif()
+
+if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
     set(GGML_STANDALONE ON)
-    include(cmake/GitVars.cmake)
-    include(cmake/BuildTypes.cmake)
+
+    set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+
+    # configure project version
+    # TODO
 else()
     set(GGML_STANDALONE OFF)
 endif()
 
 if (EMSCRIPTEN)
     set(BUILD_SHARED_LIBS_DEFAULT OFF)
+
+    option(GGML_WASM_SINGLE_FILE "ggml: embed WASM inside the generated ggml.js" ON)
 else()
     if (MINGW)
         set(BUILD_SHARED_LIBS_DEFAULT OFF)
@@ -23,7 +32,13 @@ else()
     endif()
 endif()
 
-# options
+option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
+
+#
+# option list
+#
+
+# TODO: mark all options as advanced when not GGML_STANDALONE
 
 if (APPLE)
     set(GGML_METAL_DEFAULT ON)
@@ -35,179 +50,132 @@ else()
     set(GGML_BLAS_VENDOR_DEFAULT "Generic")
 endif()
 
-option(BUILD_SHARED_LIBS            "ggml: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
-
-option(GGML_ALL_WARNINGS            "ggml: enable all compiler warnings"                   ON)
-option(GGML_ALL_WARNINGS_3RD_PARTY  "ggml: enable all compiler warnings in 3rd party libs" OFF)
-
-option(GGML_SANITIZE_THREAD         "ggml: enable thread sanitizer"    OFF)
-option(GGML_SANITIZE_ADDRESS        "ggml: enable address sanitizer"   OFF)
-option(GGML_SANITIZE_UNDEFINED      "ggml: enable undefined sanitizer" OFF)
+# general
+option(GGML_STATIC "ggml: static link libraries"         OFF)
+option(GGML_NATIVE "ggml: enable -march=native flag"     ON)
+option(GGML_LTO    "ggml: enable link time optimization" OFF)
+option(GGML_CCACHE "ggml: use ccache if available"       ON)
 
-option(GGML_BUILD_TESTS             "ggml: build tests"    ${GGML_STANDALONE})
-option(GGML_BUILD_EXAMPLES          "ggml: build examples" ${GGML_STANDALONE})
+# debug
+option(GGML_ALL_WARNINGS           "ggml: enable all compiler warnings"                   ON)
+option(GGML_ALL_WARNINGS_3RD_PARTY "ggml: enable all compiler warnings in 3rd party libs" OFF)
+option(GGML_GPROF                  "ggml: enable gprof"                                   OFF)
 
-option(GGML_TEST_COVERAGE           "ggml: enable test coverage" OFF)
-
-option(GGML_PERF                    "ggml: enable perf timings"               OFF)
-option(GGML_NO_ACCELERATE           "ggml: disable Accelerate framework"      OFF)
-option(GGML_BLAS                    "ggml: use BLAS"                          ${GGML_BLAS_DEFAULT})
-set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
-                                    "ggml: BLAS library vendor")
-option(GGML_HIPBLAS                 "ggml: use hipBLAS"                       OFF)
-option(GGML_CUDA                    "ggml: use CUDA"                          OFF)
-option(GGML_CUBLAS                  "ggml: use CUDA (deprecated)"             OFF)
-option(GGML_METAL                   "ggml: use Metal"                         ${GGML_METAL_DEFAULT})
-option(GGML_METAL_NDEBUG            "ggml: disable Metal debugging"           OFF)
-option(GGML_METAL_SHADER_DEBUG      "ggml: compile Metal with -fno-fast-math" OFF)
-option(GGML_METAL_EMBED_LIBRARY     "ggml: embed Metal library"               OFF)
-option(GGML_RPC                     "ggml: use RPC"                           OFF)
-option(GGML_VULKAN                  "ggml: use Vulkan"                        OFF)
+# build
+option(GGML_FATAL_WARNINGS    "ggml: enable -Werror flag"    OFF)
 
-option(GGML_CUDA_FORCE_DMMV                 "ggml: use dmmv instead of mmvq CUDA kernels"     OFF)
-option(GGML_CUDA_FORCE_MMQ                  "ggml: use mmq kernels instead of cuBLAS"         OFF)
-set(GGML_CUDA_DMMV_X      "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
-set(GGML_CUDA_MMV_Y        "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
-option(GGML_CUDA_F16                        "ggml: use 16 bit floats for some calculations"   OFF)
-set(GGML_CUDA_KQUANTS_ITER "2" CACHE STRING "ggml: iters./thread per block for Q2_K/Q6_K")
-set(GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
-                                            "ggml: max. batch size for using peer access")
 # sanitizers
+option(GGML_SANITIZE_THREAD    "ggml: enable thread sanitizer"    OFF)
+option(GGML_SANITIZE_ADDRESS   "ggml: enable address sanitizer"   OFF)
+option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
 
-if (GGML_SANITIZE_THREAD)
-    set(CMAKE_C_FLAGS   "${CMAKE_C_FLAGS}   -fsanitize=thread")
-    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread")
-endif()
-
-if (GGML_SANITIZE_ADDRESS)
-    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}     -fsanitize=address -fno-omit-frame-pointer")
-    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
+# instruction set specific
+if (GGML_NATIVE)
+    set(INS_ENB OFF)
+else()
+    set(INS_ENB ON)
 endif()
 
-if (GGML_SANITIZE_UNDEFINED)
-    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}     -fsanitize=undefined")
-    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined")
-endif()
+option(GGML_CPU_HBM     "ggml: use memkind for CPU HBM" OFF)
 
-# instruction set specific
-option(GGML_AVX                     "ggml: enable AVX"                                     ON)
-option(GGML_AVX2                    "ggml: enable AVX2"                                    ON)
-option(GGML_AVX512                  "ggml: enable AVX512"                                  OFF)
-option(GGML_AVX512_VBMI             "ggml: enable AVX512-VBMI"                             OFF)
-option(GGML_AVX512_VNNI             "ggml: enable AVX512-VNNI"                             OFF)
-option(GGML_FMA                     "ggml: enable FMA"                                     ON)
-# in MSVC F16C is implied with AVX2/AVX512
+option(GGML_AVX         "ggml: enable AVX"              ${INS_ENB})
+option(GGML_AVX2        "ggml: enable AVX2"             ${INS_ENB})
+option(GGML_AVX512      "ggml: enable AVX512"           OFF)
+option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI"      OFF)
+option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI"      OFF)
+option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16"      OFF)
+option(GGML_FMA         "ggml: enable FMA"              ${INS_ENB})
 if (NOT MSVC)
-    option(GGML_F16C                "ggml: enable F16C"                                    ON)
+    option(GGML_F16C    "ggml: enable F16C"             ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512
 endif()
+option(GGML_LASX        "ggml: enable lasx"             ON)
+option(GGML_LSX         "ggml: enable lsx"              ON)
+option(GGML_SVE         "ggml: enable SVE"              OFF)
 
-#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffast-math")
-#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
-#set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=native")
+if (WIN32)
+    set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows Version")
+endif()
 
-# warning flags
+# ggml core
+set(GGML_SCHED_MAX_COPIES  "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
 
-if (GGML_ALL_WARNINGS)
-    if (NOT MSVC)
-        set(c_flags   -Wall -Wpedantic -Wformat=2 -Wno-unused -Wstrict-prototypes)
-        set(cxx_flags -Wall -Wpedantic -Wformat=2)
-    else()
-        # todo : windows
-    endif()
-
-    add_compile_options(
-        "$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
-        "$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
-    )
-endif()
+# 3rd party libs / backends
+option(GGML_ACCELERATE                      "ggml: enable Accelerate framework"               ON)
+option(GGML_BLAS                            "ggml: use BLAS"                                  ${GGML_BLAS_DEFAULT})
+set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
+                                            "ggml: BLAS library vendor")
+option(GGML_LLAMAFILE                       "ggml: use ggml SGEMM"                            OFF)
 
-if (NOT MSVC)
-    # TODO: temporary disabled until we figure out ggml-metal.m
-    #add_compile_options(
-    #    "$<$<COMPILE_LANGUAGE:C>:-Werror=vla>"
-    #    "$<$<COMPILE_LANGUAGE:CXX>:-Werror=vla>"
-    #    "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler;-Werror=vla>"
-    #)
-endif()
+option(GGML_CUDA                            "ggml: use CUDA"                                  OFF)
+option(GGML_CUDA_FORCE_DMMV                 "ggml: use dmmv instead of mmvq CUDA kernels"     OFF)
+option(GGML_CUDA_FORCE_MMQ                  "ggml: use mmq kernels instead of cuBLAS"         OFF)
+set   (GGML_CUDA_DMMV_X   "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
+set   (GGML_CUDA_MMV_Y     "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
+option(GGML_CUDA_F16                        "ggml: use 16 bit floats for some calculations"   OFF)
+set   (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING
+                                            "ggml: iters./thread per block for Q2_K/Q6_K")
+set   (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
+                                            "ggml: max. batch size for using peer access")
+option(GGML_CUDA_NO_PEER_COPY               "ggml: do not use peer to peer copies"            OFF)
+option(GGML_CUDA_NO_VMM                     "ggml: do not try to use CUDA VMM"                OFF)
+option(GGML_CUDA_FA_ALL_QUANTS              "ggml: compile all quants for FlashAttention"     OFF)
+
+option(GGML_CURL                            "ggml: use libcurl to download model from an URL" OFF)
+option(GGML_HIPBLAS                         "ggml: use hipBLAS"                               OFF)
+option(GGML_HIP_UMA                         "ggml: use HIP unified memory architecture"       OFF)
+option(GGML_VULKAN                          "ggml: use Vulkan"                                OFF)
+option(GGML_VULKAN_CHECK_RESULTS            "ggml: run Vulkan op checks"                      OFF)
+option(GGML_VULKAN_DEBUG                    "ggml: enable Vulkan debug output"                OFF)
+option(GGML_VULKAN_MEMORY_DEBUG             "ggml: enable Vulkan memory debug output"         OFF)
+option(GGML_VULKAN_VALIDATE                 "ggml: enable Vulkan validation"                  OFF)
+option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF)
+option(GGML_KOMPUTE                         "ggml: use Kompute"                               OFF)
+option(GGML_METAL                           "ggml: use Metal"                                 ${GGML_METAL_DEFAULT})
+option(GGML_METAL_NDEBUG                    "ggml: disable Metal debugging"                   OFF)
+option(GGML_METAL_SHADER_DEBUG              "ggml: compile Metal with -fno-fast-math"         OFF)
+option(GGML_METAL_EMBED_LIBRARY             "ggml: embed Metal library"                       ${GGML_METAL})
+set   (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING
+                                            "ggml: metal minimum macOS version")
+set   (GGML_METAL_STD "" CACHE STRING       "ggml: metal standard version (-std flag)")
+option(GGML_OPENMP                          "ggml: use OpenMP"                                ON)
+option(GGML_RPC                             "ggml: use RPC"                                   OFF)
+option(GGML_SYCL                            "ggml: use SYCL"                                  OFF)
+option(GGML_SYCL_F16                        "ggml: use 16 bit floats for sycl calculations"   OFF)
+set   (GGML_SYCL_TARGET "INTEL" CACHE STRING
+                                            "ggml: sycl target device")
+
+# extra artifacts
+option(GGML_BUILD_TESTS    "ggml: build tests"    ${GGML_STANDALONE})
+option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
 
 #
-# POSIX conformance
+# dependencies
 #
 
-# clock_gettime came in POSIX.1b (1993)
-# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
-# posix_memalign came in POSIX.1-2001 / SUSv3
-# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
-add_compile_definitions(_XOPEN_SOURCE=600)
-
-# Somehow in OpenBSD whenever POSIX conformance is specified
-# some string functions rely on locale_t availability,
-# which was introduced in POSIX.1-2008, forcing us to go higher
-if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
-    remove_definitions(-D_XOPEN_SOURCE=600)
-    add_compile_definitions(_XOPEN_SOURCE=700)
-endif()
-
-# Data types, macros and functions related to controlling CPU affinity
-# are available on Linux through GNU extensions in libc
-if (CMAKE_SYSTEM_NAME MATCHES "Linux")
-    add_compile_definitions(_GNU_SOURCE)
-endif()
-
-# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
-# and on macOS its availability depends on enabling Darwin extensions
-# similarly on DragonFly, enabling BSD extensions is necessary
-if (CMAKE_SYSTEM_NAME MATCHES "Darwin")
-    add_compile_definitions(_DARWIN_C_SOURCE)
-endif()
-if (CMAKE_SYSTEM_NAME MATCHES "DragonFly")
-    add_compile_definitions(_DARWIN_C_SOURCE)
-endif()
-
-# alloca is a non-standard interface that is not visible on BSDs when
-# POSIX conformance is specified, but not all of them provide a clean way
-# to enable it in such cases
-if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
-    add_compile_definitions(__BSD_VISIBLE)
-endif()
-if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
-    add_compile_definitions(_NETBSD_SOURCE)
-endif()
-if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
-    add_compile_definitions(_BSD_SOURCE)
-endif()
+set(CMAKE_C_STANDARD 11)
+set(CMAKE_C_STANDARD_REQUIRED true)
 
-if (WHISPER_PERF)
-    set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
+if (GGML_SYCL)
+    set(CMAKE_CXX_STANDARD 17)
+else()
+    set(CMAKE_CXX_STANDARD 11)
 endif()
+set(CMAKE_CXX_STANDARD_REQUIRED true)
 
-# dependencies
-
-set(CMAKE_C_STANDARD   11)
-set(CMAKE_CXX_STANDARD 11)
+set(THREADS_PREFER_PTHREAD_FLAG ON)
 
 find_package(Threads REQUIRED)
 
-# main
-
-if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
-    set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
-    set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "RelWithDebInfo")
-endif ()
-
-if (GGML_BUILD_TESTS)
-    if (GGML_TEST_COVERAGE)
-        if (CMAKE_C_COMPILER_ID MATCHES "Clang")
-            set(CMAKE_C_FLAGS   "${CMAKE_C_FLAGS}   -fprofile-instr-generate -fcoverage-mapping")
-            set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-instr-generate -fcoverage-mapping")
-        else()
-            message(WARNING "Test coverage is only supported for Clang")
-        endif()
-    endif()
-endif()
+#
+# build the library
+#
 
 add_subdirectory(src)
 
+#
+# tests and examples
+#
+
 if (GGML_BUILD_TESTS)
     enable_testing()
     add_subdirectory(tests)
@@ -217,8 +185,54 @@ if (GGML_BUILD_EXAMPLES)
     add_subdirectory(examples)
 endif ()
 
-configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in
-               ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
-               @ONLY)
-install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
+#
+# install
+#
+
+include(GNUInstallDirs)
+include(CMakePackageConfigHelpers)
+
+set(GGML_PUBLIC_HEADERS
+    include/ggml.h
+    include/ggml-alloc.h
+    include/ggml-backend.h
+    "${GGML_HEADERS_CUDA}"
+    "${GGML_HEADERS_METAL}"
+    "${GGML_HEADERS_EXTRA}")
+
+set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
+#if (GGML_METAL)
+#    set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
+#endif()
+install(TARGETS ggml PUBLIC_HEADER)
+
+if (BUILD_SHARED_LIBS)
+    install(TARGETS ggml LIBRARY)
+endif()
+
+if (GGML_METAL)
+    install(
+        FILES src/ggml-metal.metal
+        PERMISSIONS
+            OWNER_READ
+            OWNER_WRITE
+            GROUP_READ
+            WORLD_READ
+        DESTINATION ${CMAKE_INSTALL_BINDIR})
+
+    if (NOT GGML_METAL_EMBED_LIBRARY)
+        install(
+            FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
+            DESTINATION ${CMAKE_INSTALL_BINDIR}
+        )
+    endif()
+endif()
+
+if (GGML_STANDALONE)
+    configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in
+        ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
+        @ONLY)
+
+    install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
         DESTINATION share/pkgconfig)
+endif()
diff --git a/cmake/FindSIMD.cmake b/cmake/FindSIMD.cmake
new file mode 100644 (file)
index 0000000..5533668
--- /dev/null
@@ -0,0 +1,100 @@
+include(CheckCSourceRuns)
+
+set(AVX_CODE "
+    #include <immintrin.h>
+    int main()
+    {
+        __m256 a;
+        a = _mm256_set1_ps(0);
+        return 0;
+    }
+")
+
+set(AVX512_CODE "
+    #include <immintrin.h>
+    int main()
+    {
+        __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0,
+                                    0, 0, 0, 0, 0, 0, 0, 0);
+        __m512i b = a;
+        __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
+        return 0;
+    }
+")
+
+set(AVX2_CODE "
+    #include <immintrin.h>
+    int main()
+    {
+        __m256i a = {0};
+        a = _mm256_abs_epi16(a);
+        __m256i x;
+        _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
+        return 0;
+    }
+")
+
+set(FMA_CODE "
+    #include <immintrin.h>
+    int main()
+    {
+        __m256 acc = _mm256_setzero_ps();
+        const __m256 d = _mm256_setzero_ps();
+        const __m256 p = _mm256_setzero_ps();
+        acc = _mm256_fmadd_ps( d, p, acc );
+        return 0;
+    }
+")
+
+macro(check_sse type flags)
+    set(__FLAG_I 1)
+    set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
+    foreach (__FLAG ${flags})
+        if (NOT ${type}_FOUND)
+            set(CMAKE_REQUIRED_FLAGS ${__FLAG})
+            check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
+            if (HAS_${type}_${__FLAG_I})
+                set(${type}_FOUND TRUE CACHE BOOL "${type} support")
+                set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
+            endif()
+            math(EXPR __FLAG_I "${__FLAG_I}+1")
+        endif()
+    endforeach()
+    set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
+
+    if (NOT ${type}_FOUND)
+        set(${type}_FOUND FALSE CACHE BOOL "${type} support")
+        set(${type}_FLAGS "" CACHE STRING "${type} flags")
+    endif()
+
+    mark_as_advanced(${type}_FOUND ${type}_FLAGS)
+endmacro()
+
+# flags are for MSVC only!
+check_sse("AVX" " ;/arch:AVX")
+if (NOT ${AVX_FOUND})
+    set(GGML_AVX OFF)
+else()
+    set(GGML_AVX ON)
+endif()
+
+check_sse("AVX2" " ;/arch:AVX2")
+check_sse("FMA" " ;/arch:AVX2")
+if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
+    set(GGML_AVX2 OFF)
+else()
+    set(GGML_AVX2 ON)
+endif()
+
+check_sse("AVX512" " ;/arch:AVX512")
+if (NOT ${AVX512_FOUND})
+    set(GGML_AVX512 OFF)
+else()
+    set(GGML_AVX512 ON)
+endif()
index 7a3197e65c15c25d33faeccf8db13febdee5aafe..ba8906573c9dd74bd4a9576b1495803f43fad6cb 100644 (file)
@@ -1,6 +1,6 @@
-#include "ggml/ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #include "common.h"
 #include "common-ggml.h"
index 714c158fb206a4168dd5cb089da0a548b648d133..db8e7f20e3271e4a44f70b16d3d59c576c697fd6 100644 (file)
@@ -1,6 +1,6 @@
-#include "ggml/ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #ifdef GGML_USE_CUDA
 #include "ggml-cuda.h"
index 6dbf5e3bc119360b645196ee489e212a2ffe1490..ecf08b2b7ca002047b1bbeed876118b01bfa34b5 100644 (file)
@@ -1,6 +1,6 @@
-#include "ggml/ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #ifdef GGML_USE_CUDA
 #include "ggml-cuda.h"
index 5dd114177ff676ea79a5f14b2ff5187670c13038..01da41a2421f6fa92a77a2d17c03012d480b1d09 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include "common.h"
 #include "common-ggml.h"
index 11c72973d2273c52e7d3a8515e35cb764fa1687d..36571045e48fa14dd6ae1aea91581781eac2e478 100644 (file)
@@ -1,6 +1,6 @@
-#include "ggml/ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #ifdef GGML_USE_CUDA
 #include "ggml-cuda.h"
index 9d8d53a67b1b1881da73eb7aa67074c84d866595..f81c04e8cd91848e9906cabbcfdf973198772ce9 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include "common.h"
 #include "common-ggml.h"
index e5cc7959ef09bbb86d2dd83a250f70120c179723..eac5a7313dc670dce0dd558bb5aaa95b0a4fda73 100644 (file)
@@ -147,7 +147,7 @@ sys.    0m7.103s
 ## Implementation details
 
 The high level implementation of the model is contained in the [main.cpp](main.cpp) file. The core computations are
-performed by the [ggml](https://github.com/ggerganov/ggml/blob/master/include/ggml/ggml.h) library.
+performed by the [ggml](https://github.com/ggerganov/ggml/blob/master/include/ggml.h) library.
 
 
 #### Matrix multiplication
index ae55fb7029500e4cee0852e3942ad6ccb07273e0..54ff611346907ef5700acac3bb71150e048e656e 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include "common.h"
 #include "common-ggml.h"
index 437053b7d86b6bfe30d34a4810235ed837f45e43..c6f258c4a0eec4498167d594f66e68da0c4afafc 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include "common.h"
 #include "common-ggml.h"
index 61383ec70c739b0e27e36d89f1f20fb59639831c..7c22d6cee2d7aeeee86e7ac7559f355f694f2b5b 100644 (file)
@@ -1,6 +1,7 @@
-#include "ggml/ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+
 #include <algorithm>
 #include <cmath>
 #include <numeric>
index b0135035bfc3d24b06c166a06abf4c9d90438352..1ea48e653970dba039f566fa6ba5f8b546f7854c 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include "common.h"
 
index 3b759b0f3c0272de127ae0cdd50f90f0a6c53c5d..bfef474da92cd88639062009c53849b028e82377 100644 (file)
@@ -10,7 +10,7 @@
 // $ ./bin/mnist-cpu ./models/mnist/mnist.ggml ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
 //
 
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <algorithm>
 #include <cmath>
index 7d0eec8f186bb13c23e3f067cc9513a7850673be..db2401f28ad2a845e4abbcba0dbbbc6b5d909342 100644 (file)
@@ -10,7 +10,7 @@
 // $ ./bin/mnist-mtl ./models/mnist/mnist.ggml ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
 //
 
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include "main-mtl.h"
 
index 4b7717920a69b0017d91135c9f7adf118ec81394..ea929e20fda751285724d30fa4ca65cc3d31e918 100644 (file)
@@ -1,6 +1,6 @@
 #import "main-mtl.h"
 
-#import "ggml/ggml.h"
+#import "ggml.h"
 
 #import <Foundation/Foundation.h>
 #import <Metal/Metal.h>
index 358085861b2333c01e1165472330bf918dc4b6e8..4f509fe13adc4ea2695a6f38054802779558415d 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include "common.h"
 
index 69287f88f3f7a96d0a3b35488bf0d02647def769..3cf5d2445e1c9e9e8d3831f79637769af3e05310 100644 (file)
@@ -93,7 +93,7 @@ You can also edit [api.h](./api.h) to control which files should be included in
 In fact, if you wanted to only generate bindings for the current version of the `ggml` repo itself (instead of `llama.cpp`; you'd loose support for k-quants), you could run:
 
 ```bash
-API=../../include/ggml/ggml.h python regenerate.py
+API=../../include/ggml.h python regenerate.py
 ```
 
 ## Develop
@@ -109,7 +109,7 @@ pytest
 This example's goal is to showcase [cffi](https://cffi.readthedocs.io/)-generated bindings that are trivial to use and update, but there are already alternatives in the wild:
 
 - https://github.com/abetlen/ggml-python: these bindings seem to be hand-written and use [ctypes](https://docs.python.org/3/library/ctypes.html). It has [high-quality API reference docs](https://ggml-python.readthedocs.io/en/latest/api-reference/#ggml.ggml) that can be used with these bindings too, but it doesn't expose Metal, CUDA, MPI or OpenCL calls, doesn't support transparent (de/re)quantization like this example does (see [ggml.utils](./ggml/utils.py) module), and won't pick up your local changes.
-  
+
 - https://github.com/abetlen/llama-cpp-python: these expose the C++ `llama.cpp` interface, which this example cannot easily be extended to support (`cffi` only generates bindings of C libraries)
 
 - [pybind11](https://github.com/pybind/pybind11) and [nanobind](https://github.com/wjakob/nanobind) are two alternatives to cffi that support binding C++ libraries, but it doesn't seem either of them have an automatic generator (writing bindings is rather time-consuming).
index 4ae6f3c8e64756788d4e6d420db641958e02327a..64da9542f5a89309cb26dd637596d406de5c3fb6 100644 (file)
@@ -1,6 +1,6 @@
 #include "ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #ifdef GGML_USE_CUDA
 #include "ggml-cuda.h"
index b8c81347c2e6f6e4a5307c042cbdabbf42e12ec4..cb446a4a0b993feee81c2eb8e919926f3ba84938 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 #include "yolo-image.h"
 
 #include <cmath>
diff --git a/include/ggml-alloc.h b/include/ggml-alloc.h
new file mode 100644 (file)
index 0000000..434c13b
--- /dev/null
@@ -0,0 +1,76 @@
+#pragma once
+
+#include "ggml.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+typedef struct ggml_backend * ggml_backend_t;
+
+// Tensor allocator
+struct ggml_tallocr {
+    ggml_backend_buffer_t buffer;
+    void * base;
+    size_t alignment;
+    size_t offset;
+};
+
+GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
+GGML_API void                ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
+
+// Graph allocator
+/*
+  Example usage:
+    ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type());
+
+    // optional: create a worst-case graph and reserve the buffers to avoid reallocations
+    ggml_gallocr_reserve(galloc, build_graph(max_batch));
+
+    // allocate the graph
+    struct ggml_cgraph * graph = build_graph(batch);
+    ggml_gallocr_alloc_graph(galloc, graph);
+
+    printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
+
+    // evaluate the graph
+    ggml_backend_graph_compute(backend, graph);
+*/
+
+// special tensor flags for use with the graph allocator:
+//   ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
+//   ggml_set_output(): output tensors are never freed and never overwritten
+
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
+GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
+GGML_API void           ggml_gallocr_free(ggml_gallocr_t galloc);
+
+// pre-allocate buffers from a measure graph - does not allocate or modify the graph
+// call with a worst-case graph to avoid buffer reallocations
+// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
+// returns false if the buffer allocation failed
+GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+GGML_API bool ggml_gallocr_reserve_n(
+    ggml_gallocr_t galloc,
+    struct ggml_cgraph * graph,
+    const int * node_buffer_ids,
+    const int * leaf_buffer_ids);
+
+// automatic reallocation if the topology changes when using a single buffer
+// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
+GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+
+GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
+
+// Utils
+// Create a buffer and allocate all the tensors in a ggml_context
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/include/ggml-backend.h b/include/ggml-backend.h
new file mode 100644 (file)
index 0000000..4a38eeb
--- /dev/null
@@ -0,0 +1,236 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    typedef struct ggml_backend_event * ggml_backend_event_t;
+    typedef struct ggml_backend * ggml_backend_t;
+    typedef void * ggml_backend_graph_plan_t;
+
+    //
+    // Backend buffer
+    //
+
+    // buffer type
+    GGML_API           const char *          ggml_backend_buft_name            (ggml_backend_buffer_type_t buft);
+    GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer    (ggml_backend_buffer_type_t buft, size_t size);
+    GGML_API           size_t                ggml_backend_buft_get_alignment   (ggml_backend_buffer_type_t buft);
+    GGML_API           size_t                ggml_backend_buft_get_max_size    (ggml_backend_buffer_type_t buft);
+    GGML_API GGML_CALL size_t                ggml_backend_buft_get_alloc_size  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API           bool                  ggml_backend_buft_is_host         (ggml_backend_buffer_type_t buft);
+
+    // buffer
+    enum ggml_backend_buffer_usage {
+        GGML_BACKEND_BUFFER_USAGE_ANY = 0,
+        GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
+    };
+
+    GGML_API           const char *               ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);
+    GGML_API           void                       ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
+    GGML_API           void *                     ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                     ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
+    GGML_API GGML_CALL void                       ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API           size_t                     ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                     ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                     ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API           void                       ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
+    GGML_API           bool                       ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
+    GGML_API           void                       ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+    GGML_API           ggml_backend_buffer_type_t ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);
+    GGML_API           void                       ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);
+
+    //
+    // Backend
+    //
+
+    GGML_API ggml_guid_t  ggml_backend_guid(ggml_backend_t backend);
+    GGML_API const char * ggml_backend_name(ggml_backend_t backend);
+    GGML_API void         ggml_backend_free(ggml_backend_t backend);
+
+    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
+    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);
+    GGML_API size_t                     ggml_backend_get_max_size(ggml_backend_t backend);
+
+    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+
+    GGML_API GGML_CALL void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+
+    GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
+
+    GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API void                      ggml_backend_graph_plan_free  (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+    GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+    GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
+    GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
+    GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
+
+    // tensor copy between different backends
+    GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    // asynchronous copy
+    // the copy is performed after all the currently queued operations in backend_src
+    // backend_dst will wait for the copy to complete before performing other operations
+    // automatic fallback to sync copy if async is not supported
+    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    // events
+    GGML_API ggml_backend_event_t   ggml_backend_event_new        (ggml_backend_t backend);
+    GGML_API void                   ggml_backend_event_free       (ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_record     (ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_synchronize(ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_wait       (ggml_backend_t backend, ggml_backend_event_t event);
+
+    //
+    // CPU backend
+    //
+
+    GGML_API ggml_backend_t ggml_backend_cpu_init(void);
+
+    GGML_API GGML_CALL bool ggml_backend_is_cpu                (ggml_backend_t backend);
+    GGML_API           void ggml_backend_cpu_set_n_threads     (ggml_backend_t backend_cpu, int n_threads);
+    GGML_API           void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
+
+    // Create a backend buffer from an existing pointer
+    GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+
+    GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
+
+#ifdef GGML_USE_CPU_HBM
+    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
+#endif
+
+    //
+    // Backend registry
+    //
+
+    // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
+
+    GGML_API size_t                     ggml_backend_reg_get_count(void);
+    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
+    GGML_API const char *               ggml_backend_reg_get_name(size_t i);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
+    GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
+    GGML_API ggml_backend_buffer_t      ggml_backend_reg_alloc_buffer(size_t i, size_t size);
+
+    //
+    // Backend scheduler
+    //
+
+    // The backend scheduler allows for multiple backends to be used together
+    // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
+    // The backends are selected based on:
+    // - the backend that supports the operation
+    // - the location of the pre-allocated tensors (e.g. the weights)
+    /*
+      Example usage:
+
+        // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
+        // preferrably to run on the same backend as the buffer
+        ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+
+        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
+
+        // initialize buffers from a max size graph (optional)
+        reserve_graph = build_graph(sched, max_batch_size);
+
+        // manually assign nodes to a backend (optional, should not be needed in most cases)
+        struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
+        ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
+
+        ggml_backend_sched_reserve(sched, reserve_graph);
+
+        // compute
+        graph = build_graph(sched);
+        ggml_backend_sched_graph_compute(sched, graph);
+
+        // if there are graph inputs:
+        ggml_backend_sched_reset(sched);
+        ggml_backend_sched_alloc_graph(sched, graph);
+        ggml_backend_tensor_set(input_tensor, ...);
+        ggml_backend_sched_graph_compute(sched, graph);
+    }
+    */
+
+    struct ggml_backend_sched;
+    typedef struct ggml_backend_sched * ggml_backend_sched_t;
+
+    // when ask == true, the scheduler wants to know if the user wants to observe this node
+    // this allows the scheduler to batch nodes together in order to evaluate them in a single call
+    //
+    // when ask == false, the scheduler is passing the node tensor to the user for observation
+    // if the user returns false, the scheduler will cancel the graph compute
+    //
+    typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
+
+    // Initialize a backend scheduler
+    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
+    GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
+
+    // Initialize backend buffers from a measure graph
+    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+
+    GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
+
+    // Get the number of splits of the last graph
+    GGML_API int                  ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
+    GGML_API int                  ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
+
+    GGML_API size_t               ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
+
+    GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
+
+    // Allocate and compute graph on the backend scheduler
+    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
+
+    // Reset all assignments and allocators - must be called before changing the node backends
+    GGML_API void                 ggml_backend_sched_reset(ggml_backend_sched_t sched);
+
+    // Set a callback to be called for each resulting node during graph compute
+    GGML_API void                 ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
+
+    //
+    // Utils
+    //
+
+    struct ggml_backend_graph_copy {
+        ggml_backend_buffer_t buffer;
+        struct ggml_context * ctx_allocated;
+        struct ggml_context * ctx_unallocated;
+        struct ggml_cgraph * graph;
+    };
+
+    // Copy a graph to a different backend
+    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
+    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
+
+    typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+
+    // Compare the output of two backends
+    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+
+    // Tensor initialization
+    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+    GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
+
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/include/ggml-blas.h b/include/ggml-blas.h
new file mode 100644 (file)
index 0000000..f2e37de
--- /dev/null
@@ -0,0 +1,23 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void);
+
+GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend);
+
+// number of threads used for conversion to float
+// for openblas and blis, this will also set the number of threads used for blas operations
+GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
+
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/include/ggml-cuda.h b/include/ggml-cuda.h
new file mode 100644 (file)
index 0000000..d7903c6
--- /dev/null
@@ -0,0 +1,44 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef GGML_USE_HIPBLAS
+#define GGML_CUDA_NAME "ROCm"
+#define GGML_CUBLAS_NAME "hipBLAS"
+#else
+#define GGML_CUDA_NAME "CUDA"
+#define GGML_CUBLAS_NAME "cuBLAS"
+#endif
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#define GGML_CUDA_MAX_DEVICES       16
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
+
+GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
+
+// device buffer
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+
+// split tensor buffer that splits matrices by rows across multiple devices
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
+
+// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
+
+GGML_API GGML_CALL int  ggml_backend_cuda_get_device_count(void);
+GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
+GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
+
+GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
+GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
+
+GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data);
+#ifdef  __cplusplus
+}
+#endif
diff --git a/include/ggml-kompute.h b/include/ggml-kompute.h
new file mode 100644 (file)
index 0000000..1714654
--- /dev/null
@@ -0,0 +1,46 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct ggml_vk_device {
+    int index;
+    int type; // same as VkPhysicalDeviceType
+    size_t heapSize;
+    const char * name;
+    const char * vendor;
+    int subgroupSize;
+    uint64_t bufferAlignment;
+    uint64_t maxAlloc;
+};
+
+struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
+bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
+bool ggml_vk_has_vulkan(void);
+bool ggml_vk_has_device(void);
+struct ggml_vk_device ggml_vk_current_device(void);
+
+//
+// backend API
+//
+
+// forward declaration
+typedef struct ggml_backend * ggml_backend_t;
+
+GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
+
+GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
+
+GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/include/ggml-metal.h b/include/ggml-metal.h
new file mode 100644 (file)
index 0000000..e7543ae
--- /dev/null
@@ -0,0 +1,66 @@
+// An interface allowing to compute ggml_cgraph with Metal
+//
+// This is a fully functional interface that extends ggml with GPU support for Apple devices.
+// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
+//
+// How it works?
+//
+// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
+// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
+// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
+//
+// You only need to make sure that all memory buffers that you used during the graph creation
+// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
+// used during the graph evaluation to determine the arguments of the compute kernels.
+//
+// Synchronization between device and host memory (for example for input and output tensors)
+// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
+//
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#include <stddef.h>
+#include <stdbool.h>
+
+// max memory buffers that can be mapped to the device
+#define GGML_METAL_MAX_BUFFERS 64
+
+struct ggml_tensor;
+struct ggml_cgraph;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//
+// backend API
+// user-code should use only these functions
+//
+
+GGML_API void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data);
+
+GGML_API ggml_backend_t ggml_backend_metal_init(void);
+
+GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
+
+GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
+
+GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
+
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+
+// helper to check if the device supports a specific family
+// ideally, the user code should be doing these checks
+// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
+
+// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
+GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
+
+#ifdef __cplusplus
+}
+#endif
+
diff --git a/include/ggml-rpc.h b/include/ggml-rpc.h
new file mode 100644 (file)
index 0000000..aa14483
--- /dev/null
@@ -0,0 +1,24 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#define GGML_RPC_MAX_SERVERS       16
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
+GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);
+
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
+
+GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
+
+GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/include/ggml-sycl.h b/include/ggml-sycl.h
new file mode 100644 (file)
index 0000000..43ab151
--- /dev/null
@@ -0,0 +1,42 @@
+//
+//  MIT license
+//  Copyright (C) 2024 Intel Corporation
+//  SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#define GGML_SYCL_NAME "SYCL"
+#define GGML_SYCL_MAX_DEVICES 48
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+// backend API
+GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
+
+// devide buffer
+GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
+
+// split tensor buffer that splits matrices by rows across multiple devices
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
+
+// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
+GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
+
+GGML_API void   ggml_backend_sycl_print_sycl_devices(void);
+GGML_API GGML_CALL void   ggml_sycl_get_gpu_list(int *id_list, int max_len);
+GGML_API GGML_CALL void   ggml_sycl_get_device_description(int device, char *description, size_t description_size);
+GGML_API GGML_CALL int   ggml_backend_sycl_get_device_count();
+GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
+
+// SYCL doesn't support registering host memory, keep here for reference
+// GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
+// GGML_API GGML_CALL void ggml_backend_sycl_unregister_host_buffer(void * buffer);
+#ifdef  __cplusplus
+}
+#endif
diff --git a/include/ggml-vulkan.h b/include/ggml-vulkan.h
new file mode 100644 (file)
index 0000000..af661c2
--- /dev/null
@@ -0,0 +1,29 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#define GGML_VK_NAME "Vulkan"
+#define GGML_VK_MAX_DEVICES 16
+
+GGML_API void ggml_vk_instance_init(void);
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);
+
+GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend);
+GGML_API GGML_CALL int  ggml_backend_vk_get_device_count(void);
+GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
+GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
+
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
+// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/include/ggml.h b/include/ggml.h
new file mode 100644 (file)
index 0000000..d895c9a
--- /dev/null
@@ -0,0 +1,2427 @@
+#pragma once
+
+//
+// GGML Tensor Library
+//
+// This documentation is still a work in progress.
+// If you wish some specific topics to be covered, feel free to drop a comment:
+//
+//   https://github.com/ggerganov/whisper.cpp/issues/40
+//
+// ## Overview
+//
+// This library implements:
+//
+//  - a set of tensor operations
+//  - automatic differentiation
+//  - basic optimization algorithms
+//
+// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
+// but is not limited to, the following:
+//
+//  - linear regression
+//  - support vector machines
+//  - neural networks
+//
+// The library allows the user to define a certain function using the available tensor operations. This function
+// definition is represented internally via a computation graph. Each tensor operation in the function definition
+// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
+// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
+// using one of the available optimization algorithms.
+//
+// For example, here we define the function: f(x) = a*x^2 + b
+//
+//   {
+//       struct ggml_init_params params = {
+//           .mem_size   = 16*1024*1024,
+//           .mem_buffer = NULL,
+//       };
+//
+//       // memory allocation happens here
+//       struct ggml_context * ctx = ggml_init(params);
+//
+//       struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//
+//       ggml_set_param(ctx, x); // x is an input variable
+//
+//       struct ggml_tensor * a  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//       struct ggml_tensor * b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//       struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
+//       struct ggml_tensor * f  = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
+//
+//       ...
+//   }
+//
+// Notice that the function definition above does not involve any actual computation. The computation is performed only
+// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
+//
+//   {
+//       ...
+//
+//       struct ggml_cgraph * gf = ggml_new_graph(ctx);
+//       ggml_build_forward_expand(gf, f);
+//
+//       // set the input variable and parameter values
+//       ggml_set_f32(x, 2.0f);
+//       ggml_set_f32(a, 3.0f);
+//       ggml_set_f32(b, 4.0f);
+//
+//       ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
+//
+//       printf("f = %f\n", ggml_get_f32_1d(f, 0));
+//
+//       ...
+//   }
+//
+// The actual computation is performed in the ggml_graph_compute() function.
+//
+// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
+// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
+// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
+// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
+// actually needed.
+//
+// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
+// differentiation and optimization algorithms.
+//
+// The described approach allows to define the function graph once and then compute its forward or backward graphs
+// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
+// the user can avoid the memory allocation overhead at runtime.
+//
+// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
+// citizens, but in theory the library can be extended to support FP8 and integer data types.
+//
+// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
+// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
+// clear that the library needs to support more complex operations. The way to support these operations is not clear
+// yet, but a few examples are demonstrated in the following operations:
+//
+//   - ggml_permute()
+//   - ggml_conv_1d_1s()
+//   - ggml_conv_1d_2s()
+//
+// For each tensor operator, the library implements a forward and backward computation function. The forward function
+// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
+// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
+// calculus class, or watch the following video:
+//
+//   What is Automatic Differentiation?
+//   https://www.youtube.com/watch?v=wG_nF1awSSY
+//
+//
+// ## Tensor data (struct ggml_tensor)
+//
+// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
+// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
+// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
+//
+//   {
+//       struct ggml_tensor * c = ggml_add(ctx, a, b);
+//
+//       assert(c->src[0] == a);
+//       assert(c->src[1] == b);
+//   }
+//
+// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
+// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
+// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
+// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
+// contiguous in memory.
+//
+// The data of the tensor is accessed via the "data" pointer. For example:
+//
+//   {
+//       const int nx = 2;
+//       const int ny = 3;
+//
+//       struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);
+//
+//       for (int y = 0; y < ny; y++) {
+//           for (int x = 0; x < nx; x++) {
+//               *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
+//           }
+//       }
+//
+//       ...
+//   }
+//
+// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
+//
+// ## The matrix multiplication operator (ggml_mul_mat)
+//
+// TODO
+//
+//
+// ## Multi-threading
+//
+// TODO
+//
+//
+// ## Overview of ggml.c
+//
+// TODO
+//
+//
+// ## SIMD optimizations
+//
+// TODO
+//
+//
+// ## Debugging ggml
+//
+// TODO
+//
+//
+
+#ifdef GGML_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef GGML_BUILD
+#            define GGML_API __declspec(dllexport)
+#        else
+#            define GGML_API __declspec(dllimport)
+#        endif
+#    else
+#        define GGML_API __attribute__ ((visibility ("default")))
+#    endif
+#else
+#    define GGML_API
+#endif
+
+#ifdef GGML_MULTIPLATFORM
+#    if defined(_WIN32)
+#        define GGML_CALL
+#    else
+#        define GGML_CALL __attribute__((__ms_abi__))
+#    endif
+#else
+#    define GGML_CALL
+#endif
+
+// TODO: support for clang
+#ifdef __GNUC__
+#    define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
+#elif defined(_MSC_VER)
+#    define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
+#else
+#    define GGML_DEPRECATED(func, hint) func
+#endif
+
+#ifndef __GNUC__
+#    define GGML_ATTRIBUTE_FORMAT(...)
+#elif defined(__MINGW32__)
+#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+
+#define GGML_FILE_MAGIC   0x67676d6c // "ggml"
+#define GGML_FILE_VERSION 1
+
+#define GGML_QNT_VERSION        2    // bump this on quantization format changes
+#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
+
+#define GGML_MAX_DIMS           4
+#define GGML_MAX_PARAMS         2048
+#define GGML_MAX_CONTEXTS       64
+#define GGML_MAX_SRC            10
+#ifndef GGML_MAX_NAME
+#define GGML_MAX_NAME           64
+#endif
+#define GGML_MAX_OP_PARAMS      64
+#define GGML_DEFAULT_N_THREADS  4
+#define GGML_DEFAULT_GRAPH_SIZE 2048
+#if UINTPTR_MAX == 0xFFFFFFFF
+    #define GGML_MEM_ALIGN 4
+#else
+    #define GGML_MEM_ALIGN 16
+#endif
+
+#define GGML_EXIT_SUCCESS 0
+#define GGML_EXIT_ABORTED 1
+
+#define GGUF_MAGIC "GGUF"
+
+#define GGUF_VERSION 3
+
+#define GGUF_DEFAULT_ALIGNMENT 32
+
+#define GGML_UNUSED(x) (void)(x)
+
+#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
+
+#define GGML_ASSERT(x) \
+    do { \
+        if (!(x)) { \
+            fflush(stdout); \
+            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            ggml_print_backtrace(); \
+            abort(); \
+        } \
+    } while (0)
+
+#ifndef NDEBUG
+#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
+#elif defined(__GNUC__)
+#define GGML_UNREACHABLE() __builtin_unreachable()
+#elif defined(_MSC_VER)
+#define GGML_UNREACHABLE() __assume(0)
+#else
+#define GGML_UNREACHABLE() ((void) 0)
+#endif
+
+// used to copy the number of elements and stride in bytes of tensors into local variables.
+// main purpose is to reduce code duplication and improve readability.
+//
+// example:
+//
+//    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
+//    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
+//
+#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
+    const type prefix##0 = (pointer)->array[0]; \
+    GGML_UNUSED(prefix##0);
+#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_1    (type, prefix, pointer, array) \
+    const type prefix##1 = (pointer)->array[1]; \
+    GGML_UNUSED(prefix##1);
+#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_2    (type, prefix, pointer, array) \
+    const type prefix##2 = (pointer)->array[2]; \
+    GGML_UNUSED(prefix##2);
+#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_3  (type, prefix, pointer, array) \
+    const type prefix##3 = (pointer)->array[3]; \
+    GGML_UNUSED(prefix##3);
+
+#define GGML_TENSOR_UNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS01 \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    enum ggml_status {
+        GGML_STATUS_ALLOC_FAILED = -2,
+        GGML_STATUS_FAILED = -1,
+        GGML_STATUS_SUCCESS = 0,
+        GGML_STATUS_ABORTED = 1,
+    };
+
+    // get ggml_status name string
+    GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
+
+    // ieee 754-2008 half-precision float16
+    // todo: make this not an integral type
+    typedef uint16_t ggml_fp16_t;
+    GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t);
+    GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
+    GGML_API void        ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
+    GGML_API void        ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
+
+    // google brain half-precision bfloat16
+    typedef struct { uint16_t bits; } ggml_bf16_t;
+    GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
+    GGML_API float       ggml_bf16_to_fp32(ggml_bf16_t);  // consider just doing << 16
+    GGML_API void        ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
+    GGML_API void        ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
+
+    struct ggml_object;
+    struct ggml_context;
+
+    // NOTE: always add types at the end of the enum to keep backward compatibility
+    enum ggml_type {
+        GGML_TYPE_F32     = 0,
+        GGML_TYPE_F16     = 1,
+        GGML_TYPE_Q4_0    = 2,
+        GGML_TYPE_Q4_1    = 3,
+        // GGML_TYPE_Q4_2 = 4, support has been removed
+        // GGML_TYPE_Q4_3 = 5, support has been removed
+        GGML_TYPE_Q5_0    = 6,
+        GGML_TYPE_Q5_1    = 7,
+        GGML_TYPE_Q8_0    = 8,
+        GGML_TYPE_Q8_1    = 9,
+        GGML_TYPE_Q2_K    = 10,
+        GGML_TYPE_Q3_K    = 11,
+        GGML_TYPE_Q4_K    = 12,
+        GGML_TYPE_Q5_K    = 13,
+        GGML_TYPE_Q6_K    = 14,
+        GGML_TYPE_Q8_K    = 15,
+        GGML_TYPE_IQ2_XXS = 16,
+        GGML_TYPE_IQ2_XS  = 17,
+        GGML_TYPE_IQ3_XXS = 18,
+        GGML_TYPE_IQ1_S   = 19,
+        GGML_TYPE_IQ4_NL  = 20,
+        GGML_TYPE_IQ3_S   = 21,
+        GGML_TYPE_IQ2_S   = 22,
+        GGML_TYPE_IQ4_XS  = 23,
+        GGML_TYPE_I8      = 24,
+        GGML_TYPE_I16     = 25,
+        GGML_TYPE_I32     = 26,
+        GGML_TYPE_I64     = 27,
+        GGML_TYPE_F64     = 28,
+        GGML_TYPE_IQ1_M   = 29,
+        GGML_TYPE_BF16    = 30,
+        GGML_TYPE_COUNT,
+    };
+
+    // precision
+    enum ggml_prec {
+        GGML_PREC_DEFAULT,
+        GGML_PREC_F32,
+    };
+
+    enum ggml_backend_type {
+        GGML_BACKEND_TYPE_CPU = 0,
+        GGML_BACKEND_TYPE_GPU = 10,
+        GGML_BACKEND_TYPE_GPU_SPLIT = 20,
+    };
+
+    // model file types
+    enum ggml_ftype {
+        GGML_FTYPE_UNKNOWN        = -1,
+        GGML_FTYPE_ALL_F32        = 0,
+        GGML_FTYPE_MOSTLY_F16     = 1,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_0    = 2,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_1    = 3,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
+        GGML_FTYPE_MOSTLY_Q8_0    = 7,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q5_0    = 8,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q5_1    = 9,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q2_K    = 10, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q3_K    = 11, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_K    = 12, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q5_K    = 13, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q6_K    = 14, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ2_XS  = 16, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ1_S   = 18, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ4_NL  = 19, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ3_S   = 20, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ2_S   = 21, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
+        GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors
+    };
+
+    // available tensor operations:
+    enum ggml_op {
+        GGML_OP_NONE = 0,
+
+        GGML_OP_DUP,
+        GGML_OP_ADD,
+        GGML_OP_ADD1,
+        GGML_OP_ACC,
+        GGML_OP_SUB,
+        GGML_OP_MUL,
+        GGML_OP_DIV,
+        GGML_OP_SQR,
+        GGML_OP_SQRT,
+        GGML_OP_LOG,
+        GGML_OP_SUM,
+        GGML_OP_SUM_ROWS,
+        GGML_OP_MEAN,
+        GGML_OP_ARGMAX,
+        GGML_OP_REPEAT,
+        GGML_OP_REPEAT_BACK,
+        GGML_OP_CONCAT,
+        GGML_OP_SILU_BACK,
+        GGML_OP_NORM, // normalize
+        GGML_OP_RMS_NORM,
+        GGML_OP_RMS_NORM_BACK,
+        GGML_OP_GROUP_NORM,
+
+        GGML_OP_MUL_MAT,
+        GGML_OP_MUL_MAT_ID,
+        GGML_OP_OUT_PROD,
+
+        GGML_OP_SCALE,
+        GGML_OP_SET,
+        GGML_OP_CPY,
+        GGML_OP_CONT,
+        GGML_OP_RESHAPE,
+        GGML_OP_VIEW,
+        GGML_OP_PERMUTE,
+        GGML_OP_TRANSPOSE,
+        GGML_OP_GET_ROWS,
+        GGML_OP_GET_ROWS_BACK,
+        GGML_OP_DIAG,
+        GGML_OP_DIAG_MASK_INF,
+        GGML_OP_DIAG_MASK_ZERO,
+        GGML_OP_SOFT_MAX,
+        GGML_OP_SOFT_MAX_BACK,
+        GGML_OP_ROPE,
+        GGML_OP_ROPE_BACK,
+        GGML_OP_CLAMP,
+        GGML_OP_CONV_TRANSPOSE_1D,
+        GGML_OP_IM2COL,
+        GGML_OP_CONV_TRANSPOSE_2D,
+        GGML_OP_POOL_1D,
+        GGML_OP_POOL_2D,
+        GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_PAD,
+        GGML_OP_ARANGE,
+        GGML_OP_TIMESTEP_EMBEDDING,
+        GGML_OP_ARGSORT,
+        GGML_OP_LEAKY_RELU,
+
+        GGML_OP_FLASH_ATTN_EXT,
+        GGML_OP_FLASH_ATTN_BACK,
+        GGML_OP_SSM_CONV,
+        GGML_OP_SSM_SCAN,
+        GGML_OP_WIN_PART,
+        GGML_OP_WIN_UNPART,
+        GGML_OP_GET_REL_POS,
+        GGML_OP_ADD_REL_POS,
+
+        GGML_OP_UNARY,
+
+        GGML_OP_MAP_UNARY,
+        GGML_OP_MAP_BINARY,
+
+        GGML_OP_MAP_CUSTOM1_F32,
+        GGML_OP_MAP_CUSTOM2_F32,
+        GGML_OP_MAP_CUSTOM3_F32,
+
+        GGML_OP_MAP_CUSTOM1,
+        GGML_OP_MAP_CUSTOM2,
+        GGML_OP_MAP_CUSTOM3,
+
+        GGML_OP_CROSS_ENTROPY_LOSS,
+        GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+
+        GGML_OP_COUNT,
+    };
+
+    enum ggml_unary_op {
+        GGML_UNARY_OP_ABS,
+        GGML_UNARY_OP_SGN,
+        GGML_UNARY_OP_NEG,
+        GGML_UNARY_OP_STEP,
+        GGML_UNARY_OP_TANH,
+        GGML_UNARY_OP_ELU,
+        GGML_UNARY_OP_RELU,
+        GGML_UNARY_OP_SIGMOID,
+        GGML_UNARY_OP_GELU,
+        GGML_UNARY_OP_GELU_QUICK,
+        GGML_UNARY_OP_SILU,
+        GGML_UNARY_OP_HARDSWISH,
+        GGML_UNARY_OP_HARDSIGMOID,
+
+        GGML_UNARY_OP_COUNT,
+    };
+
+    enum ggml_object_type {
+        GGML_OBJECT_TYPE_TENSOR,
+        GGML_OBJECT_TYPE_GRAPH,
+        GGML_OBJECT_TYPE_WORK_BUFFER
+    };
+
+    enum ggml_log_level {
+        GGML_LOG_LEVEL_ERROR = 2,
+        GGML_LOG_LEVEL_WARN  = 3,
+        GGML_LOG_LEVEL_INFO  = 4,
+        GGML_LOG_LEVEL_DEBUG = 5
+    };
+
+    enum ggml_tensor_flag {
+        GGML_TENSOR_FLAG_INPUT  = 1,
+        GGML_TENSOR_FLAG_OUTPUT = 2,
+        GGML_TENSOR_FLAG_PARAM  = 4,
+    };
+
+    // ggml object
+    struct ggml_object {
+        size_t offs;
+        size_t size;
+
+        struct ggml_object * next;
+
+        enum ggml_object_type type;
+
+        char padding[4];
+    };
+
+    static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+    // n-dimensional tensor
+    struct ggml_tensor {
+        enum ggml_type         type;
+
+        GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
+
+        struct ggml_backend_buffer * buffer;
+
+        int64_t ne[GGML_MAX_DIMS]; // number of elements
+        size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
+                                   // nb[0] = ggml_type_size(type)
+                                   // nb[1] = nb[0]   * (ne[0] / ggml_blck_size(type)) + padding
+                                   // nb[i] = nb[i-1] * ne[i-1]
+
+        // compute data
+        enum ggml_op op;
+
+        // op params - allocated as int32_t for alignment
+        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+
+        int32_t flags;
+
+        struct ggml_tensor * grad;
+        struct ggml_tensor * src[GGML_MAX_SRC];
+
+        // source tensor and offset for views
+        struct ggml_tensor * view_src;
+        size_t               view_offs;
+
+        void * data;
+
+        char name[GGML_MAX_NAME];
+
+        void * extra; // extra things e.g. for ggml-cuda.cu
+
+        // char padding[4];
+    };
+
+    static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
+
+    // Abort callback
+    // If not NULL, called before ggml computation
+    // If it returns true, the computation is aborted
+    typedef bool (*ggml_abort_callback)(void * data);
+
+    // the compute plan that needs to be prepared for ggml_graph_compute()
+    // since https://github.com/ggerganov/ggml/issues/287
+    struct ggml_cplan {
+        size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`
+        uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
+
+        int n_threads;
+
+        // abort ggml_graph_compute when true
+        ggml_abort_callback abort_callback;
+        void *              abort_callback_data;
+    };
+
+    enum ggml_cgraph_eval_order {
+        GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
+        GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
+        GGML_CGRAPH_EVAL_ORDER_COUNT
+    };
+
+    struct ggml_hash_set {
+        size_t size;
+        struct ggml_tensor ** keys;
+    };
+
+    // computation graph
+    struct ggml_cgraph {
+        int size;
+        int n_nodes;
+        int n_leafs;
+
+        struct ggml_tensor ** nodes;
+        struct ggml_tensor ** grads;
+        struct ggml_tensor ** leafs;
+
+        struct ggml_hash_set visited_hash_table;
+
+        enum ggml_cgraph_eval_order order;
+    };
+
+    // scratch buffer
+    struct ggml_scratch {
+        size_t offs;
+        size_t size;
+        void * data;
+    };
+
+    struct ggml_init_params {
+        // memory pool
+        size_t mem_size;   // bytes
+        void * mem_buffer; // if NULL, memory will be allocated internally
+        bool   no_alloc;   // don't allocate memory for the tensor data
+    };
+
+    // numa strategies
+    enum ggml_numa_strategy {
+        GGML_NUMA_STRATEGY_DISABLED   = 0,
+        GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
+        GGML_NUMA_STRATEGY_ISOLATE    = 2,
+        GGML_NUMA_STRATEGY_NUMACTL    = 3,
+        GGML_NUMA_STRATEGY_MIRROR     = 4,
+        GGML_NUMA_STRATEGY_COUNT
+    };
+
+    //
+    // GUID
+    //
+
+    // GUID types
+    typedef uint8_t ggml_guid[16];
+    typedef ggml_guid * ggml_guid_t;
+
+    GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b);
+
+    // misc
+
+    GGML_API void    ggml_time_init(void); // call this once at the beginning of the program
+    GGML_API int64_t ggml_time_ms(void);
+    GGML_API int64_t ggml_time_us(void);
+    GGML_API int64_t ggml_cycles(void);
+    GGML_API int64_t ggml_cycles_per_ms(void);
+
+    GGML_API void    ggml_print_backtrace(void);
+
+    // accepts a UTF-8 path, even on Windows
+    GGML_API FILE *  ggml_fopen(const char * fname, const char * mode);
+
+    GGML_API void    ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
+    GGML_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
+
+    GGML_API void    ggml_print_object (const struct ggml_object * obj);
+    GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
+
+    GGML_API GGML_CALL int64_t ggml_nelements   (const struct ggml_tensor * tensor);
+    GGML_API GGML_CALL int64_t ggml_nrows       (const struct ggml_tensor * tensor);
+    GGML_API GGML_CALL size_t  ggml_nbytes      (const struct ggml_tensor * tensor);
+    GGML_API           size_t  ggml_nbytes_pad  (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
+
+    GGML_API GGML_CALL int    ggml_blck_size(enum ggml_type type);
+    GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type);             // size in bytes for all elements in a block
+    GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
+
+    GGML_DEPRECATED(
+    GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
+    "use ggml_row_size() instead");
+
+    GGML_API GGML_CALL const char * ggml_type_name(enum ggml_type type);
+    GGML_API GGML_CALL const char * ggml_op_name  (enum ggml_op   op);
+    GGML_API           const char * ggml_op_symbol(enum ggml_op   op);
+
+    GGML_API           const char * ggml_unary_op_name(enum ggml_unary_op op);
+    GGML_API GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
+
+    GGML_API GGML_CALL size_t  ggml_element_size(const struct ggml_tensor * tensor);
+
+    GGML_API GGML_CALL bool    ggml_is_quantized(enum ggml_type type);
+
+    // TODO: temporary until model loading of ggml examples is refactored
+    GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
+
+    GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
+    GGML_API GGML_CALL bool ggml_is_permuted  (const struct ggml_tensor * tensor);
+    GGML_API GGML_CALL bool ggml_is_empty     (const struct ggml_tensor * tensor);
+    GGML_API           bool ggml_is_scalar    (const struct ggml_tensor * tensor);
+    GGML_API           bool ggml_is_vector    (const struct ggml_tensor * tensor);
+    GGML_API           bool ggml_is_matrix    (const struct ggml_tensor * tensor);
+    GGML_API           bool ggml_is_3d        (const struct ggml_tensor * tensor);
+    GGML_API           int  ggml_n_dims       (const struct ggml_tensor * tensor); // returns 1 for scalars
+
+    GGML_API GGML_CALL bool ggml_is_contiguous  (const struct ggml_tensor * tensor);
+    GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
+    GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
+    GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
+
+    GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
+    GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
+
+    // use this to compute the memory overhead of a tensor
+    GGML_API size_t ggml_tensor_overhead(void);
+
+    GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
+
+    // main
+
+    GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
+    GGML_API void                  ggml_free(struct ggml_context * ctx);
+
+    GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
+
+    GGML_API size_t  ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
+    GGML_API bool    ggml_get_no_alloc(struct ggml_context * ctx);
+    GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
+
+    GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int    n_dims,
+            const int64_t *ne);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_1d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_2d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_3d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1,
+            int64_t ne2);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_4d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1,
+            int64_t ne2,
+            int64_t ne3);
+
+    GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
+    GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+    GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
+
+    // Context tensor enumeration and lookup
+    GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);
+    GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
+
+    GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
+    GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+    // Converts a flat index into coordinates
+    GGML_API void    ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
+
+    GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
+    GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_API void    ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
+
+    GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+    GGML_API float   ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_API void    ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
+
+    GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
+    GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
+
+    GGML_API GGML_CALL enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
+
+    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);
+    GGML_ATTRIBUTE_FORMAT(2, 3)
+    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);
+
+    //
+    // operations on tensors with backpropagation
+    //
+
+    GGML_API struct ggml_tensor * ggml_dup(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_dup_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_add(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_add_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_add_cast(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            enum   ggml_type      type);
+
+    GGML_API struct ggml_tensor * ggml_add1(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_add1_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // dst = a
+    // view(dst, nb1, nb2, nb3, offset) += b
+    // return dst
+    GGML_API struct ggml_tensor * ggml_acc(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_acc_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_sub(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_sub_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_mul(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_mul_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_div(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_div_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_sqr(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sqr_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sqrt(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sqrt_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_log(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_log_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // return scalar
+    GGML_API struct ggml_tensor * ggml_sum(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
+    GGML_API struct ggml_tensor * ggml_sum_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // mean along rows
+    GGML_API struct ggml_tensor * ggml_mean(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // argmax along rows
+    GGML_API struct ggml_tensor * ggml_argmax(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // if a is the same shape as b, and a is not parameter, return a
+    // otherwise, return a new tensor: repeat(a) to fit in b
+    GGML_API struct ggml_tensor * ggml_repeat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // sums repetitions in a into shape of b
+    GGML_API struct ggml_tensor * ggml_repeat_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // concat a and b along dim
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_concat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   dim);
+
+    GGML_API struct ggml_tensor * ggml_abs(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_abs_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sgn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sgn_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_neg(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_neg_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_step(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_step_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_tanh(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_tanh_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_elu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_elu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_relu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_leaky_relu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a, float negative_slope, bool inplace);
+
+    GGML_API struct ggml_tensor * ggml_relu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sigmoid(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_quick(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_silu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_silu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // a - x
+    // b - dy
+    GGML_API struct ggml_tensor * ggml_silu_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // hardswish(x) = x * relu6(x + 3) / 6
+    GGML_API struct ggml_tensor * ggml_hardswish(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // hardsigmoid(x) = relu6(x + 3) / 6
+    GGML_API struct ggml_tensor * ggml_hardsigmoid(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // normalize along rows
+    GGML_API struct ggml_tensor * ggml_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_rms_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    // group normalize along ne0*ne1*n_groups
+    // used in stable-diffusion
+    // TODO: eps is hardcoded to 1e-6 for now
+    GGML_API struct ggml_tensor * ggml_group_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups);
+
+    GGML_API struct ggml_tensor * ggml_group_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups);
+
+    // a - x
+    // b - dy
+    GGML_API struct ggml_tensor * ggml_rms_norm_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            float                 eps);
+
+    // A: k columns, n rows => [ne03, ne02, n, k]
+    // B: k columns, m rows  (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
+    // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
+    GGML_API struct ggml_tensor * ggml_mul_mat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // change the precision of a matrix multiplication
+    // set to GGML_PREC_F32 for higher precision (useful for phi-2)
+    GGML_API void ggml_mul_mat_set_prec(
+            struct ggml_tensor * a,
+            enum ggml_prec       prec);
+
+    // indirect matrix multiplication
+    GGML_API struct ggml_tensor * ggml_mul_mat_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * as,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * ids);
+
+    // A: m columns, n rows,
+    // B: p columns, n rows,
+    // result is m columns, p rows
+    GGML_API struct ggml_tensor * ggml_out_prod(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    //
+    // operations on tensors without backpropagation
+    //
+
+    GGML_API struct ggml_tensor * ggml_scale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 s);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_scale_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 s);
+
+    // b -> view(a,offset,nb1,nb2,3), return modified a
+    GGML_API struct ggml_tensor * ggml_set(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return view(a)
+    GGML_API struct ggml_tensor * ggml_set_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_set_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_set_1d_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return modified a
+    GGML_API struct ggml_tensor * ggml_set_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                offset);
+
+    // b -> view(a,offset,nb1,nb2,3), return view(a)
+    GGML_API struct ggml_tensor * ggml_set_2d_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                offset);
+
+    // a -> b, return view(b)
+    GGML_API struct ggml_tensor * ggml_cpy(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_cast(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum   ggml_type      type);
+
+    // make contiguous
+    GGML_API struct ggml_tensor * ggml_cont(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // make contiguous, with new shape
+    GGML_API struct ggml_tensor * ggml_cont_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0);
+
+    GGML_API struct ggml_tensor * ggml_cont_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1);
+
+    GGML_API struct ggml_tensor * ggml_cont_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2);
+
+    GGML_API struct ggml_tensor * ggml_cont_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3);
+
+    // return view(a), b specifies the new shape
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0);
+
+    GGML_API struct ggml_tensor * ggml_reshape_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1);
+
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2);
+
+    GGML_API struct ggml_tensor * ggml_reshape_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3);
+
+    // offset in bytes
+    GGML_API struct ggml_tensor * ggml_view_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            size_t                nb1, // row stride in bytes
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            size_t                nb1, // row   stride in bytes
+            size_t                nb2, // slice stride in bytes
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3,
+            size_t                nb1, // row   stride in bytes
+            size_t                nb2, // slice stride in bytes
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_permute(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   axis0,
+            int                   axis1,
+            int                   axis2,
+            int                   axis3);
+
+    // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+    GGML_API struct ggml_tensor * ggml_transpose(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // supports 3D: a->ne[2] == b->ne[1]
+    GGML_API struct ggml_tensor * ggml_get_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_get_rows_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c);
+
+    GGML_API struct ggml_tensor * ggml_diag(
+        struct ggml_context     * ctx,
+        struct ggml_tensor      * a);
+
+    // set elements above the diagonal to -INF
+    GGML_API struct ggml_tensor * ggml_diag_mask_inf(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // set elements above the diagonal to 0
+    GGML_API struct ggml_tensor * ggml_diag_mask_zero(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    GGML_API struct ggml_tensor * ggml_soft_max(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // fused soft_max(a*scale + mask*(ALiBi slope))
+    // mask is optional
+    // max_bias = 0.0f for no ALiBi
+    GGML_API struct ggml_tensor * ggml_soft_max_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * mask,
+            float                 scale,
+            float                 max_bias);
+
+    GGML_API struct ggml_tensor * ggml_soft_max_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // rotary position embedding
+    // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
+    // if mode & 2 == 1, GPT-NeoX style
+    //
+    // b is an int32 vector with size a->ne[2], it contains the positions
+    // c is freq factors (e.g. phi3-128k), (optional)
+    GGML_API struct ggml_tensor * ggml_rope(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode);
+
+    // custom RoPE
+    GGML_API struct ggml_tensor * ggml_rope_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow),
+        "use ggml_rope_ext instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow),
+        "use ggml_rope_ext_inplace instead");
+
+    // compute correction dims for YaRN RoPE scaling
+    GGML_CALL void ggml_rope_yarn_corr_dims(
+        int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+
+    // rotary position embedding backward, i.e compute dx from dy
+    // a - dy
+    GGML_API struct ggml_tensor * ggml_rope_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    // clamp
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_clamp(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 min,
+            float                 max);
+
+    GGML_API struct ggml_tensor * ggml_im2col(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                  s0,
+            int                  s1,
+            int                  p0,
+            int                  p1,
+            int                  d0,
+            int                  d1,
+            bool                 is_2D,
+            enum ggml_type       dst_type);
+
+    GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                  s0,
+            int                  s1,
+            int                  p0,
+            int                  p1,
+            int                  d0,
+            int                  d1);
+
+    GGML_API struct ggml_tensor * ggml_conv_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   s0,  // stride
+            int                   p0,  // padding
+            int                   d0); // dilation
+
+    // conv_1d with padding = half
+    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
+    GGML_API struct ggml_tensor* ggml_conv_1d_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   s,
+            int                   d);
+
+    GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   s0,
+            int                   p0,
+            int                   d0);
+
+    GGML_API struct ggml_tensor * ggml_conv_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   s0,
+            int                   s1,
+            int                   p0,
+            int                   p1,
+            int                   d0,
+            int                   d1);
+
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is equal to kernel size
+    // padding is zero
+    // example:
+    // a:     16   16    3  768
+    // b:   1024 1024    3    1
+    // res:   64   64  768    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is 1
+    // padding is half
+    // example:
+    // a:      3    3    256  256
+    // b:     64   64    256    1
+    // res:   64   64    256    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   stride);
+
+    enum ggml_op_pool {
+        GGML_OP_POOL_MAX,
+        GGML_OP_POOL_AVG,
+        GGML_OP_POOL_COUNT,
+    };
+
+    GGML_API struct ggml_tensor * ggml_pool_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_op_pool     op,
+            int                   k0, // kernel size
+            int                   s0, // stride
+            int                   p0); // padding
+
+    // the result will have 2*p0 padding for the first dimension
+    // and 2*p1 padding for the second dimension
+    GGML_API struct ggml_tensor * ggml_pool_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_op_pool     op,
+            int                   k0,
+            int                   k1,
+            int                   s0,
+            int                   s1,
+            float                 p0,
+            float                 p1);
+
+    // nearest interpolate
+    // multiplies ne0 and ne1 by scale factor
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_upscale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   scale_factor);
+
+    // nearest interpolate
+    // nearest interpolate to specified dimensions
+    // used in tortoise.cpp
+    GGML_API struct ggml_tensor * ggml_upscale_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   ne0,
+            int                   ne1,
+            int                   ne2,
+            int                   ne3);
+
+    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
+    GGML_API struct ggml_tensor * ggml_pad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
+    // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
+    // timesteps: [N,]
+    // return: [N, dim]
+    GGML_API struct ggml_tensor * ggml_timestep_embedding(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * timesteps,
+            int                   dim,
+            int                   max_period);
+
+    // sort rows
+    enum ggml_sort_order {
+        GGML_SORT_ORDER_ASC,
+        GGML_SORT_ORDER_DESC,
+    };
+
+    GGML_API struct ggml_tensor * ggml_argsort(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_sort_order  order);
+
+    GGML_API struct ggml_tensor * ggml_arange(
+            struct ggml_context * ctx,
+            float                 start,
+            float                 stop,
+            float                 step);
+
+    // top k elements per row
+    GGML_API struct ggml_tensor * ggml_top_k(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   k);
+
+#define GGML_KQ_MASK_PAD 32
+
+    // q:    [n_embd, n_batch,     n_head,    1]
+    // k:    [n_embd, n_kv,        n_head_kv, 1]
+    // v:    [n_embd, n_kv,        n_head_kv, 1] !! not transposed !!
+    // mask: [n_kv,   n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+    // res:  [n_embd, n_head,      n_batch,   1] !! permuted !!
+    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * mask,
+            float                 scale,
+            float                 max_bias);
+
+    GGML_API void ggml_flash_attn_ext_set_prec(
+            struct ggml_tensor * a,
+            enum ggml_prec       prec);
+
+    // TODO: needs to be adapted to ggml_flash_attn_ext
+    GGML_API struct ggml_tensor * ggml_flash_attn_back(
+           struct ggml_context * ctx,
+           struct ggml_tensor  * q,
+           struct ggml_tensor  * k,
+           struct ggml_tensor  * v,
+           struct ggml_tensor  * d,
+           bool                  masked);
+
+    GGML_API struct ggml_tensor * ggml_ssm_conv(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * s,
+            struct ggml_tensor  * x,
+            struct ggml_tensor  * c,
+            struct ggml_tensor  * sq);
+
+    GGML_API struct ggml_tensor * ggml_ssm_scan(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * s,
+            struct ggml_tensor  * x,
+            struct ggml_tensor  * dt,
+            struct ggml_tensor  * A,
+            struct ggml_tensor  * B,
+            struct ggml_tensor  * C,
+            struct ggml_tensor  * sq);
+
+    // partition into non-overlapping windows with padding if needed
+    // example:
+    // a:   768   64   64    1
+    // w:    14
+    // res: 768   14   14    25
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_win_part(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   w);
+
+    // reverse of ggml_win_part
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_win_unpart(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   w0,
+            int                   h0,
+            int                   w);
+
+    GGML_API struct ggml_tensor * ggml_unary(
+            struct ggml_context * ctx,
+             struct ggml_tensor * a,
+             enum ggml_unary_op op);
+
+    GGML_API struct ggml_tensor * ggml_unary_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op);
+
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_get_rel_pos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   qh,
+            int                   kh);
+
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_add_rel_pos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
+    GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
+    // custom operators
+
+    typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
+    typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
+
+    typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
+    typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+    typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
+            struct ggml_context        * ctx,
+            struct ggml_tensor         * a,
+                   ggml_unary_op_f32_t   fun),
+        "use ggml_map_custom1 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
+            struct ggml_context        * ctx,
+            struct ggml_tensor         * a,
+                   ggml_unary_op_f32_t   fun),
+        "use ggml_map_custom1_inplace instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+                   ggml_binary_op_f32_t   fun),
+        "use ggml_map_custom2 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+                   ggml_binary_op_f32_t   fun),
+        "use ggml_map_custom2_inplace instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+                   ggml_custom1_op_f32_t   fun),
+        "use ggml_map_custom1 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+                   ggml_custom1_op_f32_t   fun),
+        "use ggml_map_custom1_inplace instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+                   ggml_custom2_op_f32_t   fun),
+        "use ggml_map_custom2 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+                   ggml_custom2_op_f32_t   fun),
+        "use ggml_map_custom2_inplace instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+            struct ggml_tensor           * c,
+                   ggml_custom3_op_f32_t   fun),
+        "use ggml_map_custom3 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+            struct ggml_tensor           * c,
+                   ggml_custom3_op_f32_t   fun),
+        "use ggml_map_custom3_inplace instead");
+
+    // custom operators v2
+
+    typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
+    typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
+    typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
+
+    #define GGML_N_TASKS_MAX -1
+
+    GGML_API struct ggml_tensor * ggml_map_custom1(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            ggml_custom1_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom1_inplace(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            ggml_custom1_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom2(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            struct ggml_tensor    * b,
+            ggml_custom2_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom2_inplace(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            struct ggml_tensor    * b,
+            ggml_custom2_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom3(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            struct ggml_tensor    * b,
+            struct ggml_tensor    * c,
+            ggml_custom3_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom3_inplace(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            struct ggml_tensor    * b,
+            struct ggml_tensor    * c,
+            ggml_custom3_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    // loss function
+
+    GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b);
+
+    GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+            struct ggml_tensor          * c);
+
+    //
+    // automatic differentiation
+    //
+
+    GGML_API void ggml_set_param(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * tensor);
+
+
+    GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
+
+    // graph allocation in a context
+    GGML_API struct ggml_cgraph * ggml_new_graph         (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
+    GGML_API struct ggml_cgraph * ggml_new_graph_custom  (struct ggml_context * ctx, size_t size, bool grads);
+    GGML_API struct ggml_cgraph * ggml_graph_dup         (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+    GGML_API struct ggml_cgraph   ggml_graph_view        (struct ggml_cgraph * cgraph, int i0, int i1);
+    GGML_API void                 ggml_graph_cpy         (struct ggml_cgraph * src, struct ggml_cgraph * dst);
+    GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph);  // zero grads
+    GGML_API void                 ggml_graph_clear       (struct ggml_cgraph * cgraph);
+
+    GGML_API size_t ggml_graph_overhead(void);
+    GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
+
+    // ggml_graph_plan() has to be called before ggml_graph_compute()
+    // when plan.work_size > 0, caller must allocate memory for plan.work_data
+    GGML_API struct ggml_cplan ggml_graph_plan            (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
+    GGML_API enum ggml_status  ggml_graph_compute         (      struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+    // same as ggml_graph_compute() but the work data is allocated as a part of the context
+    // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
+    GGML_API enum ggml_status  ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
+
+    GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
+
+    GGML_API void                 ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
+    GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
+
+    // print info and performance information for the graph
+    GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
+
+    // dump the graph into a file using the dot format
+    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+
+    // build gradient checkpointing backward graph gb for gf using provided checkpoints
+    // gb_tmp will contain original backward graph with rewritten backward process nodes,
+    // but without the second forward pass nodes.
+    GGML_API void ggml_build_backward_gradient_checkpointing(
+            struct ggml_context   * ctx,
+            struct ggml_cgraph    * gf,
+            struct ggml_cgraph    * gb,
+            struct ggml_cgraph    * gb_tmp,
+            struct ggml_tensor  * * checkpoints,
+            int                     n_checkpoints);
+    //
+    // optimization
+    //
+
+    // optimization methods
+    enum ggml_opt_type {
+        GGML_OPT_TYPE_ADAM,
+        GGML_OPT_TYPE_LBFGS,
+    };
+
+    // linesearch methods
+    enum ggml_linesearch {
+        GGML_LINESEARCH_DEFAULT = 1,
+
+        GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
+        GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
+        GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
+    };
+
+    // optimization return values
+    enum ggml_opt_result {
+        GGML_OPT_RESULT_OK = 0,
+        GGML_OPT_RESULT_DID_NOT_CONVERGE,
+        GGML_OPT_RESULT_NO_CONTEXT,
+        GGML_OPT_RESULT_INVALID_WOLFE,
+        GGML_OPT_RESULT_FAIL,
+        GGML_OPT_RESULT_CANCEL,
+
+        GGML_LINESEARCH_FAIL = -128,
+        GGML_LINESEARCH_MINIMUM_STEP,
+        GGML_LINESEARCH_MAXIMUM_STEP,
+        GGML_LINESEARCH_MAXIMUM_ITERATIONS,
+        GGML_LINESEARCH_INVALID_PARAMETERS,
+    };
+
+    typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
+    typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
+
+    // optimization parameters
+    //
+    //   see ggml.c (ggml_opt_default_params) for default values
+    //
+    struct ggml_opt_params {
+        enum ggml_opt_type type;
+
+        size_t graph_size;
+
+        int n_threads;
+
+        // delta-based convergence test
+        //
+        //   if past == 0 - disabled
+        //   if past > 0:
+        //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+        //
+        int past;
+        float delta;
+
+        // maximum number of iterations without improvement
+        //
+        //   if 0 - disabled
+        //   if > 0:
+        //     assume convergence if no cost improvement in this number of iterations
+        //
+        int max_no_improvement;
+
+        bool print_forward_graph;
+        bool print_backward_graph;
+
+        int n_gradient_accumulation;
+
+        // ADAM parameters
+        struct {
+            int n_iter;
+
+            float sched; // schedule multiplier (fixed, decay or warmup)
+            float decay; // weight decay for AdamW, use 0.0f to disable
+            int   decay_min_ndim; // minimum number of tensor dimension to apply weight decay
+            float alpha; // learning rate
+            float beta1;
+            float beta2;
+            float eps;   // epsilon for numerical stability
+            float eps_f; // epsilon for convergence test
+            float eps_g; // epsilon for convergence test
+            float gclip; // gradient clipping
+        } adam;
+
+        // LBFGS parameters
+        struct {
+            int m; // number of corrections to approximate the inv. Hessian
+            int n_iter;
+            int max_linesearch;
+
+            float eps;      // convergence tolerance
+            float ftol;     // line search tolerance
+            float wolfe;
+            float min_step;
+            float max_step;
+
+            enum ggml_linesearch linesearch;
+        } lbfgs;
+    };
+
+    struct ggml_opt_context {
+        struct ggml_context * ctx;
+        struct ggml_opt_params params;
+
+        int iter;
+        int64_t nx; // number of parameter elements
+
+        bool just_initialized;
+
+        float loss_before;
+        float loss_after;
+
+        struct {
+            struct ggml_tensor * g;  // current gradient
+            struct ggml_tensor * m;  // first moment
+            struct ggml_tensor * v;  // second moment
+            struct ggml_tensor * pf; // past function values
+            float fx_best;
+            float fx_prev;
+            int n_no_improvement;
+        } adam;
+
+        struct {
+            struct ggml_tensor * x;    // current parameters
+            struct ggml_tensor * xp;   // previous parameters
+            struct ggml_tensor * g;    // current gradient
+            struct ggml_tensor * gp;   // previous gradient
+            struct ggml_tensor * d;    // search direction
+            struct ggml_tensor * pf;   // past function values
+            struct ggml_tensor * lmal; // the L-BFGS memory alpha
+            struct ggml_tensor * lmys; // the L-BFGS memory ys
+            struct ggml_tensor * lms;  // the L-BFGS memory s
+            struct ggml_tensor * lmy;  // the L-BFGS memory y
+            float fx_best;
+            float step;
+            int j;
+            int k;
+            int end;
+            int n_no_improvement;
+        } lbfgs;
+    };
+
+    GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+
+    // optimize the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt(
+            struct ggml_context * ctx,
+            struct ggml_opt_params params,
+            struct ggml_tensor * f);
+
+    // initialize optimizer context
+    GGML_API void ggml_opt_init(
+            struct ggml_context     * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_opt_params    params,
+            int64_t                   nx);
+
+    // continue optimizing the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt_resume(
+            struct ggml_context * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_tensor * f);
+
+    // continue optimizing the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt_resume_g(
+            struct ggml_context * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_tensor * f,
+            struct ggml_cgraph * gf,
+            struct ggml_cgraph * gb,
+            ggml_opt_callback callback,
+            void * callback_data);
+
+    //
+    // tensor flags
+    //
+    GGML_API void ggml_set_input(struct ggml_tensor * tensor);
+    GGML_API void ggml_set_output(struct ggml_tensor * tensor);
+
+    //
+    // quantization
+    //
+
+    // - ggml_quantize_init can be called multiple times with the same type
+    //   it will only initialize the quantization tables for the first call or after ggml_quantize_free
+    //   automatically called by ggml_quantize_chunk for convenience
+    //
+    // - ggml_quantize_free will free any memory allocated by ggml_quantize_init
+    //   call this at the end of the program to avoid memory leaks
+    //
+    // note: these are thread-safe
+    //
+    GGML_API void ggml_quantize_init(enum ggml_type type);
+    GGML_API void ggml_quantize_free(void);
+
+    // some quantization type cannot be used without an importance matrix
+    GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type);
+
+    // calls ggml_quantize_init internally (i.e. can allocate memory)
+    GGML_API size_t ggml_quantize_chunk(
+            enum ggml_type   type,
+               const float * src,
+                      void * dst,
+                   int64_t   start,
+                   int64_t   nrows,
+                   int64_t   n_per_row,
+               const float * imatrix);
+
+    //
+    // gguf
+    //
+
+    enum gguf_type {
+        GGUF_TYPE_UINT8   = 0,
+        GGUF_TYPE_INT8    = 1,
+        GGUF_TYPE_UINT16  = 2,
+        GGUF_TYPE_INT16   = 3,
+        GGUF_TYPE_UINT32  = 4,
+        GGUF_TYPE_INT32   = 5,
+        GGUF_TYPE_FLOAT32 = 6,
+        GGUF_TYPE_BOOL    = 7,
+        GGUF_TYPE_STRING  = 8,
+        GGUF_TYPE_ARRAY   = 9,
+        GGUF_TYPE_UINT64  = 10,
+        GGUF_TYPE_INT64   = 11,
+        GGUF_TYPE_FLOAT64 = 12,
+        GGUF_TYPE_COUNT,       // marks the end of the enum
+    };
+
+    struct gguf_context;
+
+    struct gguf_init_params {
+        bool no_alloc;
+
+        // if not NULL, create a ggml_context and allocate the tensor data in it
+        struct ggml_context ** ctx;
+    };
+
+    GGML_API struct gguf_context * gguf_init_empty(void);
+    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
+    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
+
+    GGML_API void gguf_free(struct gguf_context * ctx);
+
+    GGML_API const char * gguf_type_name(enum gguf_type type);
+
+    GGML_API int    gguf_get_version    (const struct gguf_context * ctx);
+    GGML_API size_t gguf_get_alignment  (const struct gguf_context * ctx);
+    GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
+    GGML_API void * gguf_get_data       (const struct gguf_context * ctx);
+
+    GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
+    GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
+
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
+
+    // will abort if the wrong type is used for the key
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
+    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
+    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
+
+    GGML_API int            gguf_get_n_tensors    (const struct gguf_context * ctx);
+    GGML_API int            gguf_find_tensor      (const struct gguf_context * ctx, const char * name);
+    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
+    GGML_API char *         gguf_get_tensor_name  (const struct gguf_context * ctx, int i);
+    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int i);
+
+    // removes key if it exists
+    GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
+
+    // overrides existing values or adds a new one
+    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);
+    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t   val);
+    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
+    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t  val);
+    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
+    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);
+    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);
+    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
+    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t  val);
+    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double   val);
+    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);
+    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
+    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
+    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
+
+    // set or add KV pairs from another context
+    GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
+
+    // manage tensor info
+    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
+    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
+    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
+
+    // writing gguf files can be done in 2 ways:
+    //
+    // - write the entire gguf_context to a binary file in a single pass:
+    //
+    //   gguf_write_to_file(ctx, fname);
+    //
+    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
+    //
+    //   FILE * f = fopen(fname, "wb");
+    //   fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
+    //   fwrite(f, ...);
+    //   void * data = gguf_meta_get_meta_data(ctx);
+    //   fseek(f, 0, SEEK_SET);
+    //   fwrite(f, data, gguf_get_meta_size(ctx));
+    //   free(data);
+    //   fclose(f);
+    //
+
+    // write the entire context to a binary file
+    GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
+
+    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
+    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
+    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
+
+    //
+    // system info
+    //
+
+    GGML_API int ggml_cpu_has_avx        (void);
+    GGML_API int ggml_cpu_has_avx_vnni   (void);
+    GGML_API int ggml_cpu_has_avx2       (void);
+    GGML_API int ggml_cpu_has_avx512     (void);
+    GGML_API int ggml_cpu_has_avx512_vbmi(void);
+    GGML_API int ggml_cpu_has_avx512_vnni(void);
+    GGML_API int ggml_cpu_has_avx512_bf16(void);
+    GGML_API int ggml_cpu_has_fma        (void);
+    GGML_API int ggml_cpu_has_neon       (void);
+    GGML_API int ggml_cpu_has_sve        (void);
+    GGML_API int ggml_cpu_has_arm_fma    (void);
+    GGML_API int ggml_cpu_has_metal      (void);
+    GGML_API int ggml_cpu_has_f16c       (void);
+    GGML_API int ggml_cpu_has_fp16_va    (void);
+    GGML_API int ggml_cpu_has_wasm_simd  (void);
+    GGML_API int ggml_cpu_has_blas       (void);
+    GGML_API int ggml_cpu_has_cuda       (void);
+    GGML_API int ggml_cpu_has_vulkan     (void);
+    GGML_API int ggml_cpu_has_kompute    (void);
+    GGML_API int ggml_cpu_has_gpublas    (void);
+    GGML_API int ggml_cpu_has_sse3       (void);
+    GGML_API int ggml_cpu_has_ssse3      (void);
+    GGML_API int ggml_cpu_has_sycl       (void);
+    GGML_API int ggml_cpu_has_rpc        (void);
+    GGML_API int ggml_cpu_has_vsx        (void);
+    GGML_API int ggml_cpu_has_matmul_int8(void);
+
+    //
+    // Internal types and functions exposed for tests and benchmarks
+    //
+
+#ifdef  __cplusplus
+// restrict not standard in C++
+#define GGML_RESTRICT
+#else
+#define GGML_RESTRICT restrict
+#endif
+    typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+    typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int64_t k);
+    typedef void (*ggml_vec_dot_t)   (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
+                                      const void * GGML_RESTRICT y, size_t by, int nrc);
+
+    typedef struct {
+        const char      * type_name;
+        int               blck_size;
+        size_t            type_size;
+        bool              is_quantized;
+        ggml_to_float_t   to_float;
+        ggml_from_float_t from_float;
+        ggml_from_float_t from_float_reference;
+        ggml_vec_dot_t    vec_dot;
+        enum ggml_type    vec_dot_type;
+        int64_t           nrows; // number of rows to process simultaneously;
+    } ggml_type_traits_t;
+
+    GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/include/ggml/ggml-alloc.h b/include/ggml/ggml-alloc.h
deleted file mode 100644 (file)
index 434c13b..0000000
+++ /dev/null
@@ -1,76 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
-typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
-typedef struct ggml_backend * ggml_backend_t;
-
-// Tensor allocator
-struct ggml_tallocr {
-    ggml_backend_buffer_t buffer;
-    void * base;
-    size_t alignment;
-    size_t offset;
-};
-
-GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
-GGML_API void                ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
-
-// Graph allocator
-/*
-  Example usage:
-    ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type());
-
-    // optional: create a worst-case graph and reserve the buffers to avoid reallocations
-    ggml_gallocr_reserve(galloc, build_graph(max_batch));
-
-    // allocate the graph
-    struct ggml_cgraph * graph = build_graph(batch);
-    ggml_gallocr_alloc_graph(galloc, graph);
-
-    printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
-
-    // evaluate the graph
-    ggml_backend_graph_compute(backend, graph);
-*/
-
-// special tensor flags for use with the graph allocator:
-//   ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
-//   ggml_set_output(): output tensors are never freed and never overwritten
-
-typedef struct ggml_gallocr * ggml_gallocr_t;
-
-GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
-GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
-GGML_API void           ggml_gallocr_free(ggml_gallocr_t galloc);
-
-// pre-allocate buffers from a measure graph - does not allocate or modify the graph
-// call with a worst-case graph to avoid buffer reallocations
-// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
-// returns false if the buffer allocation failed
-GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
-GGML_API bool ggml_gallocr_reserve_n(
-    ggml_gallocr_t galloc,
-    struct ggml_cgraph * graph,
-    const int * node_buffer_ids,
-    const int * leaf_buffer_ids);
-
-// automatic reallocation if the topology changes when using a single buffer
-// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
-GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
-
-GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
-
-// Utils
-// Create a buffer and allocate all the tensors in a ggml_context
-GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
-GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
-
-#ifdef  __cplusplus
-}
-#endif
diff --git a/include/ggml/ggml-backend.h b/include/ggml/ggml-backend.h
deleted file mode 100644 (file)
index 4a38eeb..0000000
+++ /dev/null
@@ -1,236 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-#include "ggml-alloc.h"
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
-    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
-    typedef struct ggml_backend_event * ggml_backend_event_t;
-    typedef struct ggml_backend * ggml_backend_t;
-    typedef void * ggml_backend_graph_plan_t;
-
-    //
-    // Backend buffer
-    //
-
-    // buffer type
-    GGML_API           const char *          ggml_backend_buft_name            (ggml_backend_buffer_type_t buft);
-    GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer    (ggml_backend_buffer_type_t buft, size_t size);
-    GGML_API           size_t                ggml_backend_buft_get_alignment   (ggml_backend_buffer_type_t buft);
-    GGML_API           size_t                ggml_backend_buft_get_max_size    (ggml_backend_buffer_type_t buft);
-    GGML_API GGML_CALL size_t                ggml_backend_buft_get_alloc_size  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
-    GGML_API           bool                  ggml_backend_buft_is_host         (ggml_backend_buffer_type_t buft);
-
-    // buffer
-    enum ggml_backend_buffer_usage {
-        GGML_BACKEND_BUFFER_USAGE_ANY = 0,
-        GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
-    };
-
-    GGML_API           const char *               ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);
-    GGML_API           void                       ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
-    GGML_API           void *                     ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                     ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
-    GGML_API GGML_CALL void                       ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API           size_t                     ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                     ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                     ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API           void                       ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
-    GGML_API           bool                       ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
-    GGML_API           void                       ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
-    GGML_API           ggml_backend_buffer_type_t ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);
-    GGML_API           void                       ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);
-
-    //
-    // Backend
-    //
-
-    GGML_API ggml_guid_t  ggml_backend_guid(ggml_backend_t backend);
-    GGML_API const char * ggml_backend_name(ggml_backend_t backend);
-    GGML_API void         ggml_backend_free(ggml_backend_t backend);
-
-    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
-    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
-    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);
-    GGML_API size_t                     ggml_backend_get_max_size(ggml_backend_t backend);
-
-    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-
-    GGML_API GGML_CALL void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-    GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-
-    GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
-
-    GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
-    GGML_API void                      ggml_backend_graph_plan_free  (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-
-    GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-    GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
-    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
-    GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
-    GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
-    GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
-
-    // tensor copy between different backends
-    GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
-
-    // asynchronous copy
-    // the copy is performed after all the currently queued operations in backend_src
-    // backend_dst will wait for the copy to complete before performing other operations
-    // automatic fallback to sync copy if async is not supported
-    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
-
-    // events
-    GGML_API ggml_backend_event_t   ggml_backend_event_new        (ggml_backend_t backend);
-    GGML_API void                   ggml_backend_event_free       (ggml_backend_event_t event);
-    GGML_API void                   ggml_backend_event_record     (ggml_backend_event_t event);
-    GGML_API void                   ggml_backend_event_synchronize(ggml_backend_event_t event);
-    GGML_API void                   ggml_backend_event_wait       (ggml_backend_t backend, ggml_backend_event_t event);
-
-    //
-    // CPU backend
-    //
-
-    GGML_API ggml_backend_t ggml_backend_cpu_init(void);
-
-    GGML_API GGML_CALL bool ggml_backend_is_cpu                (ggml_backend_t backend);
-    GGML_API           void ggml_backend_cpu_set_n_threads     (ggml_backend_t backend_cpu, int n_threads);
-    GGML_API           void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
-
-    // Create a backend buffer from an existing pointer
-    GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
-
-    GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
-
-#ifdef GGML_USE_CPU_HBM
-    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
-#endif
-
-    //
-    // Backend registry
-    //
-
-    // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
-
-    GGML_API size_t                     ggml_backend_reg_get_count(void);
-    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
-    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
-    GGML_API const char *               ggml_backend_reg_get_name(size_t i);
-    GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
-    GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
-    GGML_API ggml_backend_buffer_t      ggml_backend_reg_alloc_buffer(size_t i, size_t size);
-
-    //
-    // Backend scheduler
-    //
-
-    // The backend scheduler allows for multiple backends to be used together
-    // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
-    // The backends are selected based on:
-    // - the backend that supports the operation
-    // - the location of the pre-allocated tensors (e.g. the weights)
-    /*
-      Example usage:
-
-        // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
-        // preferrably to run on the same backend as the buffer
-        ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
-
-        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
-
-        // initialize buffers from a max size graph (optional)
-        reserve_graph = build_graph(sched, max_batch_size);
-
-        // manually assign nodes to a backend (optional, should not be needed in most cases)
-        struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
-        ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
-
-        ggml_backend_sched_reserve(sched, reserve_graph);
-
-        // compute
-        graph = build_graph(sched);
-        ggml_backend_sched_graph_compute(sched, graph);
-
-        // if there are graph inputs:
-        ggml_backend_sched_reset(sched);
-        ggml_backend_sched_alloc_graph(sched, graph);
-        ggml_backend_tensor_set(input_tensor, ...);
-        ggml_backend_sched_graph_compute(sched, graph);
-    }
-    */
-
-    struct ggml_backend_sched;
-    typedef struct ggml_backend_sched * ggml_backend_sched_t;
-
-    // when ask == true, the scheduler wants to know if the user wants to observe this node
-    // this allows the scheduler to batch nodes together in order to evaluate them in a single call
-    //
-    // when ask == false, the scheduler is passing the node tensor to the user for observation
-    // if the user returns false, the scheduler will cancel the graph compute
-    //
-    typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
-
-    // Initialize a backend scheduler
-    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
-    GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
-
-    // Initialize backend buffers from a measure graph
-    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
-
-    GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
-    GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
-
-    // Get the number of splits of the last graph
-    GGML_API int                  ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
-    GGML_API int                  ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
-
-    GGML_API size_t               ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
-
-    GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
-    GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
-
-    // Allocate and compute graph on the backend scheduler
-    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
-    GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
-    GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
-    GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
-
-    // Reset all assignments and allocators - must be called before changing the node backends
-    GGML_API void                 ggml_backend_sched_reset(ggml_backend_sched_t sched);
-
-    // Set a callback to be called for each resulting node during graph compute
-    GGML_API void                 ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
-
-    //
-    // Utils
-    //
-
-    struct ggml_backend_graph_copy {
-        ggml_backend_buffer_t buffer;
-        struct ggml_context * ctx_allocated;
-        struct ggml_context * ctx_unallocated;
-        struct ggml_cgraph * graph;
-    };
-
-    // Copy a graph to a different backend
-    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
-    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
-
-    typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
-
-    // Compare the output of two backends
-    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
-
-    // Tensor initialization
-    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
-    GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
-
-
-#ifdef  __cplusplus
-}
-#endif
diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h
deleted file mode 100644 (file)
index 13502a3..0000000
+++ /dev/null
@@ -1,2452 +0,0 @@
-#pragma once
-
-//
-// GGML Tensor Library
-//
-// This documentation is still a work in progress.
-// If you wish some specific topics to be covered, feel free to drop a comment:
-//
-//   https://github.com/ggerganov/whisper.cpp/issues/40
-//
-// ## Overview
-//
-// This library implements:
-//
-//  - a set of tensor operations
-//  - automatic differentiation
-//  - basic optimization algorithms
-//
-// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
-// but is not limited to, the following:
-//
-//  - linear regression
-//  - support vector machines
-//  - neural networks
-//
-// The library allows the user to define a certain function using the available tensor operations. This function
-// definition is represented internally via a computation graph. Each tensor operation in the function definition
-// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
-// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
-// using one of the available optimization algorithms.
-//
-// For example, here we define the function: f(x) = a*x^2 + b
-//
-//   {
-//       struct ggml_init_params params = {
-//           .mem_size   = 16*1024*1024,
-//           .mem_buffer = NULL,
-//       };
-//
-//       // memory allocation happens here
-//       struct ggml_context * ctx = ggml_init(params);
-//
-//       struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
-//
-//       ggml_set_param(ctx, x); // x is an input variable
-//
-//       struct ggml_tensor * a  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
-//       struct ggml_tensor * b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
-//       struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
-//       struct ggml_tensor * f  = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
-//
-//       ...
-//   }
-//
-// Notice that the function definition above does not involve any actual computation. The computation is performed only
-// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
-//
-//   {
-//       ...
-//
-//       struct ggml_cgraph * gf = ggml_new_graph(ctx);
-//       ggml_build_forward_expand(gf, f);
-//
-//       // set the input variable and parameter values
-//       ggml_set_f32(x, 2.0f);
-//       ggml_set_f32(a, 3.0f);
-//       ggml_set_f32(b, 4.0f);
-//
-//       ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
-//
-//       printf("f = %f\n", ggml_get_f32_1d(f, 0));
-//
-//       ...
-//   }
-//
-// The actual computation is performed in the ggml_graph_compute() function.
-//
-// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
-// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
-// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
-// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
-// actually needed.
-//
-// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
-// differentiation and optimization algorithms.
-//
-// The described approach allows to define the function graph once and then compute its forward or backward graphs
-// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
-// the user can avoid the memory allocation overhead at runtime.
-//
-// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
-// citizens, but in theory the library can be extended to support FP8 and integer data types.
-//
-// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
-// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
-// clear that the library needs to support more complex operations. The way to support these operations is not clear
-// yet, but a few examples are demonstrated in the following operations:
-//
-//   - ggml_permute()
-//   - ggml_conv_1d_1s()
-//   - ggml_conv_1d_2s()
-//
-// For each tensor operator, the library implements a forward and backward computation function. The forward function
-// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
-// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
-// calculus class, or watch the following video:
-//
-//   What is Automatic Differentiation?
-//   https://www.youtube.com/watch?v=wG_nF1awSSY
-//
-//
-// ## Tensor data (struct ggml_tensor)
-//
-// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
-// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
-// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
-//
-//   {
-//       struct ggml_tensor * c = ggml_add(ctx, a, b);
-//
-//       assert(c->src[0] == a);
-//       assert(c->src[1] == b);
-//   }
-//
-// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
-// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
-// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
-// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
-// contiguous in memory.
-//
-// The data of the tensor is accessed via the "data" pointer. For example:
-//
-//   {
-//       const int nx = 2;
-//       const int ny = 3;
-//
-//       struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);
-//
-//       for (int y = 0; y < ny; y++) {
-//           for (int x = 0; x < nx; x++) {
-//               *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
-//           }
-//       }
-//
-//       ...
-//   }
-//
-// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
-//
-// ## The matrix multiplication operator (ggml_mul_mat)
-//
-// TODO
-//
-//
-// ## Multi-threading
-//
-// TODO
-//
-//
-// ## Overview of ggml.c
-//
-// TODO
-//
-//
-// ## SIMD optimizations
-//
-// TODO
-//
-//
-// ## Debugging ggml
-//
-// TODO
-//
-//
-
-#ifdef GGML_SHARED
-#    if defined(_WIN32) && !defined(__MINGW32__)
-#        ifdef GGML_BUILD
-#            define GGML_API __declspec(dllexport)
-#        else
-#            define GGML_API __declspec(dllimport)
-#        endif
-#    else
-#        define GGML_API __attribute__ ((visibility ("default")))
-#    endif
-#else
-#    define GGML_API
-#endif
-
-#ifdef GGML_MULTIPLATFORM
-#    if defined(_WIN32)
-#        define GGML_CALL
-#    else
-#        define GGML_CALL __attribute__((__ms_abi__))
-#    endif
-#else
-#    define GGML_CALL
-#endif
-
-// TODO: support for clang
-#ifdef __GNUC__
-#    define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
-#elif defined(_MSC_VER)
-#    define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
-#else
-#    define GGML_DEPRECATED(func, hint) func
-#endif
-
-#ifndef __GNUC__
-#    define GGML_ATTRIBUTE_FORMAT(...)
-#elif defined(__MINGW32__)
-#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
-#else
-#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
-#endif
-
-#include <stdbool.h>
-#include <stddef.h>
-#include <stdint.h>
-#include <stdio.h>
-
-#define GGML_FILE_MAGIC   0x67676d6c // "ggml"
-#define GGML_FILE_VERSION 1
-
-#define GGML_QNT_VERSION        2    // bump this on quantization format changes
-#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
-
-#define GGML_MAX_DIMS           4
-#define GGML_MAX_PARAMS         2048
-#define GGML_MAX_CONTEXTS       64
-#define GGML_MAX_SRC            10
-#ifndef GGML_MAX_NAME
-#define GGML_MAX_NAME           64
-#endif
-#define GGML_MAX_OP_PARAMS      64
-#define GGML_DEFAULT_N_THREADS  4
-#define GGML_DEFAULT_GRAPH_SIZE 2048
-#if UINTPTR_MAX == 0xFFFFFFFF
-    #define GGML_MEM_ALIGN 4
-#else
-    #define GGML_MEM_ALIGN 16
-#endif
-
-#define GGML_EXIT_SUCCESS 0
-#define GGML_EXIT_ABORTED 1
-
-#define GGUF_MAGIC "GGUF"
-
-#define GGUF_VERSION 3
-
-#define GGUF_DEFAULT_ALIGNMENT 32
-
-#define GGML_UNUSED(x) (void)(x)
-
-#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
-
-#define GGML_ASSERT(x) \
-    do { \
-        if (!(x)) { \
-            fflush(stdout); \
-            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
-            ggml_print_backtrace(); \
-            abort(); \
-        } \
-    } while (0)
-
-#ifndef NDEBUG
-#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
-#elif defined(__GNUC__)
-#define GGML_UNREACHABLE() __builtin_unreachable()
-#elif defined(_MSC_VER)
-#define GGML_UNREACHABLE() __assume(0)
-#else
-#define GGML_UNREACHABLE() ((void) 0)
-#endif
-
-// used to copy the number of elements and stride in bytes of tensors into local variables.
-// main purpose is to reduce code duplication and improve readability.
-//
-// example:
-//
-//    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
-//    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
-//
-#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
-    const type prefix##0 = (pointer)->array[0]; \
-    GGML_UNUSED(prefix##0);
-#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
-    GGML_TENSOR_LOCALS_1    (type, prefix, pointer, array) \
-    const type prefix##1 = (pointer)->array[1]; \
-    GGML_UNUSED(prefix##1);
-#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
-    GGML_TENSOR_LOCALS_2    (type, prefix, pointer, array) \
-    const type prefix##2 = (pointer)->array[2]; \
-    GGML_UNUSED(prefix##2);
-#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
-    GGML_TENSOR_LOCALS_3  (type, prefix, pointer, array) \
-    const type prefix##3 = (pointer)->array[3]; \
-    GGML_UNUSED(prefix##3);
-
-#define GGML_TENSOR_UNARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
-
-#define GGML_TENSOR_BINARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-    enum ggml_status {
-        GGML_STATUS_ALLOC_FAILED = -2,
-        GGML_STATUS_FAILED = -1,
-        GGML_STATUS_SUCCESS = 0,
-        GGML_STATUS_ABORTED = 1,
-    };
-
-    // get ggml_status name string
-    GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
-
-    // ieee 754-2008 half-precision float16
-    // todo: make this not an integral type
-    typedef uint16_t ggml_fp16_t;
-    GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t);
-    GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
-    GGML_API void        ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
-    GGML_API void        ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
-
-    // google brain half-precision bfloat16
-    typedef struct { uint16_t bits; } ggml_bf16_t;
-    GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
-    GGML_API float       ggml_bf16_to_fp32(ggml_bf16_t);  // consider just doing << 16
-    GGML_API void        ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
-    GGML_API void        ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
-
-    struct ggml_object;
-    struct ggml_context;
-
-    // NOTE: always add types at the end of the enum to keep backward compatibility
-    enum ggml_type {
-        GGML_TYPE_F32     = 0,
-        GGML_TYPE_F16     = 1,
-        GGML_TYPE_Q4_0    = 2,
-        GGML_TYPE_Q4_1    = 3,
-        // GGML_TYPE_Q4_2 = 4, support has been removed
-        // GGML_TYPE_Q4_3 = 5, support has been removed
-        GGML_TYPE_Q5_0    = 6,
-        GGML_TYPE_Q5_1    = 7,
-        GGML_TYPE_Q8_0    = 8,
-        GGML_TYPE_Q8_1    = 9,
-        GGML_TYPE_Q2_K    = 10,
-        GGML_TYPE_Q3_K    = 11,
-        GGML_TYPE_Q4_K    = 12,
-        GGML_TYPE_Q5_K    = 13,
-        GGML_TYPE_Q6_K    = 14,
-        GGML_TYPE_Q8_K    = 15,
-        GGML_TYPE_IQ2_XXS = 16,
-        GGML_TYPE_IQ2_XS  = 17,
-        GGML_TYPE_IQ3_XXS = 18,
-        GGML_TYPE_IQ1_S   = 19,
-        GGML_TYPE_IQ4_NL  = 20,
-        GGML_TYPE_IQ3_S   = 21,
-        GGML_TYPE_IQ2_S   = 22,
-        GGML_TYPE_IQ4_XS  = 23,
-        GGML_TYPE_I8      = 24,
-        GGML_TYPE_I16     = 25,
-        GGML_TYPE_I32     = 26,
-        GGML_TYPE_I64     = 27,
-        GGML_TYPE_F64     = 28,
-        GGML_TYPE_IQ1_M   = 29,
-        GGML_TYPE_BF16    = 30,
-        GGML_TYPE_COUNT,
-    };
-
-    // precision
-    enum ggml_prec {
-        GGML_PREC_DEFAULT,
-        GGML_PREC_F32,
-    };
-
-    enum ggml_backend_type {
-        GGML_BACKEND_TYPE_CPU = 0,
-        GGML_BACKEND_TYPE_GPU = 10,
-        GGML_BACKEND_TYPE_GPU_SPLIT = 20,
-    };
-
-    // model file types
-    enum ggml_ftype {
-        GGML_FTYPE_UNKNOWN        = -1,
-        GGML_FTYPE_ALL_F32        = 0,
-        GGML_FTYPE_MOSTLY_F16     = 1,  // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q4_0    = 2,  // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q4_1    = 3,  // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
-        GGML_FTYPE_MOSTLY_Q8_0    = 7,  // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q5_0    = 8,  // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q5_1    = 9,  // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q2_K    = 10, // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q3_K    = 11, // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q4_K    = 12, // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q5_K    = 13, // except 1d tensors
-        GGML_FTYPE_MOSTLY_Q6_K    = 14, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ2_XS  = 16, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ1_S   = 18, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ4_NL  = 19, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ3_S   = 20, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ2_S   = 21, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors
-        GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
-        GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors
-    };
-
-    // available tensor operations:
-    enum ggml_op {
-        GGML_OP_NONE = 0,
-
-        GGML_OP_DUP,
-        GGML_OP_ADD,
-        GGML_OP_ADD1,
-        GGML_OP_ACC,
-        GGML_OP_SUB,
-        GGML_OP_MUL,
-        GGML_OP_DIV,
-        GGML_OP_SQR,
-        GGML_OP_SQRT,
-        GGML_OP_LOG,
-        GGML_OP_SUM,
-        GGML_OP_SUM_ROWS,
-        GGML_OP_MEAN,
-        GGML_OP_ARGMAX,
-        GGML_OP_REPEAT,
-        GGML_OP_REPEAT_BACK,
-        GGML_OP_CONCAT,
-        GGML_OP_SILU_BACK,
-        GGML_OP_NORM, // normalize
-        GGML_OP_RMS_NORM,
-        GGML_OP_RMS_NORM_BACK,
-        GGML_OP_GROUP_NORM,
-
-        GGML_OP_MUL_MAT,
-        GGML_OP_MUL_MAT_ID,
-        GGML_OP_OUT_PROD,
-
-        GGML_OP_SCALE,
-        GGML_OP_SET,
-        GGML_OP_CPY,
-        GGML_OP_CONT,
-        GGML_OP_RESHAPE,
-        GGML_OP_VIEW,
-        GGML_OP_PERMUTE,
-        GGML_OP_TRANSPOSE,
-        GGML_OP_GET_ROWS,
-        GGML_OP_GET_ROWS_BACK,
-        GGML_OP_DIAG,
-        GGML_OP_DIAG_MASK_INF,
-        GGML_OP_DIAG_MASK_ZERO,
-        GGML_OP_SOFT_MAX,
-        GGML_OP_SOFT_MAX_BACK,
-        GGML_OP_ROPE,
-        GGML_OP_ROPE_BACK,
-        GGML_OP_CLAMP,
-        GGML_OP_CONV_TRANSPOSE_1D,
-        GGML_OP_IM2COL,
-        GGML_OP_CONV_TRANSPOSE_2D,
-        GGML_OP_POOL_1D,
-        GGML_OP_POOL_2D,
-        GGML_OP_UPSCALE, // nearest interpolate
-        GGML_OP_PAD,
-        GGML_OP_ARANGE,
-        GGML_OP_TIMESTEP_EMBEDDING,
-        GGML_OP_ARGSORT,
-        GGML_OP_LEAKY_RELU,
-
-        GGML_OP_FLASH_ATTN_EXT,
-        GGML_OP_FLASH_ATTN_BACK,
-        GGML_OP_SSM_CONV,
-        GGML_OP_SSM_SCAN,
-        GGML_OP_WIN_PART,
-        GGML_OP_WIN_UNPART,
-        GGML_OP_GET_REL_POS,
-        GGML_OP_ADD_REL_POS,
-
-        GGML_OP_UNARY,
-
-        GGML_OP_MAP_UNARY,
-        GGML_OP_MAP_BINARY,
-
-        GGML_OP_MAP_CUSTOM1_F32,
-        GGML_OP_MAP_CUSTOM2_F32,
-        GGML_OP_MAP_CUSTOM3_F32,
-
-        GGML_OP_MAP_CUSTOM1,
-        GGML_OP_MAP_CUSTOM2,
-        GGML_OP_MAP_CUSTOM3,
-
-        GGML_OP_CROSS_ENTROPY_LOSS,
-        GGML_OP_CROSS_ENTROPY_LOSS_BACK,
-
-        GGML_OP_COUNT,
-    };
-
-    enum ggml_unary_op {
-        GGML_UNARY_OP_ABS,
-        GGML_UNARY_OP_SGN,
-        GGML_UNARY_OP_NEG,
-        GGML_UNARY_OP_STEP,
-        GGML_UNARY_OP_TANH,
-        GGML_UNARY_OP_ELU,
-        GGML_UNARY_OP_RELU,
-        GGML_UNARY_OP_SIGMOID,
-        GGML_UNARY_OP_GELU,
-        GGML_UNARY_OP_GELU_QUICK,
-        GGML_UNARY_OP_SILU,
-        GGML_UNARY_OP_HARDSWISH,
-        GGML_UNARY_OP_HARDSIGMOID,
-
-        GGML_UNARY_OP_COUNT,
-    };
-
-    enum ggml_object_type {
-        GGML_OBJECT_TYPE_TENSOR,
-        GGML_OBJECT_TYPE_GRAPH,
-        GGML_OBJECT_TYPE_WORK_BUFFER
-    };
-
-    enum ggml_log_level {
-        GGML_LOG_LEVEL_ERROR = 2,
-        GGML_LOG_LEVEL_WARN  = 3,
-        GGML_LOG_LEVEL_INFO  = 4,
-        GGML_LOG_LEVEL_DEBUG = 5
-    };
-
-    enum ggml_tensor_flag {
-        GGML_TENSOR_FLAG_INPUT  = 1,
-        GGML_TENSOR_FLAG_OUTPUT = 2,
-        GGML_TENSOR_FLAG_PARAM  = 4,
-    };
-
-    // ggml object
-    struct ggml_object {
-        size_t offs;
-        size_t size;
-
-        struct ggml_object * next;
-
-        enum ggml_object_type type;
-
-        char padding[4];
-    };
-
-    static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
-
-    // n-dimensional tensor
-    struct ggml_tensor {
-        enum ggml_type         type;
-
-        GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
-
-        struct ggml_backend_buffer * buffer;
-
-        int64_t ne[GGML_MAX_DIMS]; // number of elements
-        size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
-                                   // nb[0] = ggml_type_size(type)
-                                   // nb[1] = nb[0]   * (ne[0] / ggml_blck_size(type)) + padding
-                                   // nb[i] = nb[i-1] * ne[i-1]
-
-        // compute data
-        enum ggml_op op;
-
-        // op params - allocated as int32_t for alignment
-        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
-
-        int32_t flags;
-
-        struct ggml_tensor * grad;
-        struct ggml_tensor * src[GGML_MAX_SRC];
-
-        // performance
-        int     perf_runs;
-        int64_t perf_cycles;
-        int64_t perf_time_us;
-
-        struct ggml_tensor * view_src;
-        size_t               view_offs;
-
-        void * data;
-
-        char name[GGML_MAX_NAME];
-
-        void * extra; // extra things e.g. for ggml-cuda.cu
-
-        char padding[8];
-    };
-
-    static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
-
-    // Abort callback
-    // If not NULL, called before ggml computation
-    // If it returns true, the computation is aborted
-    typedef bool (*ggml_abort_callback)(void * data);
-
-    // the compute plan that needs to be prepared for ggml_graph_compute()
-    // since https://github.com/ggerganov/ggml/issues/287
-    struct ggml_cplan {
-        size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`
-        uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
-
-        int n_threads;
-
-        // abort ggml_graph_compute when true
-        ggml_abort_callback abort_callback;
-        void *              abort_callback_data;
-    };
-
-    enum ggml_cgraph_eval_order {
-        GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
-        GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
-        GGML_CGRAPH_EVAL_ORDER_COUNT
-    };
-
-    struct ggml_hash_set {
-        size_t size;
-        struct ggml_tensor ** keys;
-    };
-
-    // computation graph
-    struct ggml_cgraph {
-        int size;
-        int n_nodes;
-        int n_leafs;
-
-        struct ggml_tensor ** nodes;
-        struct ggml_tensor ** grads;
-        struct ggml_tensor ** leafs;
-
-        struct ggml_hash_set visited_hash_table;
-
-        enum ggml_cgraph_eval_order order;
-
-        // performance
-        int     perf_runs;
-        int64_t perf_cycles;
-        int64_t perf_time_us;
-    };
-
-    // scratch buffer
-    struct ggml_scratch {
-        size_t offs;
-        size_t size;
-        void * data;
-    };
-
-    struct ggml_init_params {
-        // memory pool
-        size_t mem_size;   // bytes
-        void * mem_buffer; // if NULL, memory will be allocated internally
-        bool   no_alloc;   // don't allocate memory for the tensor data
-    };
-
-
-    // compute types
-
-    // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
-    // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
-    enum ggml_task_type {
-        GGML_TASK_TYPE_INIT = 0,
-        GGML_TASK_TYPE_COMPUTE,
-        GGML_TASK_TYPE_FINALIZE,
-    };
-
-    struct ggml_compute_params {
-        enum ggml_task_type type;
-
-        // ith = thread index, nth = number of threads
-        int ith, nth;
-
-        // work buffer for all threads
-        size_t wsize;
-        void * wdata;
-    };
-
-    // numa strategies
-    enum ggml_numa_strategy {
-        GGML_NUMA_STRATEGY_DISABLED   = 0,
-        GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
-        GGML_NUMA_STRATEGY_ISOLATE    = 2,
-        GGML_NUMA_STRATEGY_NUMACTL    = 3,
-        GGML_NUMA_STRATEGY_MIRROR     = 4,
-        GGML_NUMA_STRATEGY_COUNT
-    };
-
-    //
-    // GUID
-    //
-
-    // GUID types
-    typedef uint8_t ggml_guid[16];
-    typedef ggml_guid * ggml_guid_t;
-
-    GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b);
-
-    // misc
-
-    GGML_API void    ggml_time_init(void); // call this once at the beginning of the program
-    GGML_API int64_t ggml_time_ms(void);
-    GGML_API int64_t ggml_time_us(void);
-    GGML_API int64_t ggml_cycles(void);
-    GGML_API int64_t ggml_cycles_per_ms(void);
-
-    GGML_API void    ggml_print_backtrace(void);
-
-    // accepts a UTF-8 path, even on Windows
-    GGML_API FILE *  ggml_fopen(const char * fname, const char * mode);
-
-    GGML_API void    ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
-    GGML_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
-
-    GGML_API void    ggml_print_object (const struct ggml_object * obj);
-    GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
-
-    GGML_API GGML_CALL int64_t ggml_nelements   (const struct ggml_tensor * tensor);
-    GGML_API GGML_CALL int64_t ggml_nrows       (const struct ggml_tensor * tensor);
-    GGML_API GGML_CALL size_t  ggml_nbytes      (const struct ggml_tensor * tensor);
-    GGML_API           size_t  ggml_nbytes_pad  (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
-
-    GGML_API GGML_CALL int    ggml_blck_size(enum ggml_type type);
-    GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type);             // size in bytes for all elements in a block
-    GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
-
-    GGML_DEPRECATED(
-    GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
-    "use ggml_row_size() instead");
-
-    GGML_API GGML_CALL const char * ggml_type_name(enum ggml_type type);
-    GGML_API GGML_CALL const char * ggml_op_name  (enum ggml_op   op);
-    GGML_API           const char * ggml_op_symbol(enum ggml_op   op);
-
-    GGML_API           const char * ggml_unary_op_name(enum ggml_unary_op op);
-    GGML_API GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
-
-    GGML_API GGML_CALL size_t  ggml_element_size(const struct ggml_tensor * tensor);
-
-    GGML_API GGML_CALL bool    ggml_is_quantized(enum ggml_type type);
-
-    // TODO: temporary until model loading of ggml examples is refactored
-    GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
-
-    GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
-    GGML_API GGML_CALL bool ggml_is_permuted  (const struct ggml_tensor * tensor);
-    GGML_API GGML_CALL bool ggml_is_empty     (const struct ggml_tensor * tensor);
-    GGML_API           bool ggml_is_scalar    (const struct ggml_tensor * tensor);
-    GGML_API           bool ggml_is_vector    (const struct ggml_tensor * tensor);
-    GGML_API           bool ggml_is_matrix    (const struct ggml_tensor * tensor);
-    GGML_API           bool ggml_is_3d        (const struct ggml_tensor * tensor);
-    GGML_API           int  ggml_n_dims       (const struct ggml_tensor * tensor); // returns 1 for scalars
-
-    GGML_API GGML_CALL bool ggml_is_contiguous  (const struct ggml_tensor * tensor);
-    GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
-    GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
-    GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
-
-    GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
-    GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
-
-    // use this to compute the memory overhead of a tensor
-    GGML_API size_t ggml_tensor_overhead(void);
-
-    GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
-
-    // main
-
-    GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
-    GGML_API void                  ggml_free(struct ggml_context * ctx);
-
-    GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
-
-    GGML_API size_t  ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
-    GGML_API bool    ggml_get_no_alloc(struct ggml_context * ctx);
-    GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
-
-    GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
-    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);
-    GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);
-
-    GGML_API struct ggml_tensor * ggml_new_tensor(
-            struct ggml_context * ctx,
-            enum   ggml_type type,
-            int    n_dims,
-            const int64_t *ne);
-
-    GGML_API struct ggml_tensor * ggml_new_tensor_1d(
-            struct ggml_context * ctx,
-            enum   ggml_type type,
-            int64_t ne0);
-
-    GGML_API struct ggml_tensor * ggml_new_tensor_2d(
-            struct ggml_context * ctx,
-            enum   ggml_type type,
-            int64_t ne0,
-            int64_t ne1);
-
-    GGML_API struct ggml_tensor * ggml_new_tensor_3d(
-            struct ggml_context * ctx,
-            enum   ggml_type type,
-            int64_t ne0,
-            int64_t ne1,
-            int64_t ne2);
-
-    GGML_API struct ggml_tensor * ggml_new_tensor_4d(
-            struct ggml_context * ctx,
-            enum   ggml_type type,
-            int64_t ne0,
-            int64_t ne1,
-            int64_t ne2,
-            int64_t ne3);
-
-    GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
-    GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
-
-    GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
-    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
-
-    // Context tensor enumeration and lookup
-    GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);
-    GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
-    GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
-
-    GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
-    GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
-    GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
-
-    // Converts a flat index into coordinates
-    GGML_API void    ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
-
-    GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
-    GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
-
-    GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
-    GGML_API void    ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
-
-    GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
-    GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
-
-    GGML_API float   ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
-    GGML_API void    ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
-
-    GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
-    GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
-
-    GGML_API GGML_CALL enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
-
-    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);
-    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);
-    GGML_ATTRIBUTE_FORMAT(2, 3)
-    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);
-
-    //
-    // operations on tensors with backpropagation
-    //
-
-    GGML_API struct ggml_tensor * ggml_dup(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_dup_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_add(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_add_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_add_cast(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            enum   ggml_type      type);
-
-    GGML_API struct ggml_tensor * ggml_add1(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_add1_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // dst = a
-    // view(dst, nb1, nb2, nb3, offset) += b
-    // return dst
-    GGML_API struct ggml_tensor * ggml_acc(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            size_t                nb1,
-            size_t                nb2,
-            size_t                nb3,
-            size_t                offset);
-
-    GGML_API struct ggml_tensor * ggml_acc_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            size_t                nb1,
-            size_t                nb2,
-            size_t                nb3,
-            size_t                offset);
-
-    GGML_API struct ggml_tensor * ggml_sub(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_sub_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_mul(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_mul_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_div(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_div_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_sqr(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_sqr_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_sqrt(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_sqrt_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_log(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_log_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // return scalar
-    GGML_API struct ggml_tensor * ggml_sum(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
-    GGML_API struct ggml_tensor * ggml_sum_rows(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // mean along rows
-    GGML_API struct ggml_tensor * ggml_mean(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // argmax along rows
-    GGML_API struct ggml_tensor * ggml_argmax(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // if a is the same shape as b, and a is not parameter, return a
-    // otherwise, return a new tensor: repeat(a) to fit in b
-    GGML_API struct ggml_tensor * ggml_repeat(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // sums repetitions in a into shape of b
-    GGML_API struct ggml_tensor * ggml_repeat_back(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // concat a and b along dim
-    // used in stable-diffusion
-    GGML_API struct ggml_tensor * ggml_concat(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   dim);
-
-    GGML_API struct ggml_tensor * ggml_abs(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_abs_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_sgn(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_sgn_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_neg(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_neg_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_step(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_step_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_tanh(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_tanh_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_elu(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_elu_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_relu(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_leaky_relu(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a, float negative_slope, bool inplace);
-
-    GGML_API struct ggml_tensor * ggml_relu_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_sigmoid(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_gelu(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_gelu_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_gelu_quick(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_silu(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    GGML_API struct ggml_tensor * ggml_silu_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // a - x
-    // b - dy
-    GGML_API struct ggml_tensor * ggml_silu_back(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // hardswish(x) = x * relu6(x + 3) / 6
-    GGML_API struct ggml_tensor * ggml_hardswish(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // hardsigmoid(x) = relu6(x + 3) / 6
-    GGML_API struct ggml_tensor * ggml_hardsigmoid(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // normalize along rows
-    GGML_API struct ggml_tensor * ggml_norm(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            float                 eps);
-
-    GGML_API struct ggml_tensor * ggml_norm_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            float                 eps);
-
-    GGML_API struct ggml_tensor * ggml_rms_norm(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            float                 eps);
-
-    GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            float                 eps);
-
-    // group normalize along ne0*ne1*n_groups
-    // used in stable-diffusion
-    // TODO: eps is hardcoded to 1e-6 for now
-    GGML_API struct ggml_tensor * ggml_group_norm(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   n_groups);
-
-    GGML_API struct ggml_tensor * ggml_group_norm_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   n_groups);
-
-    // a - x
-    // b - dy
-    GGML_API struct ggml_tensor * ggml_rms_norm_back(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            float                 eps);
-
-    // A: k columns, n rows => [ne03, ne02, n, k]
-    // B: k columns, m rows  (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
-    // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
-    GGML_API struct ggml_tensor * ggml_mul_mat(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // change the precision of a matrix multiplication
-    // set to GGML_PREC_F32 for higher precision (useful for phi-2)
-    GGML_API void ggml_mul_mat_set_prec(
-            struct ggml_tensor * a,
-            enum ggml_prec       prec);
-
-    // indirect matrix multiplication
-    GGML_API struct ggml_tensor * ggml_mul_mat_id(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * as,
-            struct ggml_tensor  * b,
-            struct ggml_tensor  * ids);
-
-    // A: m columns, n rows,
-    // B: p columns, n rows,
-    // result is m columns, p rows
-    GGML_API struct ggml_tensor * ggml_out_prod(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    //
-    // operations on tensors without backpropagation
-    //
-
-    GGML_API struct ggml_tensor * ggml_scale(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            float                 s);
-
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_scale_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            float                 s);
-
-    // b -> view(a,offset,nb1,nb2,3), return modified a
-    GGML_API struct ggml_tensor * ggml_set(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            size_t                nb1,
-            size_t                nb2,
-            size_t                nb3,
-            size_t                offset);
-
-    // b -> view(a,offset,nb1,nb2,3), return view(a)
-    GGML_API struct ggml_tensor * ggml_set_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            size_t                nb1,
-            size_t                nb2,
-            size_t                nb3,
-            size_t                offset);
-
-    GGML_API struct ggml_tensor * ggml_set_1d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            size_t                offset);
-
-    GGML_API struct ggml_tensor * ggml_set_1d_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            size_t                offset);
-
-    // b -> view(a,offset,nb1,nb2,3), return modified a
-    GGML_API struct ggml_tensor * ggml_set_2d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            size_t                nb1,
-            size_t                offset);
-
-    // b -> view(a,offset,nb1,nb2,3), return view(a)
-    GGML_API struct ggml_tensor * ggml_set_2d_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            size_t                nb1,
-            size_t                offset);
-
-    // a -> b, return view(b)
-    GGML_API struct ggml_tensor * ggml_cpy(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_cast(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            enum   ggml_type      type);
-
-    // make contiguous
-    GGML_API struct ggml_tensor * ggml_cont(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // make contiguous, with new shape
-    GGML_API struct ggml_tensor * ggml_cont_1d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0);
-
-    GGML_API struct ggml_tensor * ggml_cont_2d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1);
-
-    GGML_API struct ggml_tensor * ggml_cont_3d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1,
-            int64_t               ne2);
-
-    GGML_API struct ggml_tensor * ggml_cont_4d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1,
-            int64_t               ne2,
-            int64_t               ne3);
-
-    // return view(a), b specifies the new shape
-    // TODO: when we start computing gradient, make a copy instead of view
-    GGML_API struct ggml_tensor * ggml_reshape(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // return view(a)
-    // TODO: when we start computing gradient, make a copy instead of view
-    GGML_API struct ggml_tensor * ggml_reshape_1d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0);
-
-    GGML_API struct ggml_tensor * ggml_reshape_2d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1);
-
-    // return view(a)
-    // TODO: when we start computing gradient, make a copy instead of view
-    GGML_API struct ggml_tensor * ggml_reshape_3d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1,
-            int64_t               ne2);
-
-    GGML_API struct ggml_tensor * ggml_reshape_4d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1,
-            int64_t               ne2,
-            int64_t               ne3);
-
-    // offset in bytes
-    GGML_API struct ggml_tensor * ggml_view_1d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            size_t                offset);
-
-    GGML_API struct ggml_tensor * ggml_view_2d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1,
-            size_t                nb1, // row stride in bytes
-            size_t                offset);
-
-    GGML_API struct ggml_tensor * ggml_view_3d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1,
-            int64_t               ne2,
-            size_t                nb1, // row   stride in bytes
-            size_t                nb2, // slice stride in bytes
-            size_t                offset);
-
-    GGML_API struct ggml_tensor * ggml_view_4d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int64_t               ne0,
-            int64_t               ne1,
-            int64_t               ne2,
-            int64_t               ne3,
-            size_t                nb1, // row   stride in bytes
-            size_t                nb2, // slice stride in bytes
-            size_t                nb3,
-            size_t                offset);
-
-    GGML_API struct ggml_tensor * ggml_permute(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   axis0,
-            int                   axis1,
-            int                   axis2,
-            int                   axis3);
-
-    // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
-    GGML_API struct ggml_tensor * ggml_transpose(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // supports 3D: a->ne[2] == b->ne[1]
-    GGML_API struct ggml_tensor * ggml_get_rows(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_get_rows_back(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            struct ggml_tensor  * c);
-
-    GGML_API struct ggml_tensor * ggml_diag(
-        struct ggml_context     * ctx,
-        struct ggml_tensor      * a);
-
-    // set elements above the diagonal to -INF
-    GGML_API struct ggml_tensor * ggml_diag_mask_inf(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   n_past);
-
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   n_past);
-
-    // set elements above the diagonal to 0
-    GGML_API struct ggml_tensor * ggml_diag_mask_zero(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   n_past);
-
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   n_past);
-
-    GGML_API struct ggml_tensor * ggml_soft_max(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_soft_max_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a);
-
-    // fused soft_max(a*scale + mask*(ALiBi slope))
-    // mask is optional
-    // max_bias = 0.0f for no ALiBi
-    GGML_API struct ggml_tensor * ggml_soft_max_ext(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * mask,
-            float                 scale,
-            float                 max_bias);
-
-    GGML_API struct ggml_tensor * ggml_soft_max_back(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // rotary position embedding
-    // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
-    // if mode & 2 == 1, GPT-NeoX style
-    //
-    // b is an int32 vector with size a->ne[2], it contains the positions
-    // c is freq factors (e.g. phi3-128k), (optional)
-    GGML_API struct ggml_tensor * ggml_rope(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   n_dims,
-            int                   mode);
-
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_rope_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   n_dims,
-            int                   mode);
-
-    // custom RoPE
-    GGML_API struct ggml_tensor * ggml_rope_ext(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            struct ggml_tensor  * c,
-            int                   n_dims,
-            int                   mode,
-            int                   n_ctx_orig,
-            float                 freq_base,
-            float                 freq_scale,
-            float                 ext_factor,
-            float                 attn_factor,
-            float                 beta_fast,
-            float                 beta_slow);
-
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            struct ggml_tensor  * c,
-            int                   n_dims,
-            int                   mode,
-            int                   n_ctx_orig,
-            float                 freq_base,
-            float                 freq_scale,
-            float                 ext_factor,
-            float                 attn_factor,
-            float                 beta_fast,
-            float                 beta_slow);
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   n_dims,
-            int                   mode,
-            int                   n_ctx_orig,
-            float                 freq_base,
-            float                 freq_scale,
-            float                 ext_factor,
-            float                 attn_factor,
-            float                 beta_fast,
-            float                 beta_slow),
-        "use ggml_rope_ext instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   n_dims,
-            int                   mode,
-            int                   n_ctx_orig,
-            float                 freq_base,
-            float                 freq_scale,
-            float                 ext_factor,
-            float                 attn_factor,
-            float                 beta_fast,
-            float                 beta_slow),
-        "use ggml_rope_ext_inplace instead");
-
-    // compute correction dims for YaRN RoPE scaling
-    GGML_CALL void ggml_rope_yarn_corr_dims(
-        int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
-
-    // rotary position embedding backward, i.e compute dx from dy
-    // a - dy
-    GGML_API struct ggml_tensor * ggml_rope_back(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            struct ggml_tensor  * c,
-            int                   n_dims,
-            int                   mode,
-            int                   n_ctx_orig,
-            float                 freq_base,
-            float                 freq_scale,
-            float                 ext_factor,
-            float                 attn_factor,
-            float                 beta_fast,
-            float                 beta_slow);
-
-    // clamp
-    // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_clamp(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            float                 min,
-            float                 max);
-
-    GGML_API struct ggml_tensor * ggml_im2col(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                  s0,
-            int                  s1,
-            int                  p0,
-            int                  p1,
-            int                  d0,
-            int                  d1,
-            bool                 is_2D,
-            enum ggml_type       dst_type);
-
-    GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                  s0,
-            int                  s1,
-            int                  p0,
-            int                  p1,
-            int                  d0,
-            int                  d1);
-
-    GGML_API struct ggml_tensor * ggml_conv_1d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   s0,  // stride
-            int                   p0,  // padding
-            int                   d0); // dilation
-
-    // conv_1d with padding = half
-    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
-    GGML_API struct ggml_tensor* ggml_conv_1d_ph(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   s,
-            int                   d);
-
-    GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   s0,
-            int                   p0,
-            int                   d0);
-
-    GGML_API struct ggml_tensor * ggml_conv_2d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   s0,
-            int                   s1,
-            int                   p0,
-            int                   p1,
-            int                   d0,
-            int                   d1);
-
-
-    // kernel size is a->ne[0] x a->ne[1]
-    // stride is equal to kernel size
-    // padding is zero
-    // example:
-    // a:     16   16    3  768
-    // b:   1024 1024    3    1
-    // res:   64   64  768    1
-    // used in sam
-    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    // kernel size is a->ne[0] x a->ne[1]
-    // stride is 1
-    // padding is half
-    // example:
-    // a:      3    3    256  256
-    // b:     64   64    256    1
-    // res:   64   64    256    1
-    // used in sam
-    GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
-
-    GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   stride);
-
-    enum ggml_op_pool {
-        GGML_OP_POOL_MAX,
-        GGML_OP_POOL_AVG,
-        GGML_OP_POOL_COUNT,
-    };
-
-    GGML_API struct ggml_tensor * ggml_pool_1d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            enum ggml_op_pool     op,
-            int                   k0, // kernel size
-            int                   s0, // stride
-            int                   p0); // padding
-
-    // the result will have 2*p0 padding for the first dimension
-    // and 2*p1 padding for the second dimension
-    GGML_API struct ggml_tensor * ggml_pool_2d(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            enum ggml_op_pool     op,
-            int                   k0,
-            int                   k1,
-            int                   s0,
-            int                   s1,
-            float                 p0,
-            float                 p1);
-
-    // nearest interpolate
-    // multiplies ne0 and ne1 by scale factor
-    // used in stable-diffusion
-    GGML_API struct ggml_tensor * ggml_upscale(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   scale_factor);
-
-    // nearest interpolate
-    // nearest interpolate to specified dimensions
-    // used in tortoise.cpp
-    GGML_API struct ggml_tensor * ggml_upscale_ext(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   ne0,
-            int                   ne1,
-            int                   ne2,
-            int                   ne3);
-
-    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
-    GGML_API struct ggml_tensor * ggml_pad(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                  p0,
-            int                  p1,
-            int                  p2,
-            int                  p3);
-
-    // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
-    // timesteps: [N,]
-    // return: [N, dim]
-    GGML_API struct ggml_tensor * ggml_timestep_embedding(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * timesteps,
-            int                   dim,
-            int                   max_period);
-
-    // sort rows
-    enum ggml_sort_order {
-        GGML_SORT_ORDER_ASC,
-        GGML_SORT_ORDER_DESC,
-    };
-
-    GGML_API struct ggml_tensor * ggml_argsort(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            enum ggml_sort_order  order);
-
-    GGML_API struct ggml_tensor * ggml_arange(
-            struct ggml_context * ctx,
-            float                 start,
-            float                 stop,
-            float                 step);
-
-    // top k elements per row
-    GGML_API struct ggml_tensor * ggml_top_k(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   k);
-
-#define GGML_KQ_MASK_PAD 32
-
-    // q:    [n_embd, n_batch,     n_head,    1]
-    // k:    [n_embd, n_kv,        n_head_kv, 1]
-    // v:    [n_embd, n_kv,        n_head_kv, 1] !! not transposed !!
-    // mask: [n_kv,   n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
-    // res:  [n_embd, n_head,      n_batch,   1] !! permuted !!
-    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * q,
-            struct ggml_tensor  * k,
-            struct ggml_tensor  * v,
-            struct ggml_tensor  * mask,
-            float                 scale,
-            float                 max_bias);
-
-    GGML_API void ggml_flash_attn_ext_set_prec(
-            struct ggml_tensor * a,
-            enum ggml_prec       prec);
-
-    // TODO: needs to be adapted to ggml_flash_attn_ext
-    GGML_API struct ggml_tensor * ggml_flash_attn_back(
-           struct ggml_context * ctx,
-           struct ggml_tensor  * q,
-           struct ggml_tensor  * k,
-           struct ggml_tensor  * v,
-           struct ggml_tensor  * d,
-           bool                  masked);
-
-    GGML_API struct ggml_tensor * ggml_ssm_conv(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * s,
-            struct ggml_tensor  * x,
-            struct ggml_tensor  * c,
-            struct ggml_tensor  * sq);
-
-    GGML_API struct ggml_tensor * ggml_ssm_scan(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * s,
-            struct ggml_tensor  * x,
-            struct ggml_tensor  * dt,
-            struct ggml_tensor  * A,
-            struct ggml_tensor  * B,
-            struct ggml_tensor  * C,
-            struct ggml_tensor  * sq);
-
-    // partition into non-overlapping windows with padding if needed
-    // example:
-    // a:   768   64   64    1
-    // w:    14
-    // res: 768   14   14    25
-    // used in sam
-    GGML_API struct ggml_tensor * ggml_win_part(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   w);
-
-    // reverse of ggml_win_part
-    // used in sam
-    GGML_API struct ggml_tensor * ggml_win_unpart(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   w0,
-            int                   h0,
-            int                   w);
-
-    GGML_API struct ggml_tensor * ggml_unary(
-            struct ggml_context * ctx,
-             struct ggml_tensor * a,
-             enum ggml_unary_op op);
-
-    GGML_API struct ggml_tensor * ggml_unary_inplace(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        enum ggml_unary_op op);
-
-    // used in sam
-    GGML_API struct ggml_tensor * ggml_get_rel_pos(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   qh,
-            int                   kh);
-
-    // used in sam
-    GGML_API struct ggml_tensor * ggml_add_rel_pos(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * pw,
-            struct ggml_tensor  * ph);
-
-    GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * pw,
-            struct ggml_tensor  * ph);
-
-    // custom operators
-
-    typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
-    typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
-
-    typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
-    typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
-    typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
-            struct ggml_context        * ctx,
-            struct ggml_tensor         * a,
-                   ggml_unary_op_f32_t   fun),
-        "use ggml_map_custom1 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
-            struct ggml_context        * ctx,
-            struct ggml_tensor         * a,
-                   ggml_unary_op_f32_t   fun),
-        "use ggml_map_custom1_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-                   ggml_binary_op_f32_t   fun),
-        "use ggml_map_custom2 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-                   ggml_binary_op_f32_t   fun),
-        "use ggml_map_custom2_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-                   ggml_custom1_op_f32_t   fun),
-        "use ggml_map_custom1 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-                   ggml_custom1_op_f32_t   fun),
-        "use ggml_map_custom1_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-                   ggml_custom2_op_f32_t   fun),
-        "use ggml_map_custom2 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-                   ggml_custom2_op_f32_t   fun),
-        "use ggml_map_custom2_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-            struct ggml_tensor           * c,
-                   ggml_custom3_op_f32_t   fun),
-        "use ggml_map_custom3 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-            struct ggml_tensor           * c,
-                   ggml_custom3_op_f32_t   fun),
-        "use ggml_map_custom3_inplace instead");
-
-    // custom operators v2
-
-    typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
-    typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
-    typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
-
-    #define GGML_N_TASKS_MAX -1
-
-    GGML_API struct ggml_tensor * ggml_map_custom1(
-            struct ggml_context   * ctx,
-            struct ggml_tensor    * a,
-            ggml_custom1_op_t       fun,
-            int                     n_tasks,
-            void                  * userdata);
-
-    GGML_API struct ggml_tensor * ggml_map_custom1_inplace(
-            struct ggml_context   * ctx,
-            struct ggml_tensor    * a,
-            ggml_custom1_op_t       fun,
-            int                     n_tasks,
-            void                  * userdata);
-
-    GGML_API struct ggml_tensor * ggml_map_custom2(
-            struct ggml_context   * ctx,
-            struct ggml_tensor    * a,
-            struct ggml_tensor    * b,
-            ggml_custom2_op_t       fun,
-            int                     n_tasks,
-            void                  * userdata);
-
-    GGML_API struct ggml_tensor * ggml_map_custom2_inplace(
-            struct ggml_context   * ctx,
-            struct ggml_tensor    * a,
-            struct ggml_tensor    * b,
-            ggml_custom2_op_t       fun,
-            int                     n_tasks,
-            void                  * userdata);
-
-    GGML_API struct ggml_tensor * ggml_map_custom3(
-            struct ggml_context   * ctx,
-            struct ggml_tensor    * a,
-            struct ggml_tensor    * b,
-            struct ggml_tensor    * c,
-            ggml_custom3_op_t       fun,
-            int                     n_tasks,
-            void                  * userdata);
-
-    GGML_API struct ggml_tensor * ggml_map_custom3_inplace(
-            struct ggml_context   * ctx,
-            struct ggml_tensor    * a,
-            struct ggml_tensor    * b,
-            struct ggml_tensor    * c,
-            ggml_custom3_op_t       fun,
-            int                     n_tasks,
-            void                  * userdata);
-
-    // loss function
-
-    GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b);
-
-    GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-            struct ggml_tensor          * c);
-
-    //
-    // automatic differentiation
-    //
-
-    GGML_API void ggml_set_param(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * tensor);
-
-
-    GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
-    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
-
-    // graph allocation in a context
-    GGML_API struct ggml_cgraph * ggml_new_graph         (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
-    GGML_API struct ggml_cgraph * ggml_new_graph_custom  (struct ggml_context * ctx, size_t size, bool grads);
-    GGML_API struct ggml_cgraph * ggml_graph_dup         (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
-    GGML_API struct ggml_cgraph   ggml_graph_view        (struct ggml_cgraph * cgraph, int i0, int i1);
-    GGML_API void                 ggml_graph_cpy         (struct ggml_cgraph * src, struct ggml_cgraph * dst);
-    GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph);  // zero grads
-    GGML_API void                 ggml_graph_clear       (struct ggml_cgraph * cgraph);
-
-    GGML_API size_t ggml_graph_overhead(void);
-    GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
-
-    // ggml_graph_plan() has to be called before ggml_graph_compute()
-    // when plan.work_size > 0, caller must allocate memory for plan.work_data
-    GGML_API struct ggml_cplan ggml_graph_plan            (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
-    GGML_API enum ggml_status  ggml_graph_compute         (      struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
-    // same as ggml_graph_compute() but the work data is allocated as a part of the context
-    // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
-    GGML_API enum ggml_status  ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
-
-    GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
-
-    GGML_API void                 ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
-    GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
-
-    // print info and performance information for the graph
-    GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
-
-    // dump the graph into a file using the dot format
-    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
-
-    // build gradient checkpointing backward graph gb for gf using provided checkpoints
-    // gb_tmp will contain original backward graph with rewritten backward process nodes,
-    // but without the second forward pass nodes.
-    GGML_API void ggml_build_backward_gradient_checkpointing(
-            struct ggml_context   * ctx,
-            struct ggml_cgraph    * gf,
-            struct ggml_cgraph    * gb,
-            struct ggml_cgraph    * gb_tmp,
-            struct ggml_tensor  * * checkpoints,
-            int                     n_checkpoints);
-    //
-    // optimization
-    //
-
-    // optimization methods
-    enum ggml_opt_type {
-        GGML_OPT_TYPE_ADAM,
-        GGML_OPT_TYPE_LBFGS,
-    };
-
-    // linesearch methods
-    enum ggml_linesearch {
-        GGML_LINESEARCH_DEFAULT = 1,
-
-        GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
-        GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
-        GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
-    };
-
-    // optimization return values
-    enum ggml_opt_result {
-        GGML_OPT_RESULT_OK = 0,
-        GGML_OPT_RESULT_DID_NOT_CONVERGE,
-        GGML_OPT_RESULT_NO_CONTEXT,
-        GGML_OPT_RESULT_INVALID_WOLFE,
-        GGML_OPT_RESULT_FAIL,
-        GGML_OPT_RESULT_CANCEL,
-
-        GGML_LINESEARCH_FAIL = -128,
-        GGML_LINESEARCH_MINIMUM_STEP,
-        GGML_LINESEARCH_MAXIMUM_STEP,
-        GGML_LINESEARCH_MAXIMUM_ITERATIONS,
-        GGML_LINESEARCH_INVALID_PARAMETERS,
-    };
-
-    typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
-    typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
-
-    // optimization parameters
-    //
-    //   see ggml.c (ggml_opt_default_params) for default values
-    //
-    struct ggml_opt_params {
-        enum ggml_opt_type type;
-
-        size_t graph_size;
-
-        int n_threads;
-
-        // delta-based convergence test
-        //
-        //   if past == 0 - disabled
-        //   if past > 0:
-        //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
-        //
-        int past;
-        float delta;
-
-        // maximum number of iterations without improvement
-        //
-        //   if 0 - disabled
-        //   if > 0:
-        //     assume convergence if no cost improvement in this number of iterations
-        //
-        int max_no_improvement;
-
-        bool print_forward_graph;
-        bool print_backward_graph;
-
-        int n_gradient_accumulation;
-
-        // ADAM parameters
-        struct {
-            int n_iter;
-
-            float sched; // schedule multiplier (fixed, decay or warmup)
-            float decay; // weight decay for AdamW, use 0.0f to disable
-            int   decay_min_ndim; // minimum number of tensor dimension to apply weight decay
-            float alpha; // learning rate
-            float beta1;
-            float beta2;
-            float eps;   // epsilon for numerical stability
-            float eps_f; // epsilon for convergence test
-            float eps_g; // epsilon for convergence test
-            float gclip; // gradient clipping
-        } adam;
-
-        // LBFGS parameters
-        struct {
-            int m; // number of corrections to approximate the inv. Hessian
-            int n_iter;
-            int max_linesearch;
-
-            float eps;      // convergence tolerance
-            float ftol;     // line search tolerance
-            float wolfe;
-            float min_step;
-            float max_step;
-
-            enum ggml_linesearch linesearch;
-        } lbfgs;
-    };
-
-    struct ggml_opt_context {
-        struct ggml_context * ctx;
-        struct ggml_opt_params params;
-
-        int iter;
-        int64_t nx; // number of parameter elements
-
-        bool just_initialized;
-
-        float loss_before;
-        float loss_after;
-
-        struct {
-            struct ggml_tensor * g;  // current gradient
-            struct ggml_tensor * m;  // first moment
-            struct ggml_tensor * v;  // second moment
-            struct ggml_tensor * pf; // past function values
-            float fx_best;
-            float fx_prev;
-            int n_no_improvement;
-        } adam;
-
-        struct {
-            struct ggml_tensor * x;    // current parameters
-            struct ggml_tensor * xp;   // previous parameters
-            struct ggml_tensor * g;    // current gradient
-            struct ggml_tensor * gp;   // previous gradient
-            struct ggml_tensor * d;    // search direction
-            struct ggml_tensor * pf;   // past function values
-            struct ggml_tensor * lmal; // the L-BFGS memory alpha
-            struct ggml_tensor * lmys; // the L-BFGS memory ys
-            struct ggml_tensor * lms;  // the L-BFGS memory s
-            struct ggml_tensor * lmy;  // the L-BFGS memory y
-            float fx_best;
-            float step;
-            int j;
-            int k;
-            int end;
-            int n_no_improvement;
-        } lbfgs;
-    };
-
-    GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
-
-    // optimize the function defined by the tensor f
-    GGML_API enum ggml_opt_result ggml_opt(
-            struct ggml_context * ctx,
-            struct ggml_opt_params params,
-            struct ggml_tensor * f);
-
-    // initialize optimizer context
-    GGML_API void ggml_opt_init(
-            struct ggml_context     * ctx,
-            struct ggml_opt_context * opt,
-            struct ggml_opt_params    params,
-            int64_t                   nx);
-
-    // continue optimizing the function defined by the tensor f
-    GGML_API enum ggml_opt_result ggml_opt_resume(
-            struct ggml_context * ctx,
-            struct ggml_opt_context * opt,
-            struct ggml_tensor * f);
-
-    // continue optimizing the function defined by the tensor f
-    GGML_API enum ggml_opt_result ggml_opt_resume_g(
-            struct ggml_context * ctx,
-            struct ggml_opt_context * opt,
-            struct ggml_tensor * f,
-            struct ggml_cgraph * gf,
-            struct ggml_cgraph * gb,
-            ggml_opt_callback callback,
-            void * callback_data);
-
-    //
-    // tensor flags
-    //
-    GGML_API void ggml_set_input(struct ggml_tensor * tensor);
-    GGML_API void ggml_set_output(struct ggml_tensor * tensor);
-
-    //
-    // quantization
-    //
-
-    // - ggml_quantize_init can be called multiple times with the same type
-    //   it will only initialize the quantization tables for the first call or after ggml_quantize_free
-    //   automatically called by ggml_quantize_chunk for convenience
-    //
-    // - ggml_quantize_free will free any memory allocated by ggml_quantize_init
-    //   call this at the end of the program to avoid memory leaks
-    //
-    // note: these are thread-safe
-    //
-    GGML_API void ggml_quantize_init(enum ggml_type type);
-    GGML_API void ggml_quantize_free(void);
-
-    // some quantization type cannot be used without an importance matrix
-    GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type);
-
-    // calls ggml_quantize_init internally (i.e. can allocate memory)
-    GGML_API size_t ggml_quantize_chunk(
-            enum ggml_type   type,
-               const float * src,
-                      void * dst,
-                   int64_t   start,
-                   int64_t   nrows,
-                   int64_t   n_per_row,
-               const float * imatrix);
-
-    //
-    // gguf
-    //
-
-    enum gguf_type {
-        GGUF_TYPE_UINT8   = 0,
-        GGUF_TYPE_INT8    = 1,
-        GGUF_TYPE_UINT16  = 2,
-        GGUF_TYPE_INT16   = 3,
-        GGUF_TYPE_UINT32  = 4,
-        GGUF_TYPE_INT32   = 5,
-        GGUF_TYPE_FLOAT32 = 6,
-        GGUF_TYPE_BOOL    = 7,
-        GGUF_TYPE_STRING  = 8,
-        GGUF_TYPE_ARRAY   = 9,
-        GGUF_TYPE_UINT64  = 10,
-        GGUF_TYPE_INT64   = 11,
-        GGUF_TYPE_FLOAT64 = 12,
-        GGUF_TYPE_COUNT,       // marks the end of the enum
-    };
-
-    struct gguf_context;
-
-    struct gguf_init_params {
-        bool no_alloc;
-
-        // if not NULL, create a ggml_context and allocate the tensor data in it
-        struct ggml_context ** ctx;
-    };
-
-    GGML_API struct gguf_context * gguf_init_empty(void);
-    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
-    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
-
-    GGML_API void gguf_free(struct gguf_context * ctx);
-
-    GGML_API const char * gguf_type_name(enum gguf_type type);
-
-    GGML_API int    gguf_get_version    (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_alignment  (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
-    GGML_API void * gguf_get_data       (const struct gguf_context * ctx);
-
-    GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
-    GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
-    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
-
-    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
-    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
-
-    // will abort if the wrong type is used for the key
-    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
-    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
-    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
-    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
-    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
-    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
-    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
-    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
-    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
-    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
-
-    GGML_API int            gguf_get_n_tensors    (const struct gguf_context * ctx);
-    GGML_API int            gguf_find_tensor      (const struct gguf_context * ctx, const char * name);
-    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
-    GGML_API char *         gguf_get_tensor_name  (const struct gguf_context * ctx, int i);
-    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int i);
-
-    // removes key if it exists
-    GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
-
-    // overrides existing values or adds a new one
-    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);
-    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t   val);
-    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
-    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t  val);
-    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
-    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);
-    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);
-    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
-    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t  val);
-    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double   val);
-    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);
-    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
-    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
-    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
-
-    // set or add KV pairs from another context
-    GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
-
-    // manage tensor info
-    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
-    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
-    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
-
-    // writing gguf files can be done in 2 ways:
-    //
-    // - write the entire gguf_context to a binary file in a single pass:
-    //
-    //   gguf_write_to_file(ctx, fname);
-    //
-    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
-    //
-    //   FILE * f = fopen(fname, "wb");
-    //   fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
-    //   fwrite(f, ...);
-    //   void * data = gguf_meta_get_meta_data(ctx);
-    //   fseek(f, 0, SEEK_SET);
-    //   fwrite(f, data, gguf_get_meta_size(ctx));
-    //   free(data);
-    //   fclose(f);
-    //
-
-    // write the entire context to a binary file
-    GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
-
-    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
-    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
-    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
-
-    //
-    // system info
-    //
-
-    GGML_API int ggml_cpu_has_avx        (void);
-    GGML_API int ggml_cpu_has_avx_vnni   (void);
-    GGML_API int ggml_cpu_has_avx2       (void);
-    GGML_API int ggml_cpu_has_avx512     (void);
-    GGML_API int ggml_cpu_has_avx512_vbmi(void);
-    GGML_API int ggml_cpu_has_avx512_vnni(void);
-    GGML_API int ggml_cpu_has_avx512_bf16(void);
-    GGML_API int ggml_cpu_has_fma        (void);
-    GGML_API int ggml_cpu_has_neon       (void);
-    GGML_API int ggml_cpu_has_sve        (void);
-    GGML_API int ggml_cpu_has_arm_fma    (void);
-    GGML_API int ggml_cpu_has_metal      (void);
-    GGML_API int ggml_cpu_has_f16c       (void);
-    GGML_API int ggml_cpu_has_fp16_va    (void);
-    GGML_API int ggml_cpu_has_wasm_simd  (void);
-    GGML_API int ggml_cpu_has_blas       (void);
-    GGML_API int ggml_cpu_has_cuda       (void);
-    GGML_API int ggml_cpu_has_vulkan     (void);
-    GGML_API int ggml_cpu_has_kompute    (void);
-    GGML_API int ggml_cpu_has_gpublas    (void);
-    GGML_API int ggml_cpu_has_sse3       (void);
-    GGML_API int ggml_cpu_has_ssse3      (void);
-    GGML_API int ggml_cpu_has_sycl       (void);
-    GGML_API int ggml_cpu_has_rpc        (void);
-    GGML_API int ggml_cpu_has_vsx        (void);
-    GGML_API int ggml_cpu_has_matmul_int8(void);
-
-    //
-    // Internal types and functions exposed for tests and benchmarks
-    //
-
-#ifdef  __cplusplus
-// restrict not standard in C++
-#define GGML_RESTRICT
-#else
-#define GGML_RESTRICT restrict
-#endif
-    typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
-    typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int64_t k);
-    typedef void (*ggml_vec_dot_t)   (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
-                                      const void * GGML_RESTRICT y, size_t by, int nrc);
-
-    typedef struct {
-        const char      * type_name;
-        int               blck_size;
-        size_t            type_size;
-        bool              is_quantized;
-        ggml_to_float_t   to_float;
-        ggml_from_float_t from_float;
-        ggml_from_float_t from_float_reference;
-        ggml_vec_dot_t    vec_dot;
-        enum ggml_type    vec_dot_type;
-        int64_t           nrows; // number of rows to process simultaneously;
-    } ggml_type_traits_t;
-
-    GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
-
-#ifdef  __cplusplus
-}
-#endif
index b75ed34bce81659aab2378c3b72c119d2c76b0b8..ea00c745583b0a840236aca633d2c3d7f3b80a5c 100755 (executable)
@@ -53,14 +53,18 @@ while read c; do
     fi
 
     git format-patch -k $c~1..$c --stdout -- \
-        ggml*.h \
-        ggml*.c \
-        ggml*.cpp \
-        ggml*.m \
-        ggml*.metal \
-        ggml*.cu \
-        ggml-cuda/* \
-        ggml-sycl/* \
+        ggml/CMakeLists.txt \
+        ggml/src/CMakeLists.txt \
+        ggml/cmake/FindSIMD.cmake \
+        ggml/src/ggml*.h \
+        ggml/src/ggml*.c \
+        ggml/src/ggml*.cpp \
+        ggml/src/ggml*.m \
+        ggml/src/ggml*.metal \
+        ggml/src/ggml*.cu \
+        ggml/src/ggml-cuda/* \
+        ggml/src/ggml-sycl/* \
+        ggml/include/ggml*.h \
         tests/test-opt.cpp \
         tests/test-grad0.cpp \
         tests/test-quantize-fns.cpp \
@@ -93,33 +97,39 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
 
     # replace filenames:
     #
-    # ggml.c              -> src/ggml.c
-    # ggml-alloc.c        -> src/ggml-alloc.c
-    # ggml-backend-impl.h -> src/ggml-backend-impl.h
-    # ggml-backend.c      -> src/ggml-backend.c
-    # ggml-blas.cpp       -> src/ggml-blas.cpp
-    # ggml-blas.h         -> src/ggml-blas.h
-    # ggml-common.h       -> src/ggml-common.h
-    # ggml-cuda/*         -> src/ggml-cuda/*
-    # ggml-cuda.cu        -> src/ggml-cuda.cu
-    # ggml-cuda.h         -> src/ggml-cuda.h
-    # ggml-impl.h         -> src/ggml-impl.h
-    # ggml-kompute.cpp    -> src/ggml-kompute.cpp
-    # ggml-kompute.h      -> src/ggml-kompute.h
-    # ggml-metal.h        -> src/ggml-metal.h
-    # ggml-metal.m        -> src/ggml-metal.m
-    # ggml-quants.c       -> src/ggml-quants.c
-    # ggml-quants.h       -> src/ggml-quants.h
-    # ggml-rpc.cpp        -> src/ggml-rpc.cpp
-    # ggml-rpc.h          -> src/ggml-rpc.h
-    # ggml-sycl/*         -> src/ggml-sycl/*
-    # ggml-sycl.cpp       -> src/ggml-sycl.cpp
-    # ggml-sycl.h         -> src/ggml-sycl.h
-    # ggml-vulkan.cpp     -> src/ggml-vulkan.cpp
-    # ggml-vulkan.h       -> src/ggml-vulkan.h
-    # ggml.h              -> include/ggml/ggml.h
-    # ggml-alloc.h        -> include/ggml/ggml-alloc.h
-    # ggml-backend.h      -> include/ggml/ggml-backend.h
+    # ggml/CMakelists.txt       -> CMakeLists.txt
+    # ggml/src/CMakelists.txt   -> src/CMakeLists.txt
+    # ggml/cmake/FindSIMD.cmake -> cmake/FindSIMD.cmake
+    #
+    # ggml/src/ggml.c              -> src/ggml.c
+    # ggml/src/ggml-alloc.c        -> src/ggml-alloc.c
+    # ggml/src/ggml-backend-impl.h -> src/ggml-backend-impl.h
+    # ggml/src/ggml-backend.c      -> src/ggml-backend.c
+    # ggml/src/ggml-blas.cpp       -> src/ggml-blas.cpp
+    # ggml/src/ggml-blas.h         -> src/ggml-blas.h
+    # ggml/src/ggml-common.h       -> src/ggml-common.h
+    # ggml/src/ggml-cuda/*         -> src/ggml-cuda/*
+    # ggml/src/ggml-cuda.cu        -> src/ggml-cuda.cu
+    # ggml/src/ggml-impl.h         -> src/ggml-impl.h
+    # ggml/src/ggml-kompute.cpp    -> src/ggml-kompute.cpp
+    # ggml/src/ggml-metal.m        -> src/ggml-metal.m
+    # ggml/src/ggml-quants.c       -> src/ggml-quants.c
+    # ggml/src/ggml-quants.h       -> src/ggml-quants.h
+    # ggml/src/ggml-rpc.cpp        -> src/ggml-rpc.cpp
+    # ggml/src/ggml-sycl/*         -> src/ggml-sycl/*
+    # ggml/src/ggml-sycl.cpp       -> src/ggml-sycl.cpp
+    # ggml/src/ggml-vulkan.cpp     -> src/ggml-vulkan.cpp
+    #
+    # ggml/include/ggml.h         -> include/ggml.h
+    # ggml/include/ggml-alloc.h   -> include/ggml-alloc.h
+    # ggml/include/ggml-backend.h -> include/ggml-backend.h
+    # ggml/include/ggml-blas.h    -> include/ggml-blas.h
+    # ggml/include/ggml-cuda.h    -> include/ggml-cuda.h
+    # ggml/include/ggml-kompute.h -> include/ggml-kompute.h
+    # ggml/include/ggml-metal.h   -> include/ggml-metal.h
+    # ggml/include/ggml-rpc.h     -> include/ggml-rpc.h
+    # ggml/include/ggml-sycl.h    -> include/ggml-sycl.h
+    # ggml/include/ggml-vulkan.h  -> include/ggml-vulkan.h
     #
     # tests/test-opt.cpp           -> tests/test-opt.cpp
     # tests/test-grad0.cpp         -> tests/test-grad0.cpp
@@ -131,34 +141,37 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
     # scripts/gen-authors.sh -> scripts/gen-authors.sh
 
     cat llama-src.patch | sed \
-        -e 's/\/ggml\.c/\/src\/ggml.c/g' \
-        -e 's/\/ggml-alloc\.c/\/src\/ggml-alloc.c/g' \
-        -e 's/\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \
-        -e 's/\/ggml-backend\.c/\/src\/ggml-backend.c/g' \
-        -e 's/\/ggml-blas\.cpp/\/src\/ggml-blas.cpp/g' \
-        -e 's/\/ggml-blas\.h/\/src\/ggml-blas.h/g' \
-        -e 's/\/ggml-common\.h/\/src\/ggml-common.h/g' \
-        -e 's/\/ggml-cuda\//\/src\/ggml-cuda\//g' \
-        -e 's/\/ggml-cuda\.cu/\/src\/ggml-cuda.cu/g' \
-        -e 's/\/ggml-cuda\.h/\/src\/ggml-cuda.h/g' \
-        -e 's/\/ggml-impl\.h/\/src\/ggml-impl.h/g' \
-        -e 's/\/ggml-kompute\.cpp/\/src\/ggml-kompute.cpp/g' \
-        -e 's/\/ggml-kompute\.h/\/src\/ggml-kompute.h/g' \
-        -e 's/\/ggml-metal\.h/\/src\/ggml-metal.h/g' \
-        -e 's/\/ggml-metal\.m/\/src\/ggml-metal.m/g' \
-        -e 's/\/ggml-quants\.c/\/src\/ggml-quants.c/g' \
-        -e 's/\/ggml-quants\.h/\/src\/ggml-quants.h/g' \
-        -e 's/\/ggml-rpc\.cpp/\/src\/ggml-rpc.cpp/g' \
-        -e 's/\/ggml-rpc\.h/\/src\/ggml-rpc.h/g' \
-        -e 's/\/ggml-sycl\//\/src\/ggml-sycl\//g' \
-        -e 's/\/ggml-sycl\.cpp/\/src\/ggml-sycl.cpp/g' \
-        -e 's/\/ggml-sycl\.h/\/src\/ggml-sycl.h/g' \
-        -e 's/\/ggml-vulkan\.cpp/\/src\/ggml-vulkan.cpp/g' \
-        -e 's/\/ggml-vulkan\.h/\/src\/ggml-vulkan.h/g' \
-        -e 's/\/ggml_vk_generate_shaders\.py/\/src\/ggml_vk_generate_shaders.py/g' \
-        -e 's/\/ggml\.h/\/include\/ggml\/ggml.h/g' \
-        -e 's/\/ggml-alloc\.h/\/include\/ggml\/ggml-alloc.h/g' \
-        -e 's/\/ggml-backend\.h/\/include\/ggml\/ggml-backend.h/g' \
+        -e 's/\/ggml\/CMakeLists\.txt/\/CMakeLists.txt/g' \
+        -e 's/\/ggml\/src\/CMakeLists\.txt/\/src\/CMakeLists.txt/g' \
+        -e 's/\/ggml\/cmake\/FindSIMD\.cmake/\/cmake\/FindSIMD.cmake/g' \
+        -e 's/\/ggml\/src\/ggml\.c/\/src\/ggml.c/g' \
+        -e 's/\/ggml\/src\/ggml-alloc\.c/\/src\/ggml-alloc.c/g' \
+        -e 's/\/ggml\/src\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \
+        -e 's/\/ggml\/src\/ggml-backend\.c/\/src\/ggml-backend.c/g' \
+        -e 's/\/ggml\/src\/ggml-blas\.cpp/\/src\/ggml-blas.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-blas\.h/\/src\/ggml-blas.h/g' \
+        -e 's/\/ggml\/src\/ggml-common\.h/\/src\/ggml-common.h/g' \
+        -e 's/\/ggml\/src\/ggml-cuda\//\/src\/ggml-cuda\//g' \
+        -e 's/\/ggml\/src\/ggml-cuda\.cu/\/src\/ggml-cuda.cu/g' \
+        -e 's/\/ggml\/src\/ggml-impl\.h/\/src\/ggml-impl.h/g' \
+        -e 's/\/ggml\/src\/ggml-kompute\.cpp/\/src\/ggml-kompute.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-metal\.m/\/src\/ggml-metal.m/g' \
+        -e 's/\/ggml\/src\/ggml-quants\.c/\/src\/ggml-quants.c/g' \
+        -e 's/\/ggml\/src\/ggml-quants\.h/\/src\/ggml-quants.h/g' \
+        -e 's/\/ggml\/src\/ggml-rpc\.cpp/\/src\/ggml-rpc.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-sycl\//\/src\/ggml-sycl\//g' \
+        -e 's/\/ggml\/src\/ggml-sycl\.cpp/\/src\/ggml-sycl.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-vulkan\.cpp/\/src\/ggml-vulkan.cpp/g' \
+        -e 's/\/ggml\/include\/ggml\.h/\/include\/ggml.h/g' \
+        -e 's/\/ggml\/include\/ggml-alloc\.h/\/include\/ggml-alloc.h/g' \
+        -e 's/\/ggml\/include\/ggml-backend\.h/\/include\/ggml-backend.h/g' \
+        -e 's/\/ggml\/include\/ggml-blas\.h/\/include\/ggml-blas.h/g' \
+        -e 's/\/ggml\/include\/ggml-cuda\.h/\/include\/ggml-cuda.h/g' \
+        -e 's/\/ggml\/include\/ggml-kompute\.h/\/include\/ggml-kompute.h/g' \
+        -e 's/\/ggml\/include\/ggml-metal\.h/\/include\/ggml-metal.h/g' \
+        -e 's/\/ggml\/include\/ggml-rpc\.h/\/include\/ggml-rpc.h/g' \
+        -e 's/\/ggml\/include\/ggml-sycl\.h/\/include\/ggml-sycl.h/g' \
+        -e 's/\/ggml\/include\/ggml-vulkan\.h/\/include\/ggml-vulkan.h/g' \
         -e 's/\/tests\/test-opt\.cpp/\/tests\/test-opt.cpp/g' \
         -e 's/\/tests\/test-grad0\.cpp/\/tests\/test-grad0.cpp/g' \
         -e 's/\/tests\/test-quantize-fns\.cpp/\/tests\/test-quantize-fns.cpp/g' \
index 330573abbf4fd9f40561250eb6022c2acdcc7a37..c4c17433bcf06f4b12d34fd93b1e5fdae3e5616b 100755 (executable)
@@ -1,34 +1,39 @@
 #!/bin/bash
 
-cp -rpv ../llama.cpp/ggml.c              src/ggml.c
-cp -rpv ../llama.cpp/ggml-alloc.c        src/ggml-alloc.c
-cp -rpv ../llama.cpp/ggml-backend-impl.h src/ggml-backend-impl.h
-cp -rpv ../llama.cpp/ggml-backend.c      src/ggml-backend.c
-cp -rpv ../llama.cpp/ggml-blas.cpp       src/ggml-blas.cpp
-cp -rpv ../llama.cpp/ggml-blas.h         src/ggml-blas.h
-cp -rpv ../llama.cpp/ggml-common.h       src/ggml-common.h
-cp -rpv ../llama.cpp/ggml-cuda/*         src/ggml-cuda/
-cp -rpv ../llama.cpp/ggml-cuda.cu        src/ggml-cuda.cu
-cp -rpv ../llama.cpp/ggml-cuda.h         src/ggml-cuda.h
-cp -rpv ../llama.cpp/ggml-impl.h         src/ggml-impl.h
-cp -rpv ../llama.cpp/ggml-kompute.cpp    src/ggml-kompute.cpp
-cp -rpv ../llama.cpp/ggml-kompute.h      src/ggml-kompute.h
-cp -rpv ../llama.cpp/ggml-metal.h        src/ggml-metal.h
-cp -rpv ../llama.cpp/ggml-metal.m        src/ggml-metal.m
-cp -rpv ../llama.cpp/ggml-metal.metal    src/ggml-metal.metal
-cp -rpv ../llama.cpp/ggml-quants.c       src/ggml-quants.c
-cp -rpv ../llama.cpp/ggml-quants.h       src/ggml-quants.h
-cp -rpv ../llama.cpp/ggml-rpc.cpp        src/ggml-rpc.cpp
-cp -rpv ../llama.cpp/ggml-rpc.h          src/ggml-rpc.h
-cp -rpv ../llama.cpp/ggml-sycl/*         src/ggml-sycl/
-cp -rpv ../llama.cpp/ggml-sycl.cpp       src/ggml-sycl.cpp
-cp -rpv ../llama.cpp/ggml-sycl.h         src/ggml-sycl.h
-cp -rpv ../llama.cpp/ggml-vulkan.cpp     src/ggml-vulkan.cpp
-cp -rpv ../llama.cpp/ggml-vulkan.h       src/ggml-vulkan.h
-cp -rpv ../llama.cpp/ggml_vk_generate_shaders.py       src/ggml_vk_generate_shaders.py
-cp -rpv ../llama.cpp/ggml.h              include/ggml/ggml.h
-cp -rpv ../llama.cpp/ggml-alloc.h        include/ggml/ggml-alloc.h
-cp -rpv ../llama.cpp/ggml-backend.h      include/ggml/ggml-backend.h
+cp -rpv ../llama.cpp/ggml/CMakeLists.txt       CMakeLists.txt
+cp -rpv ../llama.cpp/ggml/src/CMakeLists.txt   src/CMakeLists.txt
+cp -rpv ../llama.cpp/ggml/cmake/FindSIMD.cmake cmake/FindSIMD.cmake
+
+cp -rpv ../llama.cpp/ggml/src/ggml.c              src/ggml.c
+cp -rpv ../llama.cpp/ggml/src/ggml-alloc.c        src/ggml-alloc.c
+cp -rpv ../llama.cpp/ggml/src/ggml-backend-impl.h src/ggml-backend-impl.h
+cp -rpv ../llama.cpp/ggml/src/ggml-backend.c      src/ggml-backend.c
+cp -rpv ../llama.cpp/ggml/src/ggml-blas.cpp       src/ggml-blas.cpp
+cp -rpv ../llama.cpp/ggml/src/ggml-blas.h         src/ggml-blas.h
+cp -rpv ../llama.cpp/ggml/src/ggml-common.h       src/ggml-common.h
+cp -rpv ../llama.cpp/ggml/src/ggml-cuda/*         src/ggml-cuda/
+cp -rpv ../llama.cpp/ggml/src/ggml-cuda.cu        src/ggml-cuda.cu
+cp -rpv ../llama.cpp/ggml/src/ggml-impl.h         src/ggml-impl.h
+cp -rpv ../llama.cpp/ggml/src/ggml-kompute.cpp    src/ggml-kompute.cpp
+cp -rpv ../llama.cpp/ggml/src/ggml-metal.m        src/ggml-metal.m
+cp -rpv ../llama.cpp/ggml/src/ggml-metal.metal    src/ggml-metal.metal
+cp -rpv ../llama.cpp/ggml/src/ggml-quants.c       src/ggml-quants.c
+cp -rpv ../llama.cpp/ggml/src/ggml-quants.h       src/ggml-quants.h
+cp -rpv ../llama.cpp/ggml/src/ggml-rpc.cpp        src/ggml-rpc.cpp
+cp -rpv ../llama.cpp/ggml/src/ggml-sycl/*         src/ggml-sycl/
+cp -rpv ../llama.cpp/ggml/src/ggml-sycl.cpp       src/ggml-sycl.cpp
+cp -rpv ../llama.cpp/ggml/src/ggml-vulkan.cpp     src/ggml-vulkan.cpp
+
+cp -rpv ../llama.cpp/ggml/include/ggml.h         include/ggml.h
+cp -rpv ../llama.cpp/ggml/include/ggml-alloc.h   include/ggml-alloc.h
+cp -rpv ../llama.cpp/ggml/include/ggml-backend.h include/ggml-backend.h
+cp -rpv ../llama.cpp/ggml/include/ggml-blas.h    include/ggml-blas.h
+cp -rpv ../llama.cpp/ggml/include/ggml-cuda.h    include/ggml-cuda.h
+cp -rpv ../llama.cpp/ggml/include/ggml-kompute.h include/ggml-kompute.h
+cp -rpv ../llama.cpp/ggml/include/ggml-metal.h   include/ggml-metal.h
+cp -rpv ../llama.cpp/ggml/include/ggml-rpc.h     include/ggml-rpc.h
+cp -rpv ../llama.cpp/ggml/include/ggml-sycl.h    include/ggml-sycl.h
+cp -rpv ../llama.cpp/ggml/include/ggml-vulkan.h  include/ggml-vulkan.h
 
 cp -rpv ../llama.cpp/tests/test-opt.cpp           tests/test-opt.cpp
 cp -rpv ../llama.cpp/tests/test-grad0.cpp         tests/test-grad0.cpp
index de3ac876dc8f3839e48dd1fbf44d32adad834eeb..c46790d444a3fd21aafa9934ca08f13cd49076e5 100755 (executable)
@@ -53,20 +53,24 @@ while read c; do
     fi
 
     git format-patch -k $c~1..$c --stdout -- \
-    ggml*.h \
-    ggml*.c \
-    ggml*.cpp \
-    ggml*.m \
-    ggml*.metal \
-    ggml*.cu \
-    ggml-cuda/* \
-    examples/common.h \
-    examples/common.cpp \
-    examples/common-ggml.h \
-    examples/common-ggml.cpp \
-    LICENSE \
-    scripts/gen-authors.sh \
-    >> $SRC_GGML/whisper-src.patch
+        ggml/CMakeLists.txt \
+        ggml/src/CMakeLists.txt \
+        ggml/cmake/FindSIMD.cmake \
+        ggml/src/ggml*.h \
+        ggml/src/ggml*.c \
+        ggml/src/ggml*.cpp \
+        ggml/src/ggml*.m \
+        ggml/src/ggml*.metal \
+        ggml/src/ggml*.cu \
+        ggml/src/ggml-cuda/* \
+        ggml/include/ggml*.h \
+        examples/common.h \
+        examples/common.cpp \
+        examples/common-ggml.h \
+        examples/common-ggml.cpp \
+        LICENSE \
+        scripts/gen-authors.sh \
+        >> $SRC_GGML/whisper-src.patch
 done < $SRC_GGML/whisper-commits
 
 rm -v $SRC_GGML/whisper-commits
@@ -91,70 +95,80 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then
 
     # replace filenames:
     #
-    # ggml.c              -> src/ggml.c
-    # ggml-alloc.c        -> src/ggml-alloc.c
-    # ggml-backend-impl.h -> src/ggml-backend-impl.h
-    # ggml-backend.c      -> src/ggml-backend.c
-    # ggml-blas.cpp       -> src/ggml-blas.cpp
-    # ggml-blas.h         -> src/ggml-blas.h
-    # ggml-common.h       -> src/ggml-common.h
-    # ggml-cuda/*         -> src/ggml-cuda/
-    # ggml-cuda.cu        -> src/ggml-cuda.cu
-    # ggml-cuda.h         -> src/ggml-cuda.h
-    # ggml-impl.h         -> src/ggml-impl.h
-    # ggml-kompute.cpp    -> src/ggml-kompute.cpp
-    # ggml-kompute.h      -> src/ggml-kompute.h
-    # ggml-metal.h        -> src/ggml-metal.h
-    # ggml-metal.m        -> src/ggml-metal.m
-    # ggml-quants.c       -> src/ggml-quants.c
-    # ggml-quants.h       -> src/ggml-quants.h
-    # ggml-rpc.cpp        -> src/ggml-rpc.cpp
-    # ggml-rpc.h          -> src/ggml-rpc.h
-    # ggml-sycl/*         -> src/ggml-sycl/*
-    # ggml-sycl.cpp       -> src/ggml-sycl.cpp
-    # ggml-sycl.h         -> src/ggml-sycl.h
-    # ggml-vulkan.cpp     -> src/ggml-vulkan.cpp
-    # ggml-vulkan.h       -> src/ggml-vulkan.h
-    # ggml.h              -> include/ggml/ggml.h
-    # ggml-alloc.h        -> include/ggml/ggml-alloc.h
-    # ggml-backend.h      -> include/ggml/ggml-backend.h
+    # ggml/CMakelists.txt       -> CMakeLists.txt
+    # ggml/src/CMakelists.txt   -> src/CMakeLists.txt
+    # ggml/cmake/FindSIMD.cmake -> cmake/FindSIMD.cmake
     #
-    # examples/common.h              -> examples/common.h
-    # examples/common.cpp            -> examples/common.cpp
-    # examples/common-ggml.h         -> examples/common-ggml.h
-    # examples/common-ggml.cpp       -> examples/common-ggml.cpp
+    # ggml/src/ggml.c              -> src/ggml.c
+    # ggml/src/ggml-alloc.c        -> src/ggml-alloc.c
+    # ggml/src/ggml-backend-impl.h -> src/ggml-backend-impl.h
+    # ggml/src/ggml-backend.c      -> src/ggml-backend.c
+    # ggml/src/ggml-blas.cpp       -> src/ggml-blas.cpp
+    # ggml/src/ggml-blas.h         -> src/ggml-blas.h
+    # ggml/src/ggml-common.h       -> src/ggml-common.h
+    # ggml/src/ggml-cuda/*         -> src/ggml-cuda/*
+    # ggml/src/ggml-cuda.cu        -> src/ggml-cuda.cu
+    # ggml/src/ggml-impl.h         -> src/ggml-impl.h
+    # ggml/src/ggml-kompute.cpp    -> src/ggml-kompute.cpp
+    # ggml/src/ggml-metal.m        -> src/ggml-metal.m
+    # ggml/src/ggml-quants.c       -> src/ggml-quants.c
+    # ggml/src/ggml-quants.h       -> src/ggml-quants.h
+    # ggml/src/ggml-rpc.cpp        -> src/ggml-rpc.cpp
+    # ggml/src/ggml-sycl/*         -> src/ggml-sycl/*
+    # ggml/src/ggml-sycl.cpp       -> src/ggml-sycl.cpp
+    # ggml/src/ggml-vulkan.cpp     -> src/ggml-vulkan.cpp
+    #
+    # ggml/include/ggml.h         -> include/ggml.h
+    # ggml/include/ggml-alloc.h   -> include/ggml-alloc.h
+    # ggml/include/ggml-backend.h -> include/ggml-backend.h
+    # ggml/include/ggml-blas.h    -> include/ggml-blas.h
+    # ggml/include/ggml-cuda.h    -> include/ggml-cuda.h
+    # ggml/include/ggml-kompute.h -> include/ggml-kompute.h
+    # ggml/include/ggml-metal.h   -> include/ggml-metal.h
+    # ggml/include/ggml-rpc.h     -> include/ggml-rpc.h
+    # ggml/include/ggml-sycl.h    -> include/ggml-sycl.h
+    # ggml/include/ggml-vulkan.h  -> include/ggml-vulkan.h
+    #
+    # examples/common.h        -> examples/common.h
+    # examples/common.cpp      -> examples/common.cpp
+    # examples/common-ggml.h   -> examples/common-ggml.h
+    # examples/common-ggml.cpp -> examples/common-ggml.cpp
     #
     # LICENSE                -> LICENSE
     # scripts/gen-authors.sh -> scripts/gen-authors.sh
 
     cat whisper-src.patch | sed \
-        -e 's/\/ggml\.c/\/src\/ggml.c/g' \
-        -e 's/\/ggml-alloc\.c/\/src\/ggml-alloc.c/g' \
-        -e 's/\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \
-        -e 's/\/ggml-backend\.c/\/src\/ggml-backend.c/g' \
-        -e 's/\/ggml-blas\.cpp/\/src\/ggml-blas.cpp/g' \
-        -e 's/\/ggml-blas\.h/\/src\/ggml-blas.h/g' \
-        -e 's/\/ggml-common\.h/\/src\/ggml-common.h/g' \
-        -e 's/\/ggml-cuda\//\/src\/ggml-cuda\//g' \
-        -e 's/\/ggml-cuda\.cu/\/src\/ggml-cuda.cu/g' \
-        -e 's/\/ggml-cuda\.h/\/src\/ggml-cuda.h/g' \
-        -e 's/\/ggml-impl\.h/\/src\/ggml-impl.h/g' \
-        -e 's/\/ggml-kompute\.cpp/\/src\/ggml-kompute.cpp/g' \
-        -e 's/\/ggml-kompute\.h/\/src\/ggml-kompute.h/g' \
-        -e 's/\/ggml-metal\.h/\/src\/ggml-metal.h/g' \
-        -e 's/\/ggml-metal\.m/\/src\/ggml-metal.m/g' \
-        -e 's/\/ggml-quants\.c/\/src\/ggml-quants.c/g' \
-        -e 's/\/ggml-quants\.h/\/src\/ggml-quants.h/g' \
-        -e 's/\/ggml-rpc\.cpp/\/src\/ggml-rpc.cpp/g' \
-        -e 's/\/ggml-rpc\.h/\/src\/ggml-rpc.h/g' \
-        -e 's/\/ggml-sycl\//\/src\/ggml-sycl\//g' \
-        -e 's/\/ggml-sycl\.cpp/\/src\/ggml-sycl.cpp/g' \
-        -e 's/\/ggml-sycl\.h/\/src\/ggml-sycl.h/g' \
-        -e 's/\/ggml-vulkan\.cpp/\/src\/ggml-vulkan.cpp/g' \
-        -e 's/\/ggml-vulkan\.h/\/src\/ggml-vulkan.h/g' \
-        -e 's/\/ggml\.h/\/include\/ggml\/ggml.h/g' \
-        -e 's/\/ggml-alloc\.h/\/include\/ggml\/ggml-alloc.h/g' \
-        -e 's/\/ggml-backend\.h/\/include\/ggml\/ggml-backend.h/g' \
+        -e 's/\/ggml\/CMakeLists\.txt/\/CMakeLists.txt/g' \
+        -e 's/\/ggml\/src\/CMakeLists\.txt/\/src\/CMakeLists.txt/g' \
+        -e 's/\/ggml\/cmake\/FindSIMD\.cmake/\/cmake\/FindSIMD.cmake/g' \
+        -e 's/\/ggml\/src\/ggml\.c/\/src\/ggml.c/g' \
+        -e 's/\/ggml\/src\/ggml-alloc\.c/\/src\/ggml-alloc.c/g' \
+        -e 's/\/ggml\/src\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \
+        -e 's/\/ggml\/src\/ggml-backend\.c/\/src\/ggml-backend.c/g' \
+        -e 's/\/ggml\/src\/ggml-blas\.cpp/\/src\/ggml-blas.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-blas\.h/\/src\/ggml-blas.h/g' \
+        -e 's/\/ggml\/src\/ggml-common\.h/\/src\/ggml-common.h/g' \
+        -e 's/\/ggml\/src\/ggml-cuda\//\/src\/ggml-cuda\//g' \
+        -e 's/\/ggml\/src\/ggml-cuda\.cu/\/src\/ggml-cuda.cu/g' \
+        -e 's/\/ggml\/src\/ggml-impl\.h/\/src\/ggml-impl.h/g' \
+        -e 's/\/ggml\/src\/ggml-kompute\.cpp/\/src\/ggml-kompute.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-metal\.m/\/src\/ggml-metal.m/g' \
+        -e 's/\/ggml\/src\/ggml-quants\.c/\/src\/ggml-quants.c/g' \
+        -e 's/\/ggml\/src\/ggml-quants\.h/\/src\/ggml-quants.h/g' \
+        -e 's/\/ggml\/src\/ggml-rpc\.cpp/\/src\/ggml-rpc.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-sycl\//\/src\/ggml-sycl\//g' \
+        -e 's/\/ggml\/src\/ggml-sycl\.cpp/\/src\/ggml-sycl.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-vulkan\.cpp/\/src\/ggml-vulkan.cpp/g' \
+        -e 's/\/ggml\/include\/ggml\.h/\/include\/ggml.h/g' \
+        -e 's/\/ggml\/include\/ggml-alloc\.h/\/include\/ggml-alloc.h/g' \
+        -e 's/\/ggml\/include\/ggml-backend\.h/\/include\/ggml-backend.h/g' \
+        -e 's/\/ggml\/include\/ggml-blas\.h/\/include\/ggml-blas.h/g' \
+        -e 's/\/ggml\/include\/ggml-cuda\.h/\/include\/ggml-cuda.h/g' \
+        -e 's/\/ggml\/include\/ggml-kompute\.h/\/include\/ggml-kompute.h/g' \
+        -e 's/\/ggml\/include\/ggml-metal\.h/\/include\/ggml-metal.h/g' \
+        -e 's/\/ggml\/include\/ggml-rpc\.h/\/include\/ggml-rpc.h/g' \
+        -e 's/\/ggml\/include\/ggml-sycl\.h/\/include\/ggml-sycl.h/g' \
+        -e 's/\/ggml\/include\/ggml-vulkan\.h/\/include\/ggml-vulkan.h/g' \
         -e 's/\/examples\/common\.h/\/examples\/common.h/g' \
         -e 's/\/examples\/common\.cpp/\/examples\/common.cpp/g' \
         -e 's/\/examples\/common-ggml\.h/\/examples\/common-ggml.h/g' \
index 3100ba4e1527491c6021c246fe6cef926723403a..b78e66ada3dbfa76e2feb85e9d0d066d8ee7775a 100755 (executable)
@@ -1,39 +1,44 @@
 #!/bin/bash
 
-cp -rpv ../whisper.cpp/ggml.c                         src/ggml.c
-cp -rpv ../whisper.cpp/ggml-impl.h                    src/ggml-impl.h
-cp -rpv ../whisper.cpp/ggml-alloc.c                   src/ggml-alloc.c
-cp -rpv ../whisper.cpp/ggml-backend-impl.h            src/ggml-backend-impl.h
-cp -rpv ../whisper.cpp/ggml-backend.c                 src/ggml-backend.c
-cp -rpv ../whisper.cpp/ggml-blas.cpp                  src/ggml-blas.cpp
-cp -rpv ../whisper.cpp/ggml-blas.h                    src/ggml-blas.h
-cp -rpv ../whisper.cpp/ggml-common.h                  src/ggml-common.h
-cp -rpv ../whisper.cpp/ggml-cuda/*                    src/ggml-cuda/
-cp -rpv ../whisper.cpp/ggml-cuda.cu                   src/ggml-cuda.cu
-cp -rpv ../whisper.cpp/ggml-cuda.h                    src/ggml-cuda.h
-cp -rpv ../whisper.cpp/ggml-kompute.cpp               src/ggml-kompute.cpp
-cp -rpv ../whisper.cpp/ggml-kompute.h                 src/ggml-kompute.h
-cp -rpv ../whisper.cpp/ggml-metal.h                   src/ggml-metal.h
-cp -rpv ../whisper.cpp/ggml-metal.m                   src/ggml-metal.m
-cp -rpv ../whisper.cpp/ggml-metal.metal               src/ggml-metal.metal
-cp -rpv ../whisper.cpp/ggml-quants.c                  src/ggml-quants.c
-cp -rpv ../whisper.cpp/ggml-quants.h                  src/ggml-quants.h
-cp -rpv ../whisper.cpp/ggml-rpc.cpp                   src/ggml-rpc.cpp
-cp -rpv ../whisper.cpp/ggml-rpc.h                     src/ggml-rpc.h
-cp -rpv ../whisper.cpp/ggml-sycl/*                    src/ggml-sycl/
-cp -rpv ../whisper.cpp/ggml-sycl.cpp                  src/ggml-sycl.cpp
-cp -rpv ../whisper.cpp/ggml-sycl.h                    src/ggml-sycl.h
-cp -rpv ../whisper.cpp/ggml-vulkan.cpp                src/ggml-vulkan.cpp
-cp -rpv ../whisper.cpp/ggml-vulkan.h                  src/ggml-vulkan.h
+cp -rpv ../whisper.cpp/ggml/CMakeLists.txt       CMakeLists.txt
+cp -rpv ../whisper.cpp/ggml/src/CMakeLists.txt   src/CMakeLists.txt
+cp -rpv ../whisper.cpp/ggml/cmake/FindSIMD.cmake cmake/FindSIMD.cmake
 
-cp -rpv ../whisper.cpp/ggml.h                         include/ggml/ggml.h
-cp -rpv ../whisper.cpp/ggml-alloc.h                   include/ggml/ggml-alloc.h
-cp -rpv ../whisper.cpp/ggml-backend.h                 include/ggml/ggml-backend.h
+cp -rpv ../whisper.cpp/ggml/src/ggml.c              src/ggml.c
+cp -rpv ../whisper.cpp/ggml/src/ggml-alloc.c        src/ggml-alloc.c
+cp -rpv ../whisper.cpp/ggml/src/ggml-backend-impl.h src/ggml-backend-impl.h
+cp -rpv ../whisper.cpp/ggml/src/ggml-backend.c      src/ggml-backend.c
+cp -rpv ../whisper.cpp/ggml/src/ggml-blas.cpp       src/ggml-blas.cpp
+cp -rpv ../whisper.cpp/ggml/src/ggml-blas.h         src/ggml-blas.h
+cp -rpv ../whisper.cpp/ggml/src/ggml-common.h       src/ggml-common.h
+cp -rpv ../whisper.cpp/ggml/src/ggml-cuda/*         src/ggml-cuda/
+cp -rpv ../whisper.cpp/ggml/src/ggml-cuda.cu        src/ggml-cuda.cu
+cp -rpv ../whisper.cpp/ggml/src/ggml-impl.h         src/ggml-impl.h
+cp -rpv ../whisper.cpp/ggml/src/ggml-kompute.cpp    src/ggml-kompute.cpp
+cp -rpv ../whisper.cpp/ggml/src/ggml-metal.m        src/ggml-metal.m
+cp -rpv ../whisper.cpp/ggml/src/ggml-metal.metal    src/ggml-metal.metal
+cp -rpv ../whisper.cpp/ggml/src/ggml-quants.c       src/ggml-quants.c
+cp -rpv ../whisper.cpp/ggml/src/ggml-quants.h       src/ggml-quants.h
+cp -rpv ../whisper.cpp/ggml/src/ggml-rpc.cpp        src/ggml-rpc.cpp
+cp -rpv ../whisper.cpp/ggml/src/ggml-sycl/*         src/ggml-sycl/
+cp -rpv ../whisper.cpp/ggml/src/ggml-sycl.cpp       src/ggml-sycl.cpp
+cp -rpv ../whisper.cpp/ggml/src/ggml-vulkan.cpp     src/ggml-vulkan.cpp
 
-cp -rpv ../whisper.cpp/examples/common.h              examples/common.h
-cp -rpv ../whisper.cpp/examples/common.cpp            examples/common.cpp
-cp -rpv ../whisper.cpp/examples/common-ggml.h         examples/common-ggml.h
-cp -rpv ../whisper.cpp/examples/common-ggml.cpp       examples/common-ggml.cpp
+cp -rpv ../whisper.cpp/ggml/include/ggml.h         include/ggml.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-alloc.h   include/ggml-alloc.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-backend.h include/ggml-backend.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-blas.h    include/ggml-blas.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-cuda.h    include/ggml-cuda.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-kompute.h include/ggml-kompute.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-metal.h   include/ggml-metal.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-rpc.h     include/ggml-rpc.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-sycl.h    include/ggml-sycl.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-vulkan.h  include/ggml-vulkan.h
 
-cp -rpv ../whisper.cpp/LICENSE                        ./LICENSE
-cp -rpv ../whisper.cpp/scripts/gen-authors.sh         ./scripts/gen-authors.sh
+cp -rpv ../whisper.cpp/examples/common.h        examples/common.h
+cp -rpv ../whisper.cpp/examples/common.cpp      examples/common.cpp
+cp -rpv ../whisper.cpp/examples/common-ggml.h   examples/common-ggml.h
+cp -rpv ../whisper.cpp/examples/common-ggml.cpp examples/common-ggml.cpp
+
+cp -rpv ../whisper.cpp/LICENSE                ./LICENSE
+cp -rpv ../whisper.cpp/scripts/gen-authors.sh ./scripts/gen-authors.sh
index 407fadaf8504d0d43717e2589e3af606a91b0755..284bda015a62af1873e78408a4473cc467e5cf17 120000 (symlink)
@@ -1 +1 @@
-../include/ggml/ggml-alloc.h
\ No newline at end of file
+../include/ggml-alloc.h
\ No newline at end of file
index a69e9b54e4992609a4fcb0d31da2220b58374280..fe8a239b1d63197326985874c4f0df2213618768 120000 (symlink)
@@ -1 +1 @@
-../include/ggml/ggml-backend.h
\ No newline at end of file
+../include/ggml-backend.h
\ No newline at end of file
diff --git a/spm-headers/ggml-metal.h b/spm-headers/ggml-metal.h
new file mode 120000 (symlink)
index 0000000..b5e5471
--- /dev/null
@@ -0,0 +1 @@
+../include/ggml-metal.h
\ No newline at end of file
index 245bb981e44e25510e2a327605148ec6cea1bcd6..9de84567aa57187a625db32e7f1175b31f25512f 120000 (symlink)
@@ -1 +1 @@
-../include/ggml/ggml.h
\ No newline at end of file
+../include/ggml.h
\ No newline at end of file
index 0330c9b36e026a2f5ecbb925bae8a9206b905621..ba341d3749050bab42c782777ea87a40e9e7c782 100644 (file)
-if (GGML_ALL_WARNINGS)
-    if (NOT MSVC)
-        add_compile_options(-Wunused -Wextra -Wcast-qual -Wdouble-promotion)
-        add_compile_options("$<$<COMPILE_LANGUAGE:C>:-Wshadow;-Wno-unused-function;-Wmissing-prototypes>")
-    else()
-        # todo : windows
-    endif()
-endif()
-
-# compiler flags
+include(CheckCXXCompilerFlag)
 
-if (NOT MSVC)
-    #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-math-errno -ffinite-math-only -funsafe-math-optimizations")
-endif()
+unset(GGML_CDEF_PUBLIC)
 
-message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
+add_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES})
 
-if (NOT UNAME_S)
-    execute_process(COMMAND uname -s OUTPUT_VARIABLE UNAME_S)
-endif()
-if (NOT UNAME_P)
-    execute_process(COMMAND uname -p OUTPUT_VARIABLE UNAME_P)
-endif()
-if (NOT UNAME_M)
-    execute_process(COMMAND uname -m OUTPUT_VARIABLE UNAME_M)
+# enable libstdc++ assertions for debug builds
+if (CMAKE_SYSTEM_NAME MATCHES "Linux")
+    add_compile_definitions($<$<CONFIG:Debug>:_GLIBCXX_ASSERTIONS>)
 endif()
-#message(STATUS "UNAME_S: ${UNAME_S}  UNAME_P: ${UNAME_P}  UNAME_M: ${UNAME_M}")
 
-# this version of Apple ld64 is buggy
-execute_process(
-    COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
-    ERROR_VARIABLE output
-)
-if (output MATCHES "dyld-1015\.7")
-    add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
-endif()
+if (NOT MSVC)
+    if (GGML_SANITIZE_THREAD)
+        add_compile_options(-fsanitize=thread)
+        link_libraries     (-fsanitize=thread)
+    endif()
 
-# Mac OS + Arm can report x86_64
-# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
-if (UNAME_S MATCHES "Darwin")
-    if (NOT UNAME_P MATCHES "arm")
-        execute_process(COMMAND sysctl -n hw.optional.arm64 OUTPUT_VARIABLE SYSCTL_M)
-       if (SYSCTL_M MATCHES "1")
-            #set(UNAME_P "arm")
-            #set(UNAME_M "arm64")
-           message(WARNING "Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-#1282546789")
-       endif()
+    if (GGML_SANITIZE_ADDRESS)
+        add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
+        link_libraries     (-fsanitize=address)
     endif()
-endif()
 
-if (${CMAKE_SYSTEM_NAME} STREQUAL "Emscripten")
-    message(STATUS "Emscripten detected")
-elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
-    message(STATUS "ARM detected")
-    #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mcpu=apple-m1")
-elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
-    message(STATUS "PPC64 detected")
-    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mpower9-vector")
-else()
-    message(STATUS "x86 detected")
-    #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
-    if (UNAME_S MATCHES "Darwin")
-        execute_process(COMMAND sysctl machdep.cpu.features OUTPUT_VARIABLE AVX1_M)
-        if (AVX1_M MATCHES "AVX1.0")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
-        endif()
-        execute_process(COMMAND sysctl machdep.cpu.leaf7_features OUTPUT_VARIABLE AVX2_M)
-        if (AVX2_M MATCHES "AVX2")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
-        endif()
-        if (AVX1_M MATCHES "FMA")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
-        endif()
-        set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
-    elseif (UNAME_S MATCHES "Linux")
-        message(STATUS "Linux detected")
-        execute_process(COMMAND grep "avx " /proc/cpuinfo OUTPUT_VARIABLE AVX1_M)
-        if (AVX1_M MATCHES "avx")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
-        endif()
-        execute_process(COMMAND grep "avx2 " /proc/cpuinfo OUTPUT_VARIABLE AVX2_M)
-        if (AVX2_M MATCHES "avx2")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
-        endif()
-        execute_process(COMMAND grep "fma " /proc/cpuinfo OUTPUT_VARIABLE FMA_M)
-        if (FMA_M MATCHES "fma")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
-        endif()
-        execute_process(COMMAND grep "f16c " /proc/cpuinfo OUTPUT_VARIABLE F16C_M)
-        if (F16C_M MATCHES "f16c")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
-        endif()
-        execute_process(COMMAND grep "sse3 " /proc/cpuinfo OUTPUT_VARIABLE SSE3_M)
-        if (SSE3_M MATCHES "sse3")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse3")
-        endif()
-    elseif (UNAME_S MATCHES "Haiku")
-        message(STATUS "Haiku detected")
-        execute_process(COMMAND sysinfo -cpu COMMAND grep "AVX " OUTPUT_VARIABLE AVX1_M)
-        if (AVX1_M MATCHES "avx")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
-        endif()
-        execute_process(COMMAND sysinfo -cpu COMMAND grep "AVX2 " OUTPUT_VARIABLE AVX2_M)
-        if (AVX2_M MATCHES "avx2")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
-        endif()
-        execute_process(COMMAND sysinfo -cpu COMMAND grep "FMA " OUTPUT_VARIABLE FMA_M)
-        if (FMA_M MATCHES "fma")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
-        endif()
-        execute_process(COMMAND sysinfo -cpu COMMAND grep "F16C " OUTPUT_VARIABLE F16C_M)
-        if (F16C_M MATCHES "f16c")
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
-        endif()
-    elseif (MSVC)
-        if (GGML_AVX512)
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX512")
-            # MSVC has no compile-time flags enabling specific
-            # AVX512 extensions, neither it defines the
-            # macros corresponding to the extensions.
-            # Do it manually.
-            if (GGML_AVX512_VBMI)
-                add_compile_definitions(__AVX512VBMI__)
-            endif()
-            if (GGML_AVX512_VNNI)
-                add_compile_definitions(__AVX512VNNI__)
-            endif()
-        elseif (GGML_AVX2)
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
-        elseif (GGML_AVX)
-            set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX")
-        endif()
-    else()
-        set(CMAKE_C_FLAGS  "${CMAKE_C_FLAGS} -mfma -mf16c -mavx -mavx2")
+    if (GGML_SANITIZE_UNDEFINED)
+        add_compile_options(-fsanitize=undefined)
+        link_libraries     (-fsanitize=undefined)
     endif()
 endif()
 
-# ggml
-
-set(TARGET ggml)
-
-# on APPLE - include Accelerate framework
-if (APPLE AND NOT GGML_NO_ACCELERATE)
+if (APPLE AND GGML_ACCELERATE)
     find_library(ACCELERATE_FRAMEWORK Accelerate)
     if (ACCELERATE_FRAMEWORK)
         message(STATUS "Accelerate framework found")
 
-        set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${ACCELERATE_FRAMEWORK})
-        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64)
+        add_compile_definitions(GGML_USE_ACCELERATE)
+        add_compile_definitions(ACCELERATE_NEW_LAPACK)
+        add_compile_definitions(ACCELERATE_LAPACK_ILP64)
+
+        set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
     else()
         message(WARNING "Accelerate framework not found")
     endif()
 endif()
 
+if (GGML_METAL)
+    find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
+    find_library(METAL_FRAMEWORK    Metal      REQUIRED)
+    find_library(METALKIT_FRAMEWORK MetalKit   REQUIRED)
+
+    message(STATUS "Metal framework found")
+    set(GGML_HEADERS_METAL ../include/ggml-metal.h)
+    set(GGML_SOURCES_METAL ggml-metal.m)
+
+    list(APPEND GGML_CDEF_PUBLIC GGML_USE_METAL)
+    if (GGML_METAL_NDEBUG)
+        add_compile_definitions(GGML_METAL_NDEBUG)
+    endif()
+
+    # copy ggml-common.h and ggml-metal.metal to bin directory
+    configure_file(ggml-common.h    ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h    COPYONLY)
+    configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
+
+    if (GGML_METAL_EMBED_LIBRARY)
+        enable_language(ASM)
+
+        add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
+
+        set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/ggml-common.h")
+        set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
+
+        file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
+
+        # merge ggml-common.h and ggml-metal.metal into a single file
+        set(METALLIB_EMBED_ASM    "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
+        set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+
+        add_custom_command(
+            OUTPUT ${METALLIB_EMBED_ASM}
+            COMMAND echo "Embedding Metal library"
+            COMMAND sed -e '/\#include \"ggml-common.h\"/r ${METALLIB_COMMON}' -e '/\#include \"ggml-common.h\"/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED}
+            COMMAND echo ".section __DATA,__ggml_metallib"          >  ${METALLIB_EMBED_ASM}
+            COMMAND echo ".globl _ggml_metallib_start"              >> ${METALLIB_EMBED_ASM}
+            COMMAND echo "_ggml_metallib_start:"                    >> ${METALLIB_EMBED_ASM}
+            COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
+            COMMAND echo ".globl _ggml_metallib_end"                >> ${METALLIB_EMBED_ASM}
+            COMMAND echo "_ggml_metallib_end:"                      >> ${METALLIB_EMBED_ASM}
+            DEPENDS ggml-metal.metal ggml-common.h
+            COMMENT "Generate assembly for embedded Metal library"
+        )
+
+        set(GGML_SOURCES_METAL ${GGML_SOURCES_METAL} ${METALLIB_EMBED_ASM})
+    else()
+        if (GGML_METAL_SHADER_DEBUG)
+            # custom command to do the following:
+            #   xcrun -sdk macosx metal    -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
+            #   xcrun -sdk macosx metallib                   ggml-metal.air   -o default.metallib
+            #
+            # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
+            #       disabling fast math is needed in order to pass tests/test-backend-ops
+            # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
+            # note: unfortunately, we have to call it default.metallib instead of ggml.metallib
+            #       ref: https://github.com/ggerganov/whisper.cpp/issues/1720
+            set(XC_FLAGS -fno-fast-math -fno-inline -g)
+        else()
+            set(XC_FLAGS -O3)
+        endif()
+
+        # Append macOS metal versioning flags
+        if (GGML_METAL_MACOSX_VERSION_MIN)
+            message(STATUS "Adding  -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
+            list   (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
+        endif()
+
+        if (GGML_METAL_STD)
+            message(STATUS "Adding  -std=${GGML_METAL_STD} flag to metal compilation")
+            list   (APPEND XC_FLAGS -std=${GGML_METAL_STD})
+        endif()
+
+        add_custom_command(
+            OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
+            COMMAND xcrun -sdk macosx metal    ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
+            COMMAND xcrun -sdk macosx metallib                ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air   -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
+            COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
+            COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
+            COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
+            DEPENDS ggml-metal.metal ggml-common.h
+            COMMENT "Compiling Metal kernels"
+            )
+
+        add_custom_target(
+            ggml-metal ALL
+            DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
+            )
+    endif() # GGML_METAL_EMBED_LIBRARY
+
+    set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS}
+        ${FOUNDATION_LIBRARY}
+        ${METAL_FRAMEWORK}
+        ${METALKIT_FRAMEWORK}
+        )
+endif()
+
+if (GGML_OPENMP)
+    find_package(OpenMP)
+    if (OpenMP_FOUND)
+        message(STATUS "OpenMP found")
+
+        add_compile_definitions(GGML_USE_OPENMP)
+
+        set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
+    else()
+        message(WARNING "OpenMP not found")
+    endif()
+endif()
+
 if (GGML_BLAS)
     if (GGML_STATIC)
         set(BLA_STATIC ON)
@@ -219,18 +219,17 @@ if (GGML_BLAS)
 
         add_compile_options(${BLAS_LINKER_FLAGS})
 
-        add_compile_definitions(GGML_USE_BLAS)
+        list(APPEND GGML_CDEF_PUBLIC GGML_USE_BLAS)
 
         if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
             add_compile_definitions(GGML_BLAS_USE_MKL)
         endif()
 
-        set(GGML_HEADERS_BLAS ggml-blas.h)
+        set(GGML_HEADERS_BLAS ../include/ggml-blas.h)
         set(GGML_SOURCES_BLAS ggml-blas.cpp)
 
         set(GGML_EXTRA_LIBS     ${GGML_EXTRA_LIBS}     ${BLAS_LIBRARIES})
         set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
-        set(GGML_EXTRA_FLAGS    ${GGML_EXTRA_FLAGS} -DGGML_USE_BLAS)
     else()
         message(WARNING "BLAS not found, please refer to "
         "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
@@ -238,72 +237,101 @@ if (GGML_BLAS)
     endif()
 endif()
 
-if (GGML_CUBLAS)
-    message(WARNING "GGML_CUBLAS is deprecated and will be removed in the future.\nUse GGML_CUDA instead")
-    set(GGML_CUDA ON)
+if (GGML_LLAMAFILE)
+    message(STATUS "Using ggml SGEMM")
+
+    add_compile_definitions(GGML_USE_LLAMAFILE)
+
+    set(GGML_HEADERS_LLAMAFILE sgemm.h)
+    set(GGML_SOURCES_LLAMAFILE sgemm.cpp)
 endif()
 
 if (GGML_CUDA)
-    cmake_minimum_required(VERSION 3.17)
+    cmake_minimum_required(VERSION 3.18)  # for CMAKE_CUDA_ARCHITECTURES
 
     find_package(CUDAToolkit)
+
     if (CUDAToolkit_FOUND)
         message(STATUS "CUDA found")
 
         if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
             # 52 == lowest CUDA 12 standard
-            # 60 == f16 CUDA intrinsics
+            # 60 == FP16 CUDA intrinsics
             # 61 == integer CUDA intrinsics
             # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
             if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
-                set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
+                set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
             else()
-                set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
+                set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
                 #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
             endif()
         endif()
-
         message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
 
         enable_language(CUDA)
 
-        file(GLOB   GGML_SOURCES_CUDA "ggml-cuda/*.cu")
-        list(APPEND GGML_SOURCES_CUDA  ggml-cuda.h)
-        list(APPEND GGML_SOURCES_CUDA  ggml-cuda.cu)
+        file(GLOB   GGML_HEADERS_CUDA "ggml-cuda/*.cuh")
+        list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h")
 
-        file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
+        file(GLOB   GGML_SOURCES_CUDA "ggml-cuda/*.cu")
+        list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
+        file(GLOB   SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
         list(APPEND GGML_SOURCES_CUDA ${SRCS})
-        file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
+        file(GLOB   SRCS "ggml-cuda/template-instances/mmq*.cu")
         list(APPEND GGML_SOURCES_CUDA ${SRCS})
 
         if (GGML_CUDA_FA_ALL_QUANTS)
-            file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
+            file(GLOB   SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
             list(APPEND GGML_SOURCES_CUDA ${SRCS})
             add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
         else()
-            file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
+            file(GLOB   SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
             list(APPEND GGML_SOURCES_CUDA ${SRCS})
-            file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
+            file(GLOB   SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
             list(APPEND GGML_SOURCES_CUDA ${SRCS})
-            file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
+            file(GLOB   SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
             list(APPEND GGML_SOURCES_CUDA ${SRCS})
         endif()
 
-        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUDA)
+        list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA)
+
+        add_compile_definitions(GGML_CUDA_USE_GRAPHS)
+        add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
+        add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
+        add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
+        add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
 
         if (GGML_CUDA_FORCE_DMMV)
             add_compile_definitions(GGML_CUDA_FORCE_DMMV)
         endif()
+
         if (GGML_CUDA_FORCE_MMQ)
             add_compile_definitions(GGML_CUDA_FORCE_MMQ)
         endif()
 
-        # required for dynamic parallelism
-        # set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
+        if (GGML_CUDA_FORCE_CUBLAS)
+            add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
+        endif()
+
+        if (GGML_CUDA_NO_VMM)
+            add_compile_definitions(GGML_CUDA_NO_VMM)
+        endif()
+
+        if (DEFINED GGML_CUDA_DMMV_Y)
+            add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility
+        endif()
+
+        if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
+            add_compile_definitions(GGML_CUDA_F16)
+        endif()
+
+        if (GGML_CUDA_NO_PEER_COPY)
+            add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+        endif()
 
         if (GGML_STATIC)
             if (WIN32)
-                # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
+                # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
                 set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
             else ()
                 set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
@@ -312,17 +340,16 @@ if (GGML_CUDA)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
         endif()
 
-        set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver)
-
-        if (CMAKE_BUILD_TYPE MATCHES Debug)
-            set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo")
+        if (GGML_CUDA_NO_VMM)
+            # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
+        else()
+            set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
         endif()
     else()
         message(WARNING "CUDA not found")
     endif()
 endif()
 
-# TODO: do not build separate ggml-rocm target (see CUDA build above, or llama.cpp for reference)
 if (GGML_HIPBLAS)
     if (NOT EXISTS $ENV{ROCM_PATH})
         if (NOT EXISTS /opt/rocm)
@@ -333,18 +360,19 @@ if (GGML_HIPBLAS)
     else()
         set(ROCM_PATH $ENV{ROCM_PATH})
     endif()
-    list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
+
+    list(APPEND CMAKE_PREFIX_PATH  ${ROCM_PATH})
     list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
 
     # CMake on Windows doesn't support the HIP language yet
-    if(WIN32)
+    if (WIN32)
         set(CXX_IS_HIPCC TRUE)
     else()
         string(REGEX MATCH "hipcc(\.bat)?$" CXX_IS_HIPCC "${CMAKE_CXX_COMPILER}")
     endif()
 
     if (CXX_IS_HIPCC)
-        if(LINUX)
+        if (LINUX)
             if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
                 message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
             endif()
@@ -354,7 +382,7 @@ if (GGML_HIPBLAS)
         endif()
     else()
         # Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
-        if(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
+        if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
             set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
         endif()
         cmake_minimum_required(VERSION 3.21)
@@ -367,33 +395,39 @@ if (GGML_HIPBLAS)
 
     message(STATUS "HIP and hipBLAS found")
 
-    add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)
+    file(GLOB   GGML_HEADERS_ROCM "ggml-cuda/*.cuh")
+    list(APPEND GGML_HEADERS_ROCM "../include/ggml-cuda.h")
 
-    set(GGML_HEADERS_ROCM ggml-cuda.h)
-
-    file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu")
+    file(GLOB   GGML_SOURCES_ROCM "ggml-cuda/*.cu")
     list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
+    file(GLOB   SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
     list(APPEND GGML_SOURCES_ROCM ${SRCS})
-
-    file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
-    list(APPEND GGML_SOURCES_ROCM ${SRCS})
-    file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
+    file(GLOB   SRCS "ggml-cuda/template-instances/mmq*.cu")
     list(APPEND GGML_SOURCES_ROCM ${SRCS})
 
     if (GGML_CUDA_FA_ALL_QUANTS)
-        file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
+        file(GLOB   SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
         list(APPEND GGML_SOURCES_ROCM ${SRCS})
         add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
     else()
-        file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
+        file(GLOB   SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
         list(APPEND GGML_SOURCES_ROCM ${SRCS})
-        file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
+        file(GLOB   SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
         list(APPEND GGML_SOURCES_ROCM ${SRCS})
-        file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
+        file(GLOB   SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
         list(APPEND GGML_SOURCES_ROCM ${SRCS})
     endif()
 
-    add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)
+    list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA)
+
+    add_compile_definitions(GGML_USE_HIPBLAS)
+    add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
+    add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
+    add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
+
+    if (GGML_HIP_UMA)
+        add_compile_definitions(GGML_HIP_UMA)
+    endif()
 
     if (GGML_CUDA_FORCE_DMMV)
         add_compile_definitions(GGML_CUDA_FORCE_DMMV)
@@ -403,253 +437,735 @@ if (GGML_HIPBLAS)
         add_compile_definitions(GGML_CUDA_FORCE_MMQ)
     endif()
 
-    add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
-    add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
-    add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
-
-    add_library(ggml-rocm OBJECT ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM})
+    if (GGML_CUDA_NO_PEER_COPY)
+        add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+    endif()
 
     if (CXX_IS_HIPCC)
         set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
-        target_link_libraries(ggml-rocm PRIVATE hip::device)
+        set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} hip::device)
     else()
         set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)
     endif()
 
-    target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
-    target_include_directories(ggml-rocm PRIVATE . ../include ../include/ggml)
-
-    if (BUILD_SHARED_LIBS)
-        set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
-    endif()
-
     if (GGML_STATIC)
         message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
     endif()
-    set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ggml-rocm)
+
+    set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} PUBLIC hip::host roc::rocblas roc::hipblas)
 endif()
 
-if (GGML_METAL)
-    find_library(FOUNDATION_LIBRARY         Foundation              REQUIRED)
-    find_library(METAL_FRAMEWORK            Metal                   REQUIRED)
-    find_library(METALKIT_FRAMEWORK         MetalKit                REQUIRED)
-    find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
+if (GGML_SYCL)
+    if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA)$")
+        message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL or NVIDIA")
+    endif()
 
-    set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
+    if ( NOT DEFINED ENV{ONEAPI_ROOT})
+        message(FATAL_ERROR "Not detect ENV {ONEAPI_ROOT}, please install oneAPI & source it, like: source /opt/intel/oneapi/setvars.sh")
+    endif()
+    #todo: AOT
 
-    set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_METAL)
+    find_package(IntelSYCL REQUIRED)
+    find_package(MKL REQUIRED)
 
-    #add_compile_definitions(GGML_METAL_NDEBUG)
+    message(STATUS "SYCL found")
 
-    # get full path to the file
-    #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
+    list(APPEND GGML_CDEF_PUBLIC GGML_USE_SYCL)
 
-    # copy ggml-common.h and ggml-metal.metal to bin directory
-    configure_file(ggml-common.h    ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h    COPYONLY)
-    configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
+    if (GGML_SYCL_F16)
+        add_compile_definitions(GGML_SYCL_F16)
+    endif()
 
-    if (GGML_METAL_EMBED_LIBRARY)
-        enable_language(ASM)
-        add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
+    if (GGML_CUDA_FORCE_MMQ)
+        add_compile_definitions(GGML_SYCL_FORCE_MMQ)
+    endif()
 
-        set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/ggml-common.h")
-        set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
+    add_compile_options(-I./) #include DPCT
 
-        file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
+    if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
+        set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
+    endif()
 
-        # merge ggml-common.h and ggml-metal.metal into a single file
-        set(METALLIB_EMBED_ASM    "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
-        set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+    file(GLOB   GGML_HEADERS_SYCL "ggml-sycl/*.hpp")
+    list(APPEND GGML_HEADERS_SYCL "../include/ggml-sycl.h")
 
-        add_custom_command(
-            OUTPUT ${METALLIB_EMBED_ASM}
-            COMMAND echo "Embedding Metal library"
-            COMMAND sed -e '/\#include \"ggml-common.h\"/r ${METALLIB_COMMON}' -e '/\#include \"ggml-common.h\"/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED}
-            COMMAND echo ".section __DATA,__ggml_metallib"          >  ${METALLIB_EMBED_ASM}
-            COMMAND echo ".globl _ggml_metallib_start"              >> ${METALLIB_EMBED_ASM}
-            COMMAND echo "_ggml_metallib_start:"                    >> ${METALLIB_EMBED_ASM}
-            COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
-            COMMAND echo ".globl _ggml_metallib_end"                >> ${METALLIB_EMBED_ASM}
-            COMMAND echo "_ggml_metallib_end:"                      >> ${METALLIB_EMBED_ASM}
-            DEPENDS ggml-metal.metal ggml-common.h
-            COMMENT "Generate assembly for embedded Metal library"
-        )
+    file(GLOB   GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
+    list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
 
-    set(GGML_SOURCES_METAL ${GGML_SOURCES_METAL} ${METALLIB_EMBED_ASM})
+    if (WIN32)
+        set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
     else()
-        if (GGML_METAL_SHADER_DEBUG)
-            # custom command to do the following:
-            #   xcrun -sdk macosx metal    -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
-            #   xcrun -sdk macosx metallib                   ggml-metal.air   -o default.metallib
-            #
-            # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
-            #       disabling fast math is needed in order to pass tests/test-backend-ops
-            # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
-            # note: unfortunately, we have to call it default.metallib instead of ggml.metallib
-            #       ref: https://github.com/ggerganov/whisper.cpp/issues/1720
-            set(XC_FLAGS -fno-fast-math -fno-inline -g)
-        else()
-            set(XC_FLAGS -O3)
-        endif()
+        add_compile_options(-I/${SYCL_INCLUDE_DIR})
+        set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib")
 
-        # Append macOS metal versioning flags
-        if (GGML_METAL_MACOSX_VERSION_MIN)
-            message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
-            list(APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
-        endif()
-        if (GGML_METAL_STD)
-            message(STATUS "Adding -std=${GGML_METAL_STD} flag to metal compilation")
-            list(APPEND XC_FLAGS -std=${GGML_METAL_STD})
+        if (GGML_SYCL_TARGET STREQUAL "INTEL")
+            set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
+        elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
+            set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
         endif()
-
-        add_custom_command(
-            OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
-            COMMAND xcrun -sdk macosx metal    ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
-            COMMAND xcrun -sdk macosx metallib                ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air   -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
-            COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
-            COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
-            COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
-            DEPENDS ggml-metal.metal ggml-common.h
-            COMMENT "Compiling Metal kernels"
-            )
-
-        add_custom_target(
-            ggml-metal ALL
-            DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
-            )
-    endif() # GGML_METAL_EMBED_LIBRARY
-
-
-    set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS}
-        ${FOUNDATION_LIBRARY}
-        ${METAL_FRAMEWORK}
-        ${METALKIT_FRAMEWORK}
-        ${METALPERFORMANCE_FRAMEWORK}
-        )
+    endif()
 endif()
 
 if (GGML_RPC)
-    add_compile_definitions(GGML_USE_RPC)
+    message(STATUS "RPC found")
+
+    list(APPEND GGML_CDEF_PUBLIC GGML_USE_RPC)
 
     if (WIN32)
         set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ws2_32)
     endif()
 
+    set(GGML_HEADERS_RPC ../include/ggml-rpc.h)
     set(GGML_SOURCES_RPC ggml-rpc.cpp)
 endif()
 
 if (GGML_VULKAN)
     find_package(Vulkan)
+
     if (Vulkan_FOUND)
         message(STATUS "Vulkan found")
 
-        set(GGML_VULKAN_SOURCES ggml-vulkan.cpp ggml-vulkan.h)
+        set(GGML_HEADERS_VULKAN ../include/ggml-vulkan.h)
+        set(GGML_SOURCES_VULKAN ggml-vulkan.cpp)
 
-        add_library(ggml-vulkan OBJECT ggml-vulkan.cpp ggml-vulkan.h)
-        if (BUILD_SHARED_LIBS)
-            set_target_properties(ggml-vulkan PROPERTIES POSITION_INDEPENDENT_CODE ON)
+        list(APPEND GGML_CDEF_PUBLIC GGML_USE_VULKAN)
+
+        # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
+        # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
+        if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+            add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
+        endif()
+
+        if (GGML_VULKAN_CHECK_RESULTS)
+            add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
         endif()
-        target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
-        target_include_directories(ggml-vulkan PRIVATE . ../include ../include/ggml)
 
-        add_compile_definitions(GGML_USE_VULKAN)
+        if (GGML_VULKAN_DEBUG)
+            add_compile_definitions(GGML_VULKAN_DEBUG)
+        endif()
 
-        set(GGML_EXTRA_LIBS ${GGMl_EXTRA_LIBS} ggml-vulkan)
+        if (GGML_VULKAN_MEMORY_DEBUG)
+            add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
+        endif()
+
+        if (GGML_VULKAN_VALIDATE)
+            add_compile_definitions(GGML_VULKAN_VALIDATE)
+        endif()
+
+        if (GGML_VULKAN_RUN_TESTS)
+            add_compile_definitions(GGML_VULKAN_RUN_TESTS)
+        endif()
+
+        set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} Vulkan::Vulkan)
     else()
         message(WARNING "Vulkan not found")
     endif()
 endif()
 
-if (GGML_PERF)
-    set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_PERF)
+if (GGML_KOMPUTE)
+    add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
+
+    find_package(Vulkan COMPONENTS glslc REQUIRED)
+    find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
+
+    if (NOT glslc_executable)
+        message(FATAL_ERROR "glslc not found")
+    endif()
+
+    function(compile_shader)
+        set(options)
+        set(oneValueArgs)
+        set(multiValueArgs SOURCES)
+        cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+        foreach(source ${compile_shader_SOURCES})
+            get_filename_component(filename ${source} NAME)
+            set(spv_file ${filename}.spv)
+            add_custom_command(
+                OUTPUT ${spv_file}
+                DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
+                ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
+                ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
+                ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
+                ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
+                COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
+                COMMENT "Compiling ${source} to ${spv_file}"
+                )
+
+            get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
+            set(FILE_NAME "shader${RAW_FILE_NAME}")
+            string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
+            string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
+            string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
+            set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
+            message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
+            if(CMAKE_GENERATOR MATCHES "Visual Studio")
+                add_custom_command(
+                    OUTPUT ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+                    DEPENDS ${spv_file} xxd
+                    COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
+                    )
+            else()
+                add_custom_command(
+                    OUTPUT ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
+                    COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+                    DEPENDS ${spv_file} xxd
+                    COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
+                    )
+            endif()
+        endforeach()
+    endfunction()
+
+    if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
+        message(STATUS "Kompute found")
+        set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
+        add_subdirectory(kompute)
+
+        # Compile our shaders
+        compile_shader(SOURCES
+            kompute-shaders/op_scale.comp
+            kompute-shaders/op_scale_8.comp
+            kompute-shaders/op_add.comp
+            kompute-shaders/op_addrow.comp
+            kompute-shaders/op_mul.comp
+            kompute-shaders/op_silu.comp
+            kompute-shaders/op_relu.comp
+            kompute-shaders/op_gelu.comp
+            kompute-shaders/op_softmax.comp
+            kompute-shaders/op_norm.comp
+            kompute-shaders/op_rmsnorm.comp
+            kompute-shaders/op_diagmask.comp
+            kompute-shaders/op_mul_mat_mat_f32.comp
+            kompute-shaders/op_mul_mat_f16.comp
+            kompute-shaders/op_mul_mat_q8_0.comp
+            kompute-shaders/op_mul_mat_q4_0.comp
+            kompute-shaders/op_mul_mat_q4_1.comp
+            kompute-shaders/op_mul_mat_q6_k.comp
+            kompute-shaders/op_getrows_f32.comp
+            kompute-shaders/op_getrows_f16.comp
+            kompute-shaders/op_getrows_q4_0.comp
+            kompute-shaders/op_getrows_q4_1.comp
+            kompute-shaders/op_getrows_q6_k.comp
+            kompute-shaders/op_rope_f16.comp
+            kompute-shaders/op_rope_f32.comp
+            kompute-shaders/op_cpy_f16_f16.comp
+            kompute-shaders/op_cpy_f16_f32.comp
+            kompute-shaders/op_cpy_f32_f16.comp
+            kompute-shaders/op_cpy_f32_f32.comp
+        )
+
+        # Create a custom target for our generated shaders
+        add_custom_target(generated_shaders DEPENDS
+            shaderop_scale.h
+            shaderop_scale_8.h
+            shaderop_add.h
+            shaderop_addrow.h
+            shaderop_mul.h
+            shaderop_silu.h
+            shaderop_relu.h
+            shaderop_gelu.h
+            shaderop_softmax.h
+            shaderop_norm.h
+            shaderop_rmsnorm.h
+            shaderop_diagmask.h
+            shaderop_mul_mat_mat_f32.h
+            shaderop_mul_mat_f16.h
+            shaderop_mul_mat_q8_0.h
+            shaderop_mul_mat_q4_0.h
+            shaderop_mul_mat_q4_1.h
+            shaderop_mul_mat_q6_k.h
+            shaderop_getrows_f32.h
+            shaderop_getrows_f16.h
+            shaderop_getrows_q4_0.h
+            shaderop_getrows_q4_1.h
+            shaderop_getrows_q6_k.h
+            shaderop_rope_f16.h
+            shaderop_rope_f32.h
+            shaderop_cpy_f16_f16.h
+            shaderop_cpy_f16_f32.h
+            shaderop_cpy_f32_f16.h
+            shaderop_cpy_f32_f32.h
+        )
+
+        # Create a custom command that depends on the generated_shaders
+        add_custom_command(
+            OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
+            COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
+            DEPENDS generated_shaders
+            COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
+        )
+
+        # Add the stamp to the main sources to ensure dependency tracking
+        set(GGML_SOURCES_KOMPUTE ggml-kompute.cpp           ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
+        set(GGML_HEADERS_KOMPUTE ../include/ggml-kompute.h  ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
+
+        list(APPEND GGML_CDEF_PUBLIC GGML_USE_KOMPUTE)
+
+        set(GGML_EXTRA_LIBS     ${GGML_EXTRA_LIBS}     kompute)
+        set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CMAKE_CURRENT_BINARY_DIR})
+    else()
+        message(WARNING "Kompute not found")
+    endif()
 endif()
 
-add_library(${TARGET}
-    ggml.c
-    ggml-alloc.c
-    ggml-backend.c
-    ggml-quants.c
-    ggml-impl.h
-    ggml-backend-impl.h
-    ../include/ggml/ggml.h
-    ../include/ggml/ggml-alloc.h
-    ../include/ggml/ggml-backend.h
-    ${GGML_SOURCES_CUDA}  ${GGML_HEADERS_CUDA}
-    ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
-    ${GGML_SOURCES_RPC}   ${GGML_HEADERS_RPC}
-    ${GGML_SOURCES_BLAS}  ${GGML_HEADERS_BLAS}
-    )
+if (GGML_CPU_HBM)
+    find_library(memkind memkind REQUIRED)
 
-target_include_directories(${TARGET} PUBLIC
-    .
-    ../include
-    ../include/ggml
-    ${GGML_EXTRA_INCS}
-    )
+    message(STATUS "Using memkind for CPU HBM")
 
-find_library(MATH_LIBRARY m)
-if (MATH_LIBRARY)
-    target_link_libraries(${TARGET} PUBLIC ${MATH_LIBRARY})
+    add_compile_definitions(GGML_USE_CPU_HBM)
+
+    target_link_libraries(ggml PUBLIC memkind)
 endif()
 
-target_link_libraries(${TARGET} PUBLIC ${GGML_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
+function(get_flags CCID CCVER)
+    set(C_FLAGS "")
+    set(CXX_FLAGS "")
 
-if (BUILD_SHARED_LIBS)
-    set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
+    if (CCID MATCHES "Clang")
+        set(C_FLAGS   -Wunreachable-code-break -Wunreachable-code-return)
+        set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
 
-    target_link_libraries(${TARGET} PUBLIC
-        ${CMAKE_DL_LIBS}
+        if (
+            (CCID STREQUAL "Clang"      AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
+            (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
         )
+            list(APPEND C_FLAGS -Wdouble-promotion)
+        endif()
+    elseif (CCID STREQUAL "GNU")
+        set(C_FLAGS   -Wdouble-promotion)
+        set(CXX_FLAGS -Wno-array-bounds)
 
-    target_compile_definitions(${TARGET} PUBLIC
-        GGML_SHARED
-        )
+        if (CCVER VERSION_GREATER_EQUAL 7.1.0)
+            list(APPEND CXX_FLAGS -Wno-format-truncation)
+        endif()
+        if (CCVER VERSION_GREATER_EQUAL 8.1.0)
+            list(APPEND CXX_FLAGS -Wextra-semi)
+        endif()
+    endif()
 
-    target_compile_definitions(${TARGET} PRIVATE
-        GGML_BUILD
-        )
+    set(GF_C_FLAGS   ${C_FLAGS}   PARENT_SCOPE)
+    set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
+endfunction()
 
-    if (GGML_METAL)
-        set_target_properties(${TARGET} PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
+if (GGML_FATAL_WARNINGS)
+    if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+        list(APPEND C_FLAGS   -Werror)
+        list(APPEND CXX_FLAGS -Werror)
+    elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
+        add_compile_options(/WX)
     endif()
 endif()
 
-target_compile_definitions(${TARGET} PUBLIC
-    ${GGML_EXTRA_FLAGS}
-    )
+if (GGML_ALL_WARNINGS)
+    if (NOT MSVC)
+        list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
+        list(APPEND C_FLAGS       -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
+                                  -Werror=implicit-int -Werror=implicit-function-declaration)
+        list(APPEND CXX_FLAGS     -Wmissing-declarations -Wmissing-noreturn)
 
-if (MINGW)
-    target_link_libraries(${TARGET} PUBLIC
-        stdc++
-        )
+        list(APPEND C_FLAGS   ${WARNING_FLAGS})
+        list(APPEND CXX_FLAGS ${WARNING_FLAGS})
+
+        get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
+
+        add_compile_options("$<$<COMPILE_LANGUAGE:C>:${C_FLAGS};${GF_C_FLAGS}>"
+                            "$<$<COMPILE_LANGUAGE:CXX>:${CXX_FLAGS};${GF_CXX_FLAGS}>")
+    else()
+        # todo : msvc
+        set(C_FLAGS   "")
+        set(CXX_FLAGS "")
+    endif()
 endif()
 
-if (GGML_SOURCES_CUDA)
-    message(STATUS "GGML CUDA sources found")
-    if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
-        # Only configure gmml CUDA architectures is not globally set
-        if (NOT DEFINED GGML_CUDA_ARCHITECTURES)
-            # Not overriden by user, so set defaults
-            set(GGML_CUDA_ARCHITECTURES 52 61 70)
+set(CUDA_CXX_FLAGS "")
+
+if (GGML_CUDA)
+    set(CUDA_FLAGS -use_fast_math)
+
+    if (GGML_FATAL_WARNINGS)
+        list(APPEND CUDA_FLAGS -Werror all-warnings)
+    endif()
+
+    if (GGML_ALL_WARNINGS AND NOT MSVC)
+        set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
+        if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
+            list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})
+        endif()
+
+        execute_process(
+            COMMAND ${NVCC_CMD} -Xcompiler --version
+            OUTPUT_VARIABLE CUDA_CCFULLVER
+            ERROR_QUIET
+        )
+
+        if (NOT CUDA_CCFULLVER MATCHES clang)
+            set(CUDA_CCID "GNU")
+            execute_process(
+                COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
+                OUTPUT_VARIABLE CUDA_CCVER
+                ERROR_QUIET
+            )
+        else()
+            if (CUDA_CCFULLVER MATCHES Apple)
+                set(CUDA_CCID "AppleClang")
+            else()
+                set(CUDA_CCID "Clang")
+            endif()
+            string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
         endif()
-        message(STATUS "GGML Configuring CUDA architectures ${GGML_CUDA_ARCHITECTURES}")
-        set_property(TARGET ggml  PROPERTY CUDA_ARCHITECTURES ${GGML_CUDA_ARCHITECTURES})
+
+        message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
+
+        get_flags(${CUDA_CCID} ${CUDA_CCVER})
+        list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS})  # This is passed to -Xcompiler later
     endif()
-    set_property(TARGET ggml  PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
+
     if (NOT MSVC)
-        target_link_libraries(ggml PUBLIC stdc++)
+        list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
+    endif()
+endif()
+
+if (GGML_LTO)
+    include(CheckIPOSupported)
+    check_ipo_supported(RESULT result OUTPUT output)
+    if (result)
+        set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
+    else()
+        message(WARNING "IPO is not supported: ${output}")
+    endif()
+endif()
+
+if (GGML_CCACHE)
+    find_program(GGML_CCACHE_FOUND ccache)
+
+    if (GGML_CCACHE_FOUND)
+        # TODO: should not be set globally
+        set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
+        set(ENV{CCACHE_SLOPPINESS} time_macros)
+        message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
+    else()
+        message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
+    endif ()
+endif()
+
+# this version of Apple ld64 is buggy
+execute_process(
+    COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
+    ERROR_VARIABLE output
+    OUTPUT_QUIET
+)
+
+if (output MATCHES "dyld-1015\.7")
+    add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
+endif()
+
+# architecture specific
+# TODO: probably these flags need to be tweaked on some architectures
+#       feel free to update the Makefile for your architecture and send a pull request or issue
+message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
+if (MSVC)
+    string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
+    message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
+else ()
+    set(CMAKE_GENERATOR_PLATFORM_LWR "")
+endif ()
+
+if (NOT MSVC)
+    if (GGML_STATIC)
+        add_link_options(-static)
+        if (MINGW)
+            add_link_options(-static-libgcc -static-libstdc++)
+        endif()
+    endif()
+    if (GGML_GPROF)
+        add_compile_options(-pg)
+    endif()
+endif()
+
+set(ARCH_FLAGS "")
+
+if (CMAKE_OSX_ARCHITECTURES      STREQUAL "arm64" OR
+    CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
+    (NOT CMAKE_OSX_ARCHITECTURES      AND
+     NOT CMAKE_GENERATOR_PLATFORM_LWR AND
+         CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
+
+    message(STATUS "ARM detected")
+
+    if (MSVC)
+        add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead
+        add_compile_definitions(__ARM_NEON)
+        add_compile_definitions(__ARM_FEATURE_FMA)
+
+        set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})
+        string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2")
+
+        check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD)
+        if (GGML_COMPILER_SUPPORT_DOTPROD)
+            add_compile_definitions(__ARM_FEATURE_DOTPROD)
+        endif ()
+
+        check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
+
+        if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
+            add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
+        endif ()
+
+        check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
+        if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
+            add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+        endif ()
+
+        set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
+    else()
+        check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
+        if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
+            list(APPEND ARCH_FLAGS -mfp16-format=ieee)
+        endif()
+        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
+            # Raspberry Pi 1, Zero
+            list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)
+        endif()
+        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
+            if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
+                # Android armeabi-v7a
+                list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)
+            else()
+                # Raspberry Pi 2
+                list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
+            endif()
+        endif()
+        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
+            # Android arm64-v8a
+            # Raspberry Pi 3, 4, Zero 2 (32-bit)
+            list(APPEND ARCH_FLAGS -mno-unaligned-access)
+        endif()
+        if (GGML_SVE)
+            list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
+        endif()
+    endif()
+elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
+        (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
+         CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
+    message(STATUS "x86 detected")
+    if (MSVC)
+        # instruction set detection for MSVC only
+        if (GGML_NATIVE)
+            # TODO: improve, should not reference files from the parent folder
+            include(../cmake/FindSIMD.cmake)
+        endif ()
+        if (GGML_AVX512)
+            list(APPEND ARCH_FLAGS /arch:AVX512)
+            # MSVC has no compile-time flags enabling specific
+            # AVX512 extensions, neither it defines the
+            # macros corresponding to the extensions.
+            # Do it manually.
+            if (GGML_AVX512_VBMI)
+                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
+                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
+            endif()
+            if (GGML_AVX512_VNNI)
+                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
+                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
+            endif()
+            if (GGML_AVX512_BF16)
+                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
+                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
+            endif()
+        elseif (GGML_AVX2)
+            list(APPEND ARCH_FLAGS /arch:AVX2)
+        elseif (GGML_AVX)
+            list(APPEND ARCH_FLAGS /arch:AVX)
+        endif()
+    else()
+        if (GGML_NATIVE)
+            list(APPEND ARCH_FLAGS -march=native)
+        endif()
+        if (GGML_F16C)
+            list(APPEND ARCH_FLAGS -mf16c)
+        endif()
+        if (GGML_FMA)
+            list(APPEND ARCH_FLAGS -mfma)
+        endif()
+        if (GGML_AVX)
+            list(APPEND ARCH_FLAGS -mavx)
+        endif()
+        if (GGML_AVX2)
+            list(APPEND ARCH_FLAGS -mavx2)
+        endif()
+        if (GGML_AVX512)
+            list(APPEND ARCH_FLAGS -mavx512f)
+            list(APPEND ARCH_FLAGS -mavx512bw)
+        endif()
+        if (GGML_AVX512_VBMI)
+            list(APPEND ARCH_FLAGS -mavx512vbmi)
+        endif()
+        if (GGML_AVX512_VNNI)
+            list(APPEND ARCH_FLAGS -mavx512vnni)
+        endif()
+        if (GGML_AVX512_BF16)
+            list(APPEND ARCH_FLAGS -mavx512bf16)
+        endif()
+    endif()
+elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
+    message(STATUS "PowerPC detected")
+    if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
+        list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
+    else()
+        list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
+        #TODO: Add  targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
+    endif()
+elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
+    message(STATUS "loongarch64 detected")
+
+    list(APPEND ARCH_FLAGS -march=loongarch64)
+    if (GGML_LASX)
+        list(APPEND ARCH_FLAGS -mlasx)
+    endif()
+    if (GGML_LSX)
+        list(APPEND ARCH_FLAGS -mlsx)
+    endif()
+else()
+    message(STATUS "Unknown architecture")
+endif()
+
+add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
+add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
+
+if (GGML_CUDA)
+    list(APPEND CUDA_CXX_FLAGS ${ARCH_FLAGS})
+    list(JOIN   CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED)  # pass host compiler flags as a single argument
+
+    if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
+        list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
+    endif()
+
+    add_compile_options("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
+endif()
+
+if (MINGW)
+    # Target Windows 8 for PrefetchVirtualMemory
+    add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
+endif()
+
+#
+# POSIX conformance
+#
+
+# clock_gettime came in POSIX.1b (1993)
+# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
+# posix_memalign came in POSIX.1-2001 / SUSv3
+# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
+add_compile_definitions(_XOPEN_SOURCE=600)
+
+# Somehow in OpenBSD whenever POSIX conformance is specified
+# some string functions rely on locale_t availability,
+# which was introduced in POSIX.1-2008, forcing us to go higher
+if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
+    remove_definitions(-D_XOPEN_SOURCE=600)
+    add_compile_definitions(_XOPEN_SOURCE=700)
+endif()
+
+# Data types, macros and functions related to controlling CPU affinity and
+# some memory allocation are available on Linux through GNU extensions in libc
+if (CMAKE_SYSTEM_NAME MATCHES "Linux")
+    add_compile_definitions(_GNU_SOURCE)
+endif()
+
+# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
+# and on macOS its availability depends on enabling Darwin extensions
+# similarly on DragonFly, enabling BSD extensions is necessary
+if (
+    CMAKE_SYSTEM_NAME MATCHES "Darwin" OR
+    CMAKE_SYSTEM_NAME MATCHES "iOS"    OR
+    CMAKE_SYSTEM_NAME MATCHES "tvOS"   OR
+    CMAKE_SYSTEM_NAME MATCHES "DragonFly"
+)
+    add_compile_definitions(_DARWIN_C_SOURCE)
+endif()
+
+# alloca is a non-standard interface that is not visible on BSDs when
+# POSIX conformance is specified, but not all of them provide a clean way
+# to enable it in such cases
+if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
+    add_compile_definitions(__BSD_VISIBLE)
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
+    add_compile_definitions(_NETBSD_SOURCE)
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
+    add_compile_definitions(_BSD_SOURCE)
+endif()
+
+if (WIN32)
+    add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
+
+    if (BUILD_SHARED_LIBS)
+        # TODO: should not use this
+        set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
     endif()
 endif()
 
-set (GGML_PUBLIC_HEADERS
-     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml.h
-     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml-alloc.h
-     ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml/ggml-backend.h)
+#
+# libraries
+#
 
-set_target_properties(${TARGET} PROPERTIES
-                      PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
+# ggml
+
+add_library(ggml
+            ../include/ggml.h
+            ../include/ggml-alloc.h
+            ../include/ggml-backend.h
+            ggml.c
+            ggml-alloc.c
+            ggml-backend.c
+            ggml-quants.c
+            ggml-quants.h
+            ${GGML_SOURCES_CUDA}      ${GGML_HEADERS_CUDA}
+            ${GGML_SOURCES_METAL}     ${GGML_HEADERS_METAL}
+            ${GGML_SOURCES_RPC}       ${GGML_HEADERS_RPC}
+            ${GGML_SOURCES_EXTRA}     ${GGML_HEADERS_EXTRA}
+            ${GGML_SOURCES_SYCL}      ${GGML_HEADERS_SYCL}
+            ${GGML_SOURCES_KOMPUTE}   ${GGML_HEADERS_KOMPUTE}
+            ${GGML_SOURCES_VULKAN}    ${GGML_HEADERS_VULKAN}
+            ${GGML_SOURCES_ROCM}      ${GGML_HEADERS_ROCM}
+            ${GGML_SOURCES_BLAS}      ${GGML_HEADERS_BLAS}
+            ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
+            )
+
+if (EMSCRIPTEN)
+    set_target_properties(ggml PROPERTIES COMPILE_FLAGS "-msimd128")
+endif()
+
+target_compile_definitions(ggml PUBLIC  ${GGML_CDEF_PUBLIC})
+target_include_directories(ggml PUBLIC ../include)
+target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES})
+target_compile_features   (ggml PRIVATE c_std_11) # don't bump
+
+target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS})
 
-install(TARGETS ${TARGET}
-    LIBRARY DESTINATION lib
-    PUBLIC_HEADER DESTINATION include/ggml
-    )
+find_library(MATH_LIBRARY m)
+if (MATH_LIBRARY)
+    target_link_libraries(ggml PRIVATE ${MATH_LIBRARY})
+endif()
+
+if (BUILD_SHARED_LIBS)
+    set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif()
index 174297949504776493bd229918cd21d30bed5699..13c71c310c446d66d99e57167c2588618e8f1026 100644 (file)
@@ -1172,7 +1172,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
             // check if a backend with higher prio wants to offload the op
             if (src_backend_id == sched->n_backends - 1) {
                 for (int b = 0; b < src_backend_id; b++) {
-                    if (ggml_backend_offload_op(sched->backends[b], tensor)) {
+                    if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
                         SET_CAUSE(tensor, "1.off");
                         return b;
                     }
diff --git a/src/ggml-blas.h b/src/ggml-blas.h
deleted file mode 100644 (file)
index f2e37de..0000000
+++ /dev/null
@@ -1,23 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-// backend API
-GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void);
-
-GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend);
-
-// number of threads used for conversion to float
-// for openblas and blis, this will also set the number of threads used for blas operations
-GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
-
-
-#ifdef  __cplusplus
-}
-#endif
index 593fa4cdaa51431b4c2773a7686c613ffd5dbdaf..0acfda91d3e51b556d5c704abf0e12fe5f8f6ea6 100644 (file)
@@ -152,16 +152,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
     GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
 
     int64_t total_vram = 0;
-#if defined(GGML_CUDA_FORCE_MMQ)
-    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:   yes\n", __func__);
+#ifdef GGML_CUDA_FORCE_MMQ
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    yes\n", __func__);
 #else
-    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:   no\n", __func__);
-#endif
-#if defined(CUDA_USE_TENSOR_CORES)
-    GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    no\n", __func__);
+#endif // GGML_CUDA_FORCE_MMQ
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
 #else
-    GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
-#endif
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
+#endif // GGML_CUDA_FORCE_CUBLAS
     GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
     for (int id = 0; id < info.device_count; ++id) {
         int device_vmm = 0;
@@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
         }
 
         const int cc = ggml_cuda_info().devices[id].cc;
-        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
+        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
     }
     return row_rounding;
 }
@@ -1873,9 +1873,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
 
-    int64_t min_compute_capability = INT_MAX;
+    bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
+    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+    bool              use_mul_mat_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+    bool any_gpus_with_slow_fp16 = false;
 
-    bool any_pascal_with_slow_fp16 = false;
     if (split) {
         ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
         auto & tensor_split = buft_ctx->tensor_split;
@@ -1885,55 +1893,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
                 continue;
             }
 
-            if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
-                min_compute_capability = ggml_cuda_info().devices[id].cc;
-            }
-            if (ggml_cuda_info().devices[id].cc == 610) {
-                any_pascal_with_slow_fp16 = true;
-            }
+            const int cc            = ggml_cuda_info().devices[id].cc;
+            use_mul_mat_vec_q       = use_mul_mat_vec_q       && cc >= MIN_CC_DP4A;
+            use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
         }
     } else {
-        min_compute_capability    = ggml_cuda_info().devices[ctx.device].cc;
-        any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610;
+        const int cc            = ggml_cuda_info().devices[ctx.device].cc;
+        use_mul_mat_vec_q       = use_mul_mat_vec_q       && cc >= MIN_CC_DP4A;
+        use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
     }
 
-    // check data types and tensor shapes for custom matrix multiplication kernels:
-    bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
-
-    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
-
-    bool              use_mul_mat_q =  ggml_cuda_supports_mmq(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
-
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-
-    const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
-
-#ifdef CUDA_USE_TENSOR_CORES
-    use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
-#endif // CUDA_USE_TENSOR_CORES
-
-#else
-
-    // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
-    const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
-
-    // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
-    use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
-    use_mul_mat_q     = use_mul_mat_q     && min_compute_capability >= MIN_CC_DP4A;
-
-#ifdef CUDA_USE_TENSOR_CORES
-    // when tensor cores are available, use them for large batch size
-    // ref: https://github.com/ggerganov/llama.cpp/pull/3776
-    use_mul_mat_q     = use_mul_mat_q     && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
-#endif // CUDA_USE_TENSOR_CORES
-
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-
     // if mmvq is available it's a better choice than dmmv:
 #ifndef GGML_CUDA_FORCE_DMMV
     use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
@@ -1947,14 +1918,15 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
-        // KQ single-batch
+    if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        // FP32 precision KQ single-batch for batch size 1 without FlashAttention
         ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
-    } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
-        // KQV single-batch
+    } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+        // FP32 precision KQV single-batch for batch size 1 without FlashAttention
         ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
-        // KQ + KQV multi-batch
+    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
+               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        // KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
@@ -2267,6 +2239,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SQR:
             ggml_cuda_op_sqr(ctx, dst);
             break;
+        case GGML_OP_SQRT:
+            ggml_cuda_op_sqrt(ctx, dst);
+            break;
         case GGML_OP_CLAMP:
             ggml_cuda_op_clamp(ctx, dst);
             break;
@@ -2830,6 +2805,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_RMS_NORM:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
+        case GGML_OP_SQRT:
         case GGML_OP_CLAMP:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
diff --git a/src/ggml-cuda.h b/src/ggml-cuda.h
deleted file mode 100644 (file)
index d7903c6..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-
-#ifdef GGML_USE_HIPBLAS
-#define GGML_CUDA_NAME "ROCm"
-#define GGML_CUBLAS_NAME "hipBLAS"
-#else
-#define GGML_CUDA_NAME "CUDA"
-#define GGML_CUBLAS_NAME "cuBLAS"
-#endif
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-#define GGML_CUDA_MAX_DEVICES       16
-
-// backend API
-GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
-
-GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
-
-// device buffer
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
-
-// split tensor buffer that splits matrices by rows across multiple devices
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
-
-// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
-
-GGML_API GGML_CALL int  ggml_backend_cuda_get_device_count(void);
-GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
-GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
-
-GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
-GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
-
-GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data);
-#ifdef  __cplusplus
-}
-#endif
index de7c2e4349ede23b6365e3abfee4f2b75d9dd1e8..8d00db6c193ff67e4f4eee5de13c8b38b9bab536 100644 (file)
 #define CC_RDNA2      (CC_OFFSET_AMD + 1030)
 #define CC_RDNA3      (CC_OFFSET_AMD + 1100)
 
-// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
-// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
-// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
-// -  7B quantum model: +100-200 MB
-// - 13B quantum model: +200-400 MB
-//
-//#define GGML_CUDA_FORCE_MMQ
-
-// TODO: improve this to be correct for more hardware
-//       for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
-#if !defined(GGML_CUDA_FORCE_MMQ)
-#define CUDA_USE_TENSOR_CORES
-#endif
-
-#define MMVQ_MAX_BATCH_SIZE  8 // max batch size to use MMVQ kernels
-#define  MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
-
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 #if defined(_MSC_VER)
@@ -343,15 +326,15 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
 #define INT8_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
 
-static bool fast_fp16_available(const int cc) {
+static constexpr bool fast_fp16_available(const int cc) {
     return cc >= CC_PASCAL && cc != 610;
 }
 
-static bool fp16_mma_available(const int cc) {
+static constexpr bool fp16_mma_available(const int cc) {
     return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
 }
 
-static bool int8_mma_available(const int cc) {
+static constexpr bool int8_mma_available(const int cc) {
     return cc < CC_OFFSET_AMD && cc >= CC_TURING;
 }
 
@@ -643,19 +626,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
     static constexpr int qi = QI3_S;
 };
 
-static int get_mmq_x_max_host(const int cc) {
-#ifdef CUDA_USE_TENSOR_CORES
-    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
-#else
-    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
-#endif // CUDA_USE_TENSOR_CORES
-}
-
-// Round rows to this value for --split-mode row:
-static int get_mmq_y_host(const int cc, const int mmq_x) {
-    return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
-}
-
 //////////////////////
 
 struct ggml_cuda_device_info {
index 37b3b99323b20f276586696647ed6d27ec95ad33..bd7993595467c25e949a7545a24c6f7cddc0688a 100644 (file)
@@ -603,7 +603,7 @@ static void on_no_fattn_vec_case(const int D) {
     if (D == 64) {
         fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
         fprintf(stderr, "By default only f16 KV cache is supported.\n");
-        fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
+        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
         GGML_ASSERT(false);
     } else if (D == 128) {
         fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
@@ -611,7 +611,7 @@ static void on_no_fattn_vec_case(const int D) {
         fprintf(stderr, "  - K == q4_0, V == q4_0,  4.50 BPV\n");
         fprintf(stderr, "  - K == q8_0, V == q8_0,  8.50 BPV\n");
         fprintf(stderr, "  - K == f16,  V == f16,  16.00 BPV\n");
-        fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
+        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
         GGML_ASSERT(false);
     } else {
         fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
index 63e07fbc21291013fac4bab8d188fb8bb6c58a5d..5d87dd8e64c0aa2147ca03dd070e60431fffa8cf 100644 (file)
@@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 
 struct mma_int_A_I16K8 {
@@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 
 struct mma_int_B_J8K4 {
@@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
+            : "+r"(x[0])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 
 struct mma_int_B_J8K8 {
@@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
         GGML_CUDA_ASSUME(ret <  K);
         return ret;
     }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
 };
 
 struct mma_int_C_I16J8 {
index 1d6b9e6982b6e3cb78fc9ced9fbd9c01d47a7c75..0308beaccbaa3876fd508b194652ca53f15e6d1d 100644 (file)
@@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q(
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q4_1:
-            mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_0:
-            mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_1:
-            mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
             break;
         case GGML_TYPE_Q8_0:
-            mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q2_K:
-            mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q3_K:
-            mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q4_K:
-            mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_K:
-            mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q6_K:
-            mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
             break;
         default:
             GGML_ASSERT(false);
@@ -69,7 +69,13 @@ void ggml_cuda_op_mul_mat_q(
     GGML_UNUSED(src1_ddf_i);
 }
 
-bool ggml_cuda_supports_mmq(enum ggml_type type) {
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    return false;
+#endif // GGML_CUDA_FORCE_CUBLAS
+
+    bool mmq_supported;
+
     switch (type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
@@ -81,8 +87,32 @@ bool ggml_cuda_supports_mmq(enum ggml_type type) {
         case GGML_TYPE_Q4_K:
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_Q6_K:
-            return true;
+            mmq_supported = true;
+            break;
         default:
-            return false;
+            mmq_supported = false;
+            break;
+    }
+
+    if (!mmq_supported) {
+        return false;
+    }
+
+    if (int8_mma_available(cc)) {
+        return true;
+    }
+
+    if (cc < MIN_CC_DP4A) {
+        return false;
     }
+
+#ifdef GGML_CUDA_FORCE_MMQ
+    return true;
+#endif //GGML_CUDA_FORCE_MMQ
+
+    if (cc < CC_OFFSET_AMD) {
+        return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    }
+
+    return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
index 6d57974fb4e7c45c6997a9a511bfd5d8d6af8f48..31fcbf1397b6bb15475af811299efc8064006dd3 100644 (file)
@@ -7,15 +7,11 @@
 #include <climits>
 #include <cstdint>
 
-#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
 
-typedef void (*load_tiles_mmq_t)(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
-typedef void (*vec_dot_mmq_t)(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0);
-typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
+typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
 
 struct block_q8_1_mmq {
     half2  ds[4];
@@ -30,87 +26,150 @@ struct tile_x_sizes {
     int sc;
 };
 
-// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+static constexpr int get_mmq_x_max_host(const int cc) {
+    return int8_mma_available(cc) ? 128 :
+#ifdef GGML_CUDA_FORCE_MMQ
+        cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128                     : 64;
+#else
+        cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+#endif // GGML_CUDA_FORCE_MMQ
+}
 
 static constexpr __device__ int get_mmq_x_max_device() {
+#ifdef INT8_MMA_AVAILABLE
+    return 128;
+#else // INT8_MMA_AVAILABLE
+
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    return 64;
-#else
+    return 128;
+#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
 #if __CUDA_ARCH__ >= CC_VOLTA
-#ifdef CUDA_USE_TENSOR_CORES
-    return MMQ_MAX_BATCH_SIZE;
-#else
+#ifdef GGML_CUDA_FORCE_MMQ
+    return MMQ_DP4A_MAX_BATCH_SIZE;
+#else // GGML_CUDA_FORCE_MMQ
     return 128;
-#endif // CUDA_USE_TENSOR_CORES
-#else
+#endif // GGML_CUDA_FORCE_MMQ
+#else // __CUDA_ARCH__ >= CC_VOLTA
+
     return 64;
 #endif // __CUDA_ARCH__ >= CC_VOLTA
+
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // INT8_MMA_AVAILABLE
 }
 
-// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+static constexpr int get_mmq_y_host(const int cc) {
+    return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64;
+}
 
+static constexpr __device__ int get_mmq_y_device() {
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-static constexpr __device__ int get_mmq_y_device(int mmq_x) {
-    return mmq_x >= 32 ? 128 : 64;
-}
+    return 128;
 #else
 #if __CUDA_ARCH__ >= CC_VOLTA
-static constexpr __device__ int get_mmq_y_device(int mmq_x) {
-    return mmq_x >= 32 ? 128 : 64;
-}
+    return 128;
 #else
-static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
     return 64;
-}
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
 
-#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
-#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
-#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
-#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
-#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
-#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE       + mmq_y,       0}
-#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
-#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
-
-#define GET_TILE_X_SIZES_BODY                           \
-    return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
-        type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 :    \
-        type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 :    \
-        type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 :    \
-        type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 :    \
-        type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K :    \
-        type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K :    \
-        type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K :    \
-        type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K :    \
-        type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K :    \
-        tile_x_sizes{0, 0, 0}
-
-static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
-    GET_TILE_X_SIZES_BODY;
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
+#define MMQ_DP4A_TXS_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
+#define MMQ_DP4A_TXS_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE       + mmq_y,       0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+
+static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
+    return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
+        type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
+        type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
+        type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
+        type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
+        type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
+        type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
+        type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
+        type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
+        tile_x_sizes{0, 0, 0};
 }
 
-template <int mmq_y>
-static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
-    GET_TILE_X_SIZES_BODY;
+#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0               + 4)
+#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1               + 4)
+#define MMQ_MMA_TILE_X_K_Q5_0 (2*WARP_SIZE + WARP_SIZE/QI5_0               + 4)
+#define MMQ_MMA_TILE_X_K_Q5_1 (2*WARP_SIZE + WARP_SIZE/QI5_1               + 4)
+#define MMQ_MMA_TILE_X_K_Q8_0 (1*WARP_SIZE + WARP_SIZE/QI8_0               + 0)
+#define MMQ_MMA_TILE_X_K_Q2_K (1*WARP_SIZE + WARP_SIZE                     + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/QI3_K + WARP_SIZE/4 + 2)
+#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
+
+static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+
+static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
+        type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
+        type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
+        type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
+        type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
+        type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
+        type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
+        type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
+        0;
 }
 
+#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+#define MMQ_NWARPS 8
+
+static int mmq_get_granularity_host(const int mmq_x, const int cc) {
+    return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+}
+
+#ifdef INT8_MMA_AVAILABLE
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+    return mmq_x >= 48 ? 16 : 8;
+}
+#else
+static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
+    return 8;
+}
+#endif // INT8_MMA_AVAILABLE
+
 // ------------------------------------------------------------
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_0;
     const int kqsx = threadIdx.x % QI4_0;
 
-    float * x_dmf = (float *) x_dm;
-
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + threadIdx.y;
@@ -121,7 +180,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
 
-        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#else
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -137,17 +200,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q4_0       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_df = (const float *) x_dm;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
@@ -178,76 +245,90 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A;
-    float dA[mma_C::ne/2];
+    mma_A A[ntx];
+    float dA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i     = i0 + mma_A::get_i(l);
-        const int k     = k0 + mma_A::get_k(l) % QI4_0;
-        const int shift =   4*(mma_A::get_k(l) / QI4_0);
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i     = i0 + n*mma_A::I + mma_A::get_i(l);
+            const int k     = k0              + mma_A::get_k(l) % QI4_0;
+            const int shift =                4*(mma_A::get_k(l) / QI4_0);
+
+            A[n].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
+        }
 
-        A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/QI4_0];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
-        mma_B B;
-        half2 dsB[mma_C::ne/2];
-
 #pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j =    j0 + mma_B::get_j(l);
-            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        mma_B  B;
+        float dB[mma_C::ne/2];
+
+        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
 
-            dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
+            dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_1;
     const int kqsx = threadIdx.x % QI4_1;
@@ -262,7 +343,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
 
-        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#else
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -278,16 +363,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q4_1       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
@@ -318,51 +408,53 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_A_I16K4 mma_A_K4;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A;
-    half2 dmA[mma_C::ne/2];
+    mma_A A[ntx];
+    half2 dmA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i     = i0 + mma_A::get_i(l);
-        const int k     = k0 + mma_A::get_k(l) % QI4_0;
-        const int shift =   4*(mma_A::get_k(l) / QI4_0);
+    for (int n = 0; n < ntx; ++n) {
+        ((mma_A_K4 *) &A[n])[0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0, MMQ_MMA_TILE_X_K_Q4_1);
+        A[n].x[2]  = (A[n].x[0] >> 4) & 0x0F0F0F0F;
+        A[n].x[3]  = (A[n].x[1] >> 4) & 0x0F0F0F0F;
+        A[n].x[0] &= 0x0F0F0F0F;
+        A[n].x[1] &= 0x0F0F0F0F;
 
-        A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/QI4_1];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B;
         half2 dsB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j =    j0 + mma_B::get_j(l);
-            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -370,24 +462,35 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
-            sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+            for (int l = 0; l < mma_C::ne; ++l) {
+                const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_0;
     const int kqsx = threadIdx.x % QI5_0;
@@ -412,8 +515,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
         qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
 
-        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
-
         int qs1 = (ql >>  4)   & 0x0F0F0F0F;
         qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
         qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
@@ -421,12 +522,17 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
-    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
@@ -438,19 +544,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q5_0       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_dmf = (const float *) x_dm;
-    const int   * y_qs  = (const int   *) y + 4;
-    const float * y_df  = (const float *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -460,70 +570,57 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
-            const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
-
-            int u[2*VDR_Q5_0_Q8_1_MMQ];
-
-#pragma unroll
-            for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
-            }
-
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
-                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
+                 x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    mma_A A;
-    float dA[mma_C::ne/2];
+    mma_A A[ntx];
+    float dA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i     =    i0 + mma_A::get_i(l);
-        const int k     = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
+    for (int n = 0; n < ntx; ++n) {
+        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_0 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_0);
 
-        A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
 
-        dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + k0/QI5_0];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B;
         float dB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j =    j0 + mma_B::get_j(l);
-            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -531,23 +628,34 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
             dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*C.x[l];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_1;
     const int kqsx = threadIdx.x % QI5_1;
@@ -571,15 +679,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
         qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
 
-        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
-
         int qs1 = (ql >>  4) & 0x0F0F0F0F;
         qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
         qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-        x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -595,18 +707,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q5_1       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const int   * y_qs  = (const int   *) y + 4;
-    const half2 * y_ds  = (const half2 *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -616,69 +733,57 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
-            const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
-
-            int u[2*VDR_Q5_1_Q8_1_MMQ];
-
-#pragma unroll
-            for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
-            }
-
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
-                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE],
+                 x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A;
-    half2 dmA[mma_C::ne/2];
+    mma_A A[ntx];
+    half2 dmA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i     =    i0 + mma_A::get_i(l);
-        const int k     = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
+    for (int n = 0; n < ntx; ++n) {
+        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_1 + QR5_1*k0, MMQ_MMA_TILE_X_K_Q5_1);
 
-        A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + mma_C::get_i(2*l) + n*mma_C::I;
 
-        dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_1 + k0/QI5_1];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B;
         half2 dsB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j =    j0 + mma_B::get_j(l);
-            const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -686,28 +791,38 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
             dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
-            sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+            for (int l = 0; l < mma_C::ne; ++l) {
+                const half2 dmA_dsB = dmA[n][l/2]*dsB[l%2];
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
-    GGML_UNUSED(x_sc);
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_tile + WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI8_0;
     const int kqsx = threadIdx.x % QI8_0;
-    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -719,7 +834,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#else
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
@@ -735,19 +854,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0         + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
-    GGML_UNUSED(x_sc);
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_dmf = (const float *) x_dm;
-    const int   * y_qs  = (const int   *) y + 4;
-    const float * y_df  = (const float *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -758,7 +881,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
             const int i = i0 + threadIdx.x;
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
-                (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
+                (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
                 y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
         }
     }
@@ -766,51 +889,48 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
-    GGML_UNUSED(x_sc);
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    mma_A A;
-    float dA[mma_C::ne/2];
+    mma_A A[ntx];
+    float dA[ntx][mma_C::ne/2];
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i = i0 + mma_A::get_i(l);
-        const int k = k0 + mma_A::get_k(l);
+    for (int n = 0; n < ntx; ++n) {
+        A[n].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
 
-        A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
-    }
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
-        dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+        }
     }
 
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C;
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B;
         float dB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j = j0 + mma_B::get_j(l);
-            const int k = k0 + mma_B::get_k(l);
+        B.load(y_qs + j0*MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
 
-            B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -818,22 +938,34 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
             dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
         }
 
-        C.mma_K8(A, B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C;
+            C.mma_K8(A[n], B);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI2_K;
     const int kqsx = threadIdx.x % QI2_K;
@@ -862,7 +994,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
                 continue;
             }
 
-            x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
+#else
+            x_qs[i*(WARP_SIZE + 1)       + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
         }
 
         const int sc_m = bxi->scales[kqsx];
@@ -873,15 +1009,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
 #endif // FAST_FP16_AVAILABLE
 
-        x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + threadIdx.x] = x_dm_ik;
+#else
+        x_dm[i*(WARP_SIZE + 1)       + threadIdx.x] = x_dm_ik;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
@@ -902,61 +1044,63 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
     typedef mma_int_B_J8K4  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[2];
-    float  dA[mma_C::ne/2][2];
-    float  mA[mma_C::ne/2][2];
+    mma_A   A[ntx][2];
+    float  dA[ntx][mma_C::ne/2][2];
+    float  mA[ntx][mma_C::ne/2][2];
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i = i0 + mma_A::get_i(l);
-        const int shift = 2*mma_A::get_k(l);
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i = i0 + n*mma_A::I + mma_A::get_i(l);
+            const int shift = 2*mma_A::get_k(l);
 
-        A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
-        A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
-    }
+            A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 0] >> shift) & 0x03030303;
+            A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + 1] >> shift) & 0x03030303;
+        }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
 #pragma unroll
-        for (int kk = 0; kk < 2; ++kk) {
-            const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
+            for (int kdm = 0; kdm < 2; ++kdm) {
+                const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0 + kdm]);
 
-            dA[l][kk] = dm.x;
-            mA[l][kk] = dm.y;
+                dA[n][l][kdm] = dm.x;
+                mA[n][l][kdm] = dm.y;
+            }
         }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C Cd[2];
-        mma_C Cm[2];
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B[2];
         float dB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j = j0 + mma_B::get_j(l);
-            const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + 0)        % WARP_SIZE, MMQ_TILE_Y_K);
+        B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR2_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
-            B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -964,9 +1108,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
             dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        Cd[0].mma_K4(A[0], B[0]);
-        Cd[1].mma_K4(A[1], B[1]);
-
+        mma_C Cm[2];
         mma_A A1;
         A1.x[0] = 0x01010101;
         A1.x[1] = 0x01010101;
@@ -974,19 +1116,38 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         Cm[1].mma_K4(A1, B[1]);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
+    for (int n = 0; n < ntx; ++n) {
+            mma_C Cd[2];
+
+            Cd[0].mma_K4(A[n][0], B[0]);
+            Cd[1].mma_K4(A[n][1], B[1]);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += (
+                    Cd[0].x[l]*dA[n][l/2][0] + Cd[1].x[l]*dA[n][l/2][1] - Cm[0].x[l]*mA[n][l/2][0] - Cm[1].x[l]*mA[n][l/2][1])*dB[l%2];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    int   * x_sc = (int   *) (x_df + WARP_SIZE/QI3_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI3_K;
     const int kqsx = threadIdx.x % QI3_K;
@@ -1018,13 +1179,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
                 continue;
             }
 
-            x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2] = x_qs_k;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + k/2] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
         }
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
-    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
@@ -1036,7 +1200,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1061,16 +1229,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
-        x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/4)] = sc;
+#else
+        x_sc[i*(WARP_SIZE/4) + i/4   + threadIdx.x % (WARP_SIZE/4)] = sc;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_df = (const float *) x_dm;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_df + txs.dm;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
@@ -1096,69 +1270,72 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
     typedef mma_int_B_J8K4  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * x_sc = (const int   *) x_df + WARP_SIZE/QI3_K;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
 
-    mma_A   A[2];
-    int   scA[mma_C::ne/2][2];
-    float  dA[mma_C::ne/2];
+    mma_A   A[ntx][2];
+    int   scA[ntx][mma_C::ne/2][2];
+    float  dA[ntx][mma_C::ne/2];
 
 #pragma unroll
-    for (int l = 0; l < mma_A::ne; ++l) {
-        const int i = i0 + mma_A::get_i(l);
-        const int k = QR3_K*k0 + mma_A::get_k(l);
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int l = 0; l < mma_A::ne; ++l) {
+            const int i = i0 + n*mma_A::I + mma_A::get_i(l);
+            const int k = QR3_K*k0 + mma_A::get_k(l);
 
-        A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0]          >> (4*(k%2))) & 0x0F0F0F0F;
-        A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
-        A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
-        A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
-    }
+            A[n][0].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + 0]          >> (4*(k%2))) & 0x0F0F0F0F;
+            A[n][1].x[l] = (x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
+            A[n][0].x[l] = __vsubss4(A[n][0].x[l], 0x04040404);
+            A[n][1].x[l] = __vsubss4(A[n][1].x[l], 0x04040404);
+        }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        const int kbx  = k0 / QI3_K;
-        const int ky  = (k0 % QI3_K) * QR3_K;
-        const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+            const int kbx  = k0 / QI3_K;
+            const int ky  = (k0 % QI3_K) * QR3_K;
+            const int8_t * sc = ((const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q3_K + kbx*4)) + ky/4;
 
-        scA[l][0] = sc[0];
-        scA[l][1] = sc[1];
-    }
+            scA[n][l][0] = sc[0];
+            scA[n][l][1] = sc[1];
+        }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/QI3_K];
+        }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        mma_C C[2];
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
         mma_B B[2];
         float dB[mma_C::ne/2];
 
-#pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int j = j0 + mma_B::get_j(l);
-            const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
+        B[0].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + 0)        % WARP_SIZE, MMQ_TILE_Y_K);
+        B[1].load(y_qs + j0*MMQ_TILE_Y_K + (QR3_K*k0 + mma_B::K) % WARP_SIZE, MMQ_TILE_Y_K);
 
-            B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
-            B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
-        }
 #pragma unroll
         for (int l = 0; l < mma_C::ne/2; ++l) {
             const int j = j0 + mma_C::get_j(l);
@@ -1166,23 +1343,37 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
             dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
         }
 
-        C[0].mma_K4(A[0], B[0]);
-        C[1].mma_K4(A[1], B[1]);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            mma_C C[2];
+            C[0].mma_K4(A[n][0], B[0]);
+            C[1].mma_K4(A[n][1], B[1]);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += (C[0].x[l]*scA[n][l/2][0] + C[1].x[l]*scA[n][l/2][1])*dA[n][l/2]*dB[l%2];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
+    int   * x_sc = (int   *) (x_dm + WARP_SIZE/QI4_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI4_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
@@ -1197,7 +1388,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
 
-        x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#else
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_K;  // == 1 if QK_K == 256
@@ -1213,7 +1408,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q4_K       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1234,15 +1433,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
         scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
 
-        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8;
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8   + ksc] = scales8;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_dm + txs.dm;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
@@ -1265,71 +1471,79 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
+    const int   * x_sc = (const int   *) x_dm + WARP_SIZE/QI4_K;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][2];
+    int   scA[ntx][mma_C::ne/2][2];
+    int    mA[ntx][mma_C::ne/2][2];
+    half2 dmA[ntx][mma_C::ne/2];
 
-    mma_A   A[2];
-    int   scA[mma_C::ne/2][2];
-    int    mA[mma_C::ne/2][2];
-    half2 dmA[mma_C::ne/2];
 #pragma unroll
-    for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+    for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i = i0 + mma_A::get_i(l);
-            const int k = k0 + mma_A::get_k(l);
+        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 8) {
+            A[n][kvdr/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0, MMQ_MMA_TILE_X_K_Q4_K);
 
-            A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
+#pragma unroll
+            for (int l = 0; l < mma_A::ne; ++l) {
+                A[n][kvdr/4 + 1].x[l]  = (A[n][kvdr/4 + 0].x[l] >> 4) & 0x0F0F0F0F;
+                A[n][kvdr/4 + 0].x[l] &= 0x0F0F0F0F;
+            }
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + mma_C::get_i(2*l);
+        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
-            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
-            const uint8_t *  m = sc + 8;
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + k0/16]) + 2 * ((k0 % 16) / 8);
+                const uint8_t *  m = sc + 8;
 
-            scA[l][kvdr/4] = sc[kvdr/4];
-            mA[l][kvdr/4]  =  m[kvdr/4];
+                scA[n][l][kvdr/4] = sc[kvdr/4];
+                mA[n][l][kvdr/4]  =  m[kvdr/4];
+            }
         }
-    }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
 
-        dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + k0/QI4_K];
+        }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        float tmpd[mma_C::ne] = {0.0f};
-        float tmpm[mma_C::ne] = {0.0f};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float tmpd[ntx][mma_C::ne] = {{0.0f}};
+        float tmpm[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
-        for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
-            mma_C   C;
+        for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
             mma_B   B;
             half2 dsB[mma_C::ne/2];
 
-#pragma unroll
-            for (int l = 0; l < mma_B::ne; ++l) {
-                const int j = j0 + mma_B::get_j(l);
-                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+            B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
 
-                B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-            }
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
@@ -1337,29 +1551,46 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
                 dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
             }
 
-            C.mma_K8(A[kvdr/4], B);
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][kvdr/4], B);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) *  __low2float(dsB[l%2]);
-                tmpm[l] += mA[l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) *  __low2float(dsB[l%2]);
+                    tmpm[n][l] += mA[n][l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                }
             }
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
+    int   * x_sc = (int   *) (x_dm + WARP_SIZE/QI5_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI5_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
@@ -1386,8 +1617,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
         const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
 
-        x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
-        x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_K;  // == 1 if QK_K == 256
@@ -1403,7 +1639,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
 
-        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q5_K       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1424,17 +1664,24 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
         scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
 
-        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8;
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8   + ksc] = scales8;
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const int   * y_qs  = (const int   *) y + 4;
-    const half2 * y_ds  = (const half2 *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_dm + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -1455,71 +1702,70 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K8 mma_A;
     typedef mma_int_B_J8K8  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
+    const int   * x_sc = (const int   *) x_dm + WARP_SIZE/QI5_K;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][2];
+    int   scA[ntx][mma_C::ne/2][2];
+    int    mA[ntx][mma_C::ne/2][2];
+    half2 dmA[ntx][mma_C::ne/2];
 
-    mma_A   A[2];
-    int   scA[mma_C::ne/2][2];
-    int    mA[mma_C::ne/2][2];
-    half2 dmA[mma_C::ne/2];
 #pragma unroll
-    for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+    for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i = i0 + mma_A::get_i(l);
-            const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
-
-            A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
-        }
+        for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
+            A[n][kvdr/4].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + (QR5_K*k0 + QR5_K*kvdr), MMQ_MMA_TILE_X_K_Q5_K);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + mma_C::get_i(2*l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
-            const uint8_t *  m = sc + 8;
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + k0/16]) + 2 * ((k0 % 16) / 8);
+                const uint8_t *  m = sc + 8;
 
-            scA[l][kvdr/4] = sc[kvdr/4];
-            mA[l][kvdr/4]  =  m[kvdr/4];
+                scA[n][l][kvdr/4] = sc[kvdr/4];
+                mA[n][l][kvdr/4]  =  m[kvdr/4];
+            }
         }
-    }
 
-#pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+    #pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
+            dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + k0/QI5_K];
+        }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        float tmpd[mma_C::ne] = {0.0f};
-        float tmpm[mma_C::ne] = {0.0f};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float tmpd[ntx][mma_C::ne] = {{0.0f}};
+        float tmpm[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
         for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
-            mma_C   C;
             mma_B   B;
             half2 dsB[mma_C::ne/2];
 
-#pragma unroll
-            for (int l = 0; l < mma_B::ne; ++l) {
-                const int j = j0 + mma_B::get_j(l);
-                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+            B.load(y_qs + j0*MMQ_TILE_Y_K + (2*k0 + 2*kvdr) % WARP_SIZE, MMQ_TILE_Y_K);
 
-                B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
-            }
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
@@ -1527,29 +1773,46 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
                 dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
             }
 
-            C.mma_K8(A[kvdr/4], B);
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][kvdr/4], B);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) *  __low2float(dsB[l%2]);
-                tmpm[l] += mA[l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    tmpd[n][l] += (C.x[l]*scA[n][l/2][kvdr/4]) *  __low2float(dsB[l%2]);
+                    tmpm[n][l] += mA[n][l/2][kvdr/4]           * __high2float(dsB[l%2]);
+                }
             }
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
-    const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
-    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    int   * x_sc = (int   *) (x_df + WARP_SIZE/QI6_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI6_K
     const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
@@ -1576,13 +1839,17 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
         const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
 
-        x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
-        x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // INT8_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
     const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
-    float * x_dmf = (float *) x_dm;
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
@@ -1594,7 +1861,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
 
-        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q6_K       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1607,18 +1878,24 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
 
-        x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+#endif // INT8_MMA_AVAILABLE
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
-    const float * x_dmf = (const float *) x_dm;
-    const int   * y_qs  = (const int   *) y + 4;
-    const float * y_df  = (const float *) y;
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_df + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -1632,80 +1909,77 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
                 &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
-                x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
+                x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
-    const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
-    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 #ifdef INT8_MMA_AVAILABLE
 
     typedef mma_int_A_I16K4 mma_A;
     typedef mma_int_B_J8K4  mma_B;
     typedef mma_int_C_I16J8 mma_C;
 
-    const float * x_df = (const float *) x_dm;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * x_sc = (const int   *) x_df + WARP_SIZE/QI6_K;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = threadIdx.y*mma_A::I;
-#ifdef INT8_MMA_AVAILABLE
-    static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
-#endif // INT8_MMA_AVAILABLE
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][4];
+    int   scA[ntx][mma_C::ne/2][4];
+    float  dA[ntx][mma_C::ne/2];
 
-    mma_A   A[4];
-    int   scA[mma_C::ne/2][4];
-    float  dA[mma_C::ne/2];
 #pragma unroll
-    for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+    for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_A::ne; ++l) {
-            const int i = i0 + mma_A::get_i(l);
-            const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
-
-            A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
-            A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
-        }
+        for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
+            A[n][kvdr/2 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + 0),        MMQ_MMA_TILE_X_K_Q6_K);
+            A[n][kvdr/2 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (QR6_K*k0 + QR6_K*kvdr + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + mma_C::get_i(2*l);
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-            const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
+                const int8_t * sc = ((const int8_t *) &x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/8]);
 
-            scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
-            scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+                scA[n][l][kvdr/2 + 0] = sc[kvdr/2 + 0];
+                scA[n][l][kvdr/2 + 1] = sc[kvdr/2 + 1];
+            }
         }
-    }
 
 #pragma unroll
-    for (int l = 0; l < mma_C::ne/2; ++l) {
-        const int i = i0 + mma_C::get_i(2*l);
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
 
-        dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K + k0/QI6_K];
+        }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
-        float tmp[mma_C::ne] = {0.0f};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float tmp[ntx][mma_C::ne] = {{0.0f}};
 
 #pragma unroll
         for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
-            mma_C C[2];
             mma_B B[2];
             float dB[mma_C::ne/2];
 
-#pragma unroll
-            for (int l = 0; l < mma_B::ne; ++l) {
-                const int j = j0 + mma_B::get_j(l);
-                const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
+            const int k0B = (2*k0 + 2*kvdr) % WARP_SIZE;
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k0B, MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k0B, MMQ_TILE_Y_K);
 
-                B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
-                B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
-            }
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
                 const int j = j0 + mma_C::get_j(l);
@@ -1713,75 +1987,93 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
                 dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
             }
 
-            C[0].mma_K4(A[kvdr/2 + 0], B[0]);
-            C[1].mma_K4(A[kvdr/2 + 1], B[1]);
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C[2];
+                C[0].mma_K4(A[n][kvdr/2 + 0], B[0]);
+                C[1].mma_K4(A[n][kvdr/2 + 1], B[1]);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    tmp[n][l] += (C[0].x[l]*scA[n][l/2][kvdr/2 + 0] + C[1].x[l]*scA[n][l/2][kvdr/2 + 1])*dB[l%2];
+                }
             }
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+            }
         }
     }
 #else
-    GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
 #endif // INT8_MMA_AVAILABLE
 }
 
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
-static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+static __device__ __forceinline__ void mmq_write_back_dp4a(
+    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
+        const int j = j0 + threadIdx.y;
 
-        if (j >= ne1) {
+        if (j > j_max) {
             return;
         }
 
 #pragma unroll
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
+            const int i = i0 + threadIdx.x;
 
-            if (need_check && i >= ne0) {
+            if (need_check && i > i_max) {
                 continue;
             }
 
-            dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+            dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
         }
     }
 }
 
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
-static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+static __device__ __forceinline__ void mmq_write_back_mma(
+    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
     typedef mma_int_C_I16J8 mma_C;
 
-    const int i0 = threadIdx.y*mma_C::I;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
 #ifdef INT8_MMA_AVAILABLE
     static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
 #endif // INT8_MMA_AVAILABLE
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_C::ne; ++l) {
-            const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
+            for (int l = 0; l < mma_C::ne; ++l) {
+                const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
 
-            if (j >= ne1) {
-                continue;
-            }
+                if (j > j_max) {
+                    continue;
+                }
 
-            const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
+                const int i = i0 + n*mma_C::I + mma_C::get_i(l);
 
-            if (need_check && i >= ne0) {
-                continue;
-            }
+                if (need_check && i > i_max) {
+                    continue;
+                }
 
-            dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
+                dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+            }
         }
     }
 }
@@ -1896,6 +2188,75 @@ static bool mmq_need_sum(const ggml_type type_x) {
     return false;
 }
 
+template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
+static __device__ void mul_mat_q_process_tile(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+    const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
+    const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
+
+    constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
+    constexpr int              qr         = ggml_cuda_type_traits<type>::qr;
+    constexpr int              qi         = ggml_cuda_type_traits<type>::qi;
+    constexpr int              mmq_y      = get_mmq_y_device();
+    constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
+    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
+
+    extern __shared__ char data_mul_mat_q[];
+    int * tile_y = (int *) data_mul_mat_q;
+    int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
+
+#ifdef INT8_MMA_AVAILABLE
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
+    constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
+    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
+
+    constexpr int blocks_per_warp = WARP_SIZE / qi;
+
+    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+    const int tile_x_max_i = ne01 - it*mmq_y - 1;
+    const int tile_y_max_j = ne11 - jt*mmq_x - 1;
+
+    const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
+
+    for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
+
+        load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
+
+#pragma unroll
+        for (int kr = 0; kr < qr; ++kr) {
+            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+                tile_y[l] = by0[l];
+            }
+
+            __syncthreads();
+
+// #pragma unroll // unrolling this loop causes too much register pressure
+            for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
+                vec_dot(tile_x, tile_y, sum, k0);
+            }
+
+            __syncthreads();
+        }
+    }
+
+    if (fixup) {
+        write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
+    } else {
+        write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
+    }
+}
+
+
+// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
+
 template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
@@ -1909,75 +2270,153 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 static __global__ void mul_mat_q(
-    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
     const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
 
     // Skip unused template specializations for faster compilation:
-    if (mmq_x > get_mmq_x_max_device()) {
+    if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
         NO_DEVICE_CODE;
         return;
     }
 
-    constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
-    constexpr int              qr         = ggml_cuda_type_traits<type>::qr;
-    constexpr int              qi         = ggml_cuda_type_traits<type>::qi;
-    constexpr int              mmq_y      = get_mmq_y_device(mmq_x);
-    constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
-    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
+    constexpr int qk    = ggml_cuda_type_traits<type>::qk;
+    constexpr int qi    = ggml_cuda_type_traits<type>::qi;
+    constexpr int mmq_y = get_mmq_y_device();
 
-#ifdef INT8_MMA_AVAILABLE
-    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
-    constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
-#else
-    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
-    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+    // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+    {
+        constexpr bool fixup = false;
+        mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+                blockIdx.x, blockIdx.y, 0, ne00/qk);
+        return;
+    }
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
 
-    constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
+    const     int64_t blocks_per_ne00 = ne00 / qk;
+    constexpr int     blocks_per_warp = WARP_SIZE / qi;
 
-    extern __shared__ char data_mul_mat_q[];
-    int   * tile_x_qs = (int   *)  data_mul_mat_q;
-    half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
-    int   * tile_x_sc = (int   *) (tile_x_dm + txs.dm);
-    int   * tile_y    = (int   *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
+    const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
+    const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
+
+    // kbc == k block continuous, current index in continuous ijk space.
+    int64_t       kbc      = GGML_PAD((int64_t) blockIdx.x     *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
+    const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
+
+    // kb0 == k index when doing the matrix multiplication for an output tile.
+    int kb0_start = kbc % blocks_per_ne00;
+    int kb0_stop  = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
+    while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
+        const int jt =  kbc /    (blocks_per_ne00*nty);                    // j index of current tile.
+        const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
+
+        constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+        mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+             it, jt, kb0_start, kb0_stop);
+
+        kbc += blocks_per_ne00;
+        kbc -= kbc % blocks_per_ne00;
+
+        kb0_start = 0;
+        kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);
+    }
+
+    if (kbc >= kbc_stop) {
+        return;
+    }
 
-    const int blocks_per_row_x = ne00 / qk;
-    const int blocks_per_warp = WARP_SIZE / qi;
+    const int jt =  kbc /    (blocks_per_ne00*nty);
+    const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+    constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+    mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+        (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+            it, jt, kb0_start, kb0_stop);
+}
 
-    const int & ne1 = ne11;
 
-    const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+static __global__ void mul_mat_q_stream_k_fixup(
+    float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
 
-    const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
+    constexpr int     mmq_y           = get_mmq_y_device();
+    constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
+    constexpr int     qi              = ggml_cuda_type_traits<type>::qi;
+    constexpr int     blocks_per_warp = WARP_SIZE / qi;
+    const     int64_t blocks_per_ne00 = ne00 / qk;
 
     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 
-    for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
+    const int ntx = (ne11 + mmq_x - 1) / mmq_x;
+    const int nty = (ne01 + mmq_y - 1) / mmq_y;
+
+    bool any_fixup = false;
+
+    const int bidx_start = (blockIdx.y*nty + blockIdx.x)     * block_num_mmq / (gridDim.y*gridDim.x);
+    const int bidx_stop  = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
+
+    for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
+        const int64_t kbc      = GGML_PAD((int64_t) bidx     *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
+        const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
+
+        // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
+        if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
+            continue;
+        }
+
+        const int jt =  kbc_stop /    (blocks_per_ne00*nty);
+        const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+        // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
+        if (it != blockIdx.x || jt != blockIdx.y) {
+            continue;
+        }
 
-        load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
+        any_fixup = true;
 
 #pragma unroll
-        for (int kr = 0; kr < qr; ++kr) {
-            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
 #pragma unroll
-            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
-                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-                tile_y[l] = by0[l];
+                sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
             }
+        }
+    }
 
-            __syncthreads();
+    if (!any_fixup) {
+        return;
+    }
 
-// #pragma unroll // unrolling this loop causes too much register pressure
-            for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
-                vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
+    dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
+
+    const int i_max = ne01 - blockIdx.x*mmq_y - 1;
+    const int j_max = ne11 - blockIdx.y*mmq_x - 1;
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j > j_max) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            if (need_check && i > i_max) {
+                continue;
             }
 
-            __syncthreads();
+            dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
         }
     }
-
-    write_back(sum, dst, ne0, ne1);
 }
 
 struct mmq_args {
@@ -1987,124 +2426,157 @@ struct mmq_args {
     int64_t ne0;
 };
 
-constexpr int mmq_get_nwarps(int mmq_x) {
-    return mmq_x >= 32 ? 8 : 4;
-}
-
-static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
-    const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
-    const int nwarps = mmq_get_nwarps(mmq_x);
-
-    const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
-    const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
-    return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
+template<ggml_type type>
+static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
+    const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
+    const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
+    const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
+    return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
 }
 
-template <ggml_type type, int mmq_x, int nwarps>
-static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
+template <ggml_type type, int mmq_x>
+static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
     const int id = ggml_cuda_get_device();
     const int cc = ggml_cuda_info().devices[id].cc;
-    const int mmq_y = get_mmq_y_host(cc, mmq_x);
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    const int mmq_y = get_mmq_y_host(cc);
 
-    const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
-    const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+    const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
 
-    const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
+    const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
 
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
     static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
     if (!shmem_limit_raised[id]) {
-        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
-        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
         shmem_limit_raised[id] = true;
     }
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 
+    const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
+    const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
+    const dim3 block_nums_xy_tiling(nty, ntx, 1);
+
+    const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
+    if (!use_stream_k) {
+        if (args.ne01 % mmq_y == 0) {
+            constexpr bool need_check = false;
+            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        } else {
+            constexpr bool need_check = true;
+            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        }
+        return;
+    }
+
+    const dim3 block_nums_mmq(nsm, 1, 1);
+
+    ggml_cuda_pool & pool = ctx.pool();
+    ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
+
     if (args.ne01 % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
-            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        constexpr bool need_check = false;
+
+        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
     } else {
-        const bool need_check = true;
-        mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
-            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        constexpr bool need_check = true;
+
+        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
     }
 }
 
 template <ggml_type type>
-void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
+void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
     const int id    = ggml_cuda_get_device();
     const int nsm   = ggml_cuda_info().devices[id].nsm;
     const int cc    = ggml_cuda_info().devices[id].cc;
     const int smpbo = ggml_cuda_info().devices[id].smpbo;
 
     const int mmq_x_max = get_mmq_x_max_host(cc);
-    const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
+    const int mmq_y = get_mmq_y_host(cc);
     const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
+    const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
 
     int mmq_x_best  = 0;
-    int nwaves_best = INT_MAX;
+    int nparts_best = INT_MAX;
+
+    for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+        const int granularity = mmq_get_granularity_host(mmq_x, cc);
+
+        if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
+            continue;
+        }
 
-    for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
-        const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
-        const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
+        const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
+        const int nwaves_xy_tiling = ntiles_x*block_num_y;
+        const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
 
-        if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
+        if (nparts < nparts_best) {
             mmq_x_best  = mmq_x;
-            nwaves_best = nwaves;
+            nparts_best = nparts;
         }
     }
 
     switch (mmq_x_best) {
         case   8:
-            launch_mul_mat_q<type,   8, mmq_get_nwarps(  8)>(args, stream);
+            launch_mul_mat_q<type,   8>(ctx, args, stream);
             break;
         case  16:
-            launch_mul_mat_q<type,  16, mmq_get_nwarps( 16)>(args, stream);
+            launch_mul_mat_q<type,  16>(ctx, args, stream);
             break;
         case  24:
-            launch_mul_mat_q<type,  24, mmq_get_nwarps( 24)>(args, stream);
+            launch_mul_mat_q<type,  24>(ctx, args, stream);
             break;
         case  32:
-            launch_mul_mat_q<type,  32, mmq_get_nwarps( 32)>(args, stream);
+            launch_mul_mat_q<type,  32>(ctx, args, stream);
             break;
         case  40:
-            launch_mul_mat_q<type,  40, mmq_get_nwarps( 40)>(args, stream);
+            launch_mul_mat_q<type,  40>(ctx, args, stream);
             break;
         case  48:
-            launch_mul_mat_q<type,  48, mmq_get_nwarps( 48)>(args, stream);
+            launch_mul_mat_q<type,  48>(ctx, args, stream);
             break;
         case  56:
-            launch_mul_mat_q<type,  56, mmq_get_nwarps( 56)>(args, stream);
+            launch_mul_mat_q<type,  56>(ctx, args, stream);
             break;
         case  64:
-            launch_mul_mat_q<type,  64, mmq_get_nwarps( 64)>(args, stream);
+            launch_mul_mat_q<type,  64>(ctx, args, stream);
             break;
         case  72:
-            launch_mul_mat_q<type,  72, mmq_get_nwarps( 72)>(args, stream);
+            launch_mul_mat_q<type,  72>(ctx, args, stream);
             break;
         case  80:
-            launch_mul_mat_q<type,  80, mmq_get_nwarps( 80)>(args, stream);
+            launch_mul_mat_q<type,  80>(ctx, args, stream);
             break;
         case  88:
-            launch_mul_mat_q<type,  88, mmq_get_nwarps( 88)>(args, stream);
+            launch_mul_mat_q<type,  88>(ctx, args, stream);
             break;
         case  96:
-            launch_mul_mat_q<type,  96, mmq_get_nwarps( 96)>(args, stream);
+            launch_mul_mat_q<type,  96>(ctx, args, stream);
             break;
         case 104:
-            launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
+            launch_mul_mat_q<type, 104>(ctx, args, stream);
             break;
         case 112:
-            launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
+            launch_mul_mat_q<type, 112>(ctx, args, stream);
             break;
         case 120:
-            launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
+            launch_mul_mat_q<type, 120>(ctx, args, stream);
             break;
         case 128:
-            launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
+            launch_mul_mat_q<type, 128>(ctx, args, stream);
             break;
         default:
             fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
@@ -2114,7 +2586,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
 }
 
 #define DECL_MMQ_CASE(type)                                                        \
-    template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
+    template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
 
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
@@ -2135,4 +2607,4 @@ void ggml_cuda_op_mul_mat_q(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
 
-bool ggml_cuda_supports_mmq(enum ggml_type type);
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
index 88c42c4b7a8fbc35bf49f778a9025a64d6f7050a..d9e42fdd6d16c01b2594ffc9e6012aa0e98f37a5 100644 (file)
@@ -1,5 +1,7 @@
 #include "common.cuh"
 
+#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+
 void ggml_cuda_op_mul_mat_vec_q(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
index a5ff96320f23f4537053b3d80c6d93d79cb2b5ab..f9e208011e2a8f2b1387d2c2173e8056bf0ad068 100644 (file)
@@ -92,6 +92,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] * x[i];
 }
 
+static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sqrtf(x[i]);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -142,6 +151,11 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
     sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
+    sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
@@ -284,3 +298,17 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
index a1d07c04fcd4350a321690dbb6f824c8fdf7a2df..4cfb0479e7169a83a8b433b21b9b381aeb9158d9 100644 (file)
@@ -8,6 +8,7 @@
 #define CUDA_HARDSIGMOID_BLOCK_SIZE 256
 #define CUDA_HARDSWISH_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
+#define CUDA_SQRT_BLOCK_SIZE 256
 
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
@@ -28,3 +29,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 5e77471f332f443277c835f25fc916dd16fd26ca..1d23361906c34d8191d55a939c1cfea8c72af45d 100644 (file)
@@ -17,7 +17,7 @@
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
-#if defined(_WIN32)
+#if defined(_MSC_VER)
 
 #define m512bh(p) p
 #define m512i(p) p
diff --git a/src/ggml-kompute.h b/src/ggml-kompute.h
deleted file mode 100644 (file)
index 1714654..0000000
+++ /dev/null
@@ -1,46 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-
-#include <stdbool.h>
-#include <stddef.h>
-#include <stdint.h>
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-struct ggml_vk_device {
-    int index;
-    int type; // same as VkPhysicalDeviceType
-    size_t heapSize;
-    const char * name;
-    const char * vendor;
-    int subgroupSize;
-    uint64_t bufferAlignment;
-    uint64_t maxAlloc;
-};
-
-struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
-bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
-bool ggml_vk_has_vulkan(void);
-bool ggml_vk_has_device(void);
-struct ggml_vk_device ggml_vk_current_device(void);
-
-//
-// backend API
-//
-
-// forward declaration
-typedef struct ggml_backend * ggml_backend_t;
-
-GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
-
-GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
-
-GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
-
-#ifdef __cplusplus
-}
-#endif
diff --git a/src/ggml-metal.h b/src/ggml-metal.h
deleted file mode 100644 (file)
index e7543ae..0000000
+++ /dev/null
@@ -1,66 +0,0 @@
-// An interface allowing to compute ggml_cgraph with Metal
-//
-// This is a fully functional interface that extends ggml with GPU support for Apple devices.
-// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
-//
-// How it works?
-//
-// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
-// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
-// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
-//
-// You only need to make sure that all memory buffers that you used during the graph creation
-// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
-// used during the graph evaluation to determine the arguments of the compute kernels.
-//
-// Synchronization between device and host memory (for example for input and output tensors)
-// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
-//
-
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-
-#include <stddef.h>
-#include <stdbool.h>
-
-// max memory buffers that can be mapped to the device
-#define GGML_METAL_MAX_BUFFERS 64
-
-struct ggml_tensor;
-struct ggml_cgraph;
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-//
-// backend API
-// user-code should use only these functions
-//
-
-GGML_API void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data);
-
-GGML_API ggml_backend_t ggml_backend_metal_init(void);
-
-GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
-
-GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
-
-GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
-
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
-
-// helper to check if the device supports a specific family
-// ideally, the user code should be doing these checks
-// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
-GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
-
-// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
-GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
-
-#ifdef __cplusplus
-}
-#endif
-
index f894274cacc93e893614efc700c00c9274e6f326..79902c9a80616a9bb62af3f368284fb9d4981c5f 100644 (file)
@@ -735,6 +735,12 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
 }
 
 static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
+    for (size_t i = 0, n = 3; i < n; ++i) {
+        if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
+            return false;
+        }
+    }
+
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
index ed4a9ad254cf58c5c13b1fd7ad5237f8d1ba96bf..0eb52e485089f98d2f841bd1578cff2e9cfcea20 100644 (file)
@@ -8814,7 +8814,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 #endif
 }
 
-#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
+#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
 static const int8_t keven_signs_q2xs[1024] = {
      1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
      1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
@@ -8947,6 +8947,61 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
 
     *s = 0.125f * hsum_float_8(accumf);
 
+#elif defined(__AVX__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
+            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
+            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
+            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = aux32[1] >> 28;
+            const uint16_t ls2 = aux32[3] >> 28;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
 #elif defined(__POWER9_VECTOR__)
     const vector int v0 = vec_splats((int32_t)0);
     vector float vsumf0 = vec_splats(0.0f);
@@ -9290,6 +9345,165 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     }
 
     *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+    const __m128i mone = _mm_set1_epi8(1);
+    static const char block_sign_shuffle_mask_1[32] = {
+        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+    };
+    static const char block_sign_shuffle_mask_2[32] = {
+        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+    };
+    static const uint8_t bit_selector_mask_bytes[32] = {
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
+    const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
+    const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
+    const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
+    const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
+    const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
+
+    static const uint8_t k_bit_helper[32] = {
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+    };
+    const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
+    const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
+    const __m128i m511 = _mm_set1_epi16(511);
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    uint64_t aux64;
+
+    // somewhat hacky, but gives a significant boost in performance
+    __m256i aux_gindex;
+    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        __m128i stmp = _mm_set1_epi64x(aux64);
+        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
+        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+            const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
+            const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1);  q2 += 16;
+            aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
+
+            const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
+            const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
+            const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
+            const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
+            const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
+            const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
+
+            const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
+            const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
+            const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
+            const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
+
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
+            const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
+            const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
+            const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+            const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
+
+            // AVX2 full_signs_1 is full_sign_bits_0 here
+            // AVX2 full_signs_2 is full_sign_bits_1 here
+            __m128i signs_0, signs_1;
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const __m128i dot3_0  = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
+            const __m128i dot3_1  = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
+            const __m128i dot4_0  = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
+            const __m128i dot4_1  = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
+
+            __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
+            const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
+            const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
+            const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
+            const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
 #elif defined(__loongarch_asx)
 
     const __m256i mone = __lasx_xvreplgr2vr_b(1);
@@ -9693,6 +9907,98 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
 
     *s = 0.125f * hsum_float_8(accumf);
 
+#elif defined(__AVX__)
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
+    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
+    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
+    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
+
+    uint64_t aux64;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
+        const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
+        const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+                                                  iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+                                                  iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+                                                  iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+                                                  iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
+            qs += 8;
+
+            __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
+            __m128i aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
+            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
+
+            aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
+            aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
+            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
+
+            signs += 4;
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
 #elif defined(__POWER9_VECTOR__)
     static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
                                         0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
@@ -10019,6 +10325,63 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
 
     *s = 0.25f * hsum_float_8(accumf);
 
+#elif defined(__AVX__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
+            q3 += 8;
+            const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
+            q3 += 8;
+            memcpy(aux32, gas, 8); gas += 8;
+            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
+            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
+            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = aux32[0] >> 28;
+            const uint16_t ls2 = aux32[1] >> 28;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
 #elif defined(__POWER9_VECTOR__)
     const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
 
@@ -10370,6 +10733,112 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
 
     *s = hsum_float_8(accumf);
 
+#elif defined(__AVX__)
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
+    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
+    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
+    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
+
+    const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
+    const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
+    const __m128i idx_mask  = _mm_set1_epi32(256);
+
+    typedef union {
+        __m128i  vec[4];
+        uint32_t index[16];
+    } index_t;
+
+    index_t idx;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
+            const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
+            const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
+            idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
+            idx.vec[1] = idx.vec[0];
+            idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
+            idx.vec[3] = idx.vec[2];
+
+            idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
+            idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
+            idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
+            idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
+
+            idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
+            idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
+            idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
+            idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
+
+            const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
+            const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
+            const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
+            const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
+
+            __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
+            __m128i aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
+            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
+
+            aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
+            aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
+            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
+
+            signs += 4;
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = hsum_float_8(accumf);
+
 #elif defined(__POWER9_VECTOR__)
     static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
                                         0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
@@ -10607,6 +11076,14 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
 }
 
 
+#if defined(__AVX__)
+static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
+    const __m128i ax = _mm_sign_epi8(x, x);
+    const __m128i sy = _mm_sign_epi8(y, x);
+    return _mm_maddubs_epi16(ax, sy);
+}
+#endif
+
 #if defined(__AVX2__)
 static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
     const __m256i ax = _mm256_sign_epi8(x, x);
@@ -10724,6 +11201,54 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void
 
     *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
 
+#elif defined __AVX__
+    __m256 accum = _mm256_setzero_ps();
+    float accum1 = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        int sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+            const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
+            const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+            const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
+            qs += 8;
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
+            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
+            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
+            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
+            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
+            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
+        accum1 += d * sumi1;
+
+    }
+
+    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
 #elif defined(__POWER9_VECTOR__)
     const vector unsigned char v0 = vec_splats((unsigned char)0x0);
     const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
@@ -11062,6 +11587,92 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
 
     *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
 
+#elif defined __AVX__
+    const __m128i mask = _mm_set1_epi16(0x7);
+    const __m128i mone = _mm_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q1b_1_0 = _mm_set_epi64x(
+                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
+            const __m128i q1b_1_1 = _mm_set_epi64x(
+                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
+            const __m128i q1b_2_0 = _mm_set_epi64x(
+                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
+            const __m128i q1b_2_1 = _mm_set_epi64x(
+                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
+            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
+            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
+            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
+
+            const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+
+            const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
+            const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
+            const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
+            const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
+
+            __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
+            __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
+            __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
+            __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
+
+            scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
+            scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
+            scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
+            scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
+            const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
+            const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
+            const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
+            const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
+
+            qs += 8; qh += 4;
+        }
+
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+
+        accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
+        accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
+    }
+
+    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
+
 #else
 
     int sum1[2], sum2[2], delta[4];
@@ -11192,6 +11803,44 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
 
     *s = hsum_float_8(_mm256_add_ps(accum1, accum2));
 
+#elif defined __AVX__
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+    const __m128i mone = _mm_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (int ib = 0; ib < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs);
+        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
+        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[0].qs + 1);
+        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[1].qs);
+        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[1].qs + 1);
+
+        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+        const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
+        const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
+        const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
+        const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
+        const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
+        const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
+        const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
+        const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
+        accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
+                _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
+        accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
+                _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
+
+        y += 2;
+        x += 2;
+    }
+
+    *s = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
 #elif defined(__POWER9_VECTOR__)
     const vector signed char lowMask = vec_splats((signed char)0xF);
     const vector signed int v0 = vec_splats((int32_t)0);
@@ -11382,6 +12031,54 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
 
     *s = hsum_float_8(accum);
 
+#elif defined __AVX__
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        uint16_t sh = x[ibl].scales_h;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
+            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+            const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+            const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+            const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+            const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
+            const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
+            const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
+            const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
+            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            sh >>= 4;
+            const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
+            const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
+            const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
+            const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
+            sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
+            sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
+            sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
+            sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
+        }
+        __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
+        __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
+        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+                _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
+    }
+
+    *s = hsum_float_8(accum);
+
 #elif defined(__POWER9_VECTOR__)
     const vector signed char lowMask = vec_splats((signed char)0xF);
     const vector int v0 = vec_splats((int32_t)0);
@@ -13056,7 +13753,7 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u
         const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
     int num_neighbors = neighbours[0];
     GGML_ASSERT(num_neighbors > 0);
-    float best_score = 0;
+    float best_score = -FLT_MAX;
     int grid_index = -1;
     for (int j = 1; j <= num_neighbors; ++j) {
         const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
@@ -13254,7 +13951,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
                     sumw[j+1] = sumw[j] + weight[i];
                 }
             }
-            float best_score = 0, scale = max;
+            float best_score = -FLT_MIN, scale = max;
             int besti1 = -1, besti2 = -1, best_shift = 0;
             for (int i1 = 0; i1 <= block_size; ++i1) {
                 for (int i2 = i1; i2 <= block_size; ++i2) {
@@ -13430,7 +14127,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
                 idx[2*j] = j;
             }
             qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
-            float best_score = 0, scale = max;
+            float best_score = -FLT_MIN, scale = max;
             int besti1 = -1, besti2 = -1, best_k = -1;
             // 0: +, +
             // 1: +, -
index 22d9524b8d7642a88473a97dcd8bea493f4cb40c..b01ad267446fb7ee7c889dbeb3243f8fe3fd5fce 100644 (file)
@@ -73,9 +73,13 @@ struct rpc_tensor {
     uint64_t view_offs;
     uint64_t data;
     char name[GGML_MAX_NAME];
+
+    char padding[4];
 };
 #pragma pack(pop)
 
+static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
+
 // RPC commands
 enum rpc_cmd {
     ALLOC_BUFFER = 0,
@@ -599,9 +603,8 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
     int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
     output.resize(output_size, 0);
     memcpy(output.data(), &n_nodes, sizeof(n_nodes));
-    uint64_t * out_nodes = (uint64_t *)(output.data() + sizeof(n_nodes));
     for (uint32_t i = 0; i < n_nodes; i++) {
-        out_nodes[i] = reinterpret_cast<uint64_t>(cgraph->nodes[i]);
+        memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
     }
     uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
     *out_ntensors = n_tensors;
@@ -1036,7 +1039,9 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
     }
     std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
     for (uint32_t i = 0; i < n_nodes; i++) {
-        graph->nodes[i] = create_node(nodes[i], ctx, tensor_ptrs, tensor_map);
+        int64_t id;
+        memcpy(&id, &nodes[i], sizeof(id));
+        graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
     }
     ggml_status status = ggml_backend_graph_compute(backend, graph);
     // output serialization format: | status (1 byte) |
diff --git a/src/ggml-rpc.h b/src/ggml-rpc.h
deleted file mode 100644 (file)
index aa14483..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-#define GGML_RPC_MAX_SERVERS       16
-
-// backend API
-GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
-GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);
-
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
-
-GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
-
-GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
-
-#ifdef  __cplusplus
-}
-#endif
index 6bd42b9609882d96c18f5108a94cb201d7de8400..4a668a2c34d3ead78dc4011c0347d5eb59f1404e 100644 (file)
 #include "ggml-backend-impl.h"
 
 #include "ggml-sycl/backend.hpp"
-
-/*
-Following definition copied from DPCT head files, which are used by ggml-sycl.cpp
-*/
-// COPY from DPCT head files
-#include <sycl/sycl.hpp>
-#include <oneapi/mkl.hpp>
-#include <map>
-
-#if defined(__linux__)
-#include <sys/mman.h>
-#elif defined(_WIN64)
-#ifndef NOMINMAX
-#define NOMINMAX
-#endif
-#include <windows.h>
-#else
-#error "Only support Windows and Linux."
-#endif
-
-#if defined(__linux__)
-#include <unistd.h>
-#include <sys/syscall.h>
-#endif
-#if defined(_WIN64)
-#ifndef NOMINMAX
-#define NOMINMAX
-#endif
-#include <windows.h>
-#endif
-
-#define DPCT_COMPATIBILITY_TEMP (900)
-
-#if defined(_MSC_VER)
-#define __dpct_align__(n) __declspec(align(n))
-#define __dpct_inline__ __forceinline
-#else
-#define __dpct_align__(n) __attribute__((aligned(n)))
-#define __dpct_inline__ __inline__ __attribute__((always_inline))
-#endif
-
-#if defined(_MSC_VER)
-#define __dpct_noinline__ __declspec(noinline)
-#else
-#define __dpct_noinline__ __attribute__((noinline))
-#endif
+#include "ggml-sycl/presets.hpp"
 
 bool   ggml_sycl_loaded(void);
 void   ggml_sycl_free_data(struct ggml_tensor * tensor);
-void   ggml_sycl_assign_buffers(struct ggml_tensor * tensor);
-void   ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor);
-void   ggml_sycl_assign_buffers_force_inplace(struct ggml_tensor * tensor);
-void   ggml_sycl_assign_buffers_no_alloc(struct ggml_tensor * tensor);
 void   ggml_sycl_copy_to_device(struct ggml_tensor * tensor);
 void   ggml_sycl_set_main_device(int main_device);
 void   ggml_sycl_set_mul_mat_q(bool mul_mat_q);
-void   ggml_sycl_set_scratch_size(size_t scratch_size);
-void   ggml_sycl_free_scratch(void);
 void   ggml_sycl_get_device_description(int device, char * description, size_t description_size);
 bool   ggml_backend_is_sycl(ggml_backend_t backend);
 int    ggml_backend_sycl_get_device(ggml_backend_t backend);
 static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
+static inline int get_sycl_env(const char *env_name, int default_val);
+static inline int get_work_group_size(const sycl::device& device);
 
 void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
                     const void *ptr_src, size_t size) {
@@ -108,45 +59,6 @@ void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
     free(host_buf);
 }
 
-static __dpct_inline__ int get_int_from_int8(const int8_t *x8, const int &i32) {
-    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
-
-    int x32 = 0;
-    x32 |= x16[0] <<  0;
-    x32 |= x16[1] << 16;
-
-    return x32;
-}
-
-static __dpct_inline__ int get_int_from_uint8(const uint8_t *x8,
-                                              const int &i32) {
-    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
-
-    int x32 = 0;
-    x32 |= x16[0] <<  0;
-    x32 |= x16[1] << 16;
-
-    return x32;
-}
-
-static __dpct_inline__ int get_int_from_int8_aligned(const int8_t *x8,
-                                                     const int &i32) {
-    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
-}
-
-static __dpct_inline__ int get_int_from_uint8_aligned(const uint8_t *x8,
-                                                      const int &i32) {
-    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
-}
-
-template <typename T>
-using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
-                             int k, queue_ptr stream);
-typedef to_t_sycl_t<float> to_fp32_sycl_t;
-typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
-
-typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
-typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
 typedef void (*ggml_sycl_op_mul_mat_t)(
@@ -162,22 +74,6 @@ typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const gg
                                        const float *src1_dd, float *dst_dd,
                                        const queue_ptr &main_stream);
 
-typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
-typedef void (*allocate_tiles_sycl_t)(int **x_ql, sycl::half2 **x_dm,
-                                      int **x_qh, int **x_sc);
-typedef void (*load_tiles_sycl_t)(const void *__restrict__ vx,
-                                  int *__restrict__ x_ql,
-                                  sycl::half2 *__restrict__ x_dm,
-                                  int *__restrict__ x_qh,
-                                  int *__restrict__ x_sc, const int &i_offset,
-                                  const int &i_max, const int &k,
-                                  const int &blocks_per_row);
-typedef float (*vec_dot_q_mul_mat_sycl_t)(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ms,
-    const int &i, const int &j, const int &k);
-
 static __dpct_inline__ float warp_reduce_sum(float x,
                                              const sycl::nd_item<3> &item_ct1) {
 #pragma unroll
@@ -664,8029 +560,1392 @@ static void rms_norm_f32(const float * x, float * dst, const int ncols, const fl
     }
 }
 
-static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
-                                            const int iqs, dfloat2 &v) {
-    const block_q4_0 * x = (const block_q4_0 *) vx;
+static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int ix = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                   item_ct1.get_local_id(2);
 
-    const dfloat d = x[ib].d;
+    if (ix >= kx_padded) {
+        return;
+    }
 
-    const int vui = x[ib].qs[iqs];
+    const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                   item_ct1.get_local_id(1);
 
-    v.x() = vui & 0xF;
-    v.y() = vui >> 4;
+    const int i_padded = iy*kx_padded + ix;
 
-#ifdef GGML_SYCL_F16
-    // v = v - {8.0f, 8.0f};
-    // v = v * {d, d};
-    v.s0() = (v.s0() - 8.0f) * d;
-    v.s1() = (v.s1() - 8.0f) * d;
+    block_q8_1 * y = (block_q8_1 *) vy;
 
-#else
-    v.x() = (v.x() - 8.0f) * d;
-    v.y() = (v.y() - 8.0f) * d;
-#endif // GGML_SYCL_F16
-}
+    const int ib = i_padded / QK8_1; // block index
+    const int iqs = i_padded % QK8_1; // quant index
 
-static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
-                                            const int iqs, dfloat2 &v) {
-    const block_q4_1 * x = (const block_q4_1 *) vx;
+    const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
+    float amax = sycl::fabs((float)xi);
+    float sum = xi;
 
-    const dfloat d = x[ib].dm[0];
-    const dfloat m = x[ib].dm[1];
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        amax = sycl::fmax(amax, dpct::permute_sub_group_by_xor(
+                                    item_ct1.get_sub_group(), amax, mask));
+        sum +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), sum, mask);
+    }
 
-    const int vui = x[ib].qs[iqs];
+    const float d = amax / 127;
+    const int8_t q = amax == 0.0f ? 0 : sycl::round(xi / d);
 
-    v.x() = vui & 0xF;
-    v.y() = vui >> 4;
+    y[ib].qs[iqs] = q;
 
-#ifdef GGML_SYCL_F16
-    // v = v * {d, d};
-    // v = v + {m, m};
-    v.s0() = (v.s0() * d) + m;
-    v.s1() = (v.s1() * d) + m;
+    if (iqs > 0) {
+        return;
+    }
 
-#else
-    v.x() = (v.x() * d) + m;
-    v.y() = (v.y() * d) + m;
-#endif // GGML_SYCL_F16
+    reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
+    reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
 }
 
-static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
-                                            const int iqs, dfloat2 &v) {
-    const block_q5_0 * x = (const block_q5_0 *) vx;
+template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void k_get_rows(
+            const void * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12,
+            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
 
-    const dfloat d = x[ib].d;
+    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+                     item_ct1.get_local_id(2)) *
+                    2;
+    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1);
+    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) /
+                    ne12;
+    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) %
+                    ne12;
 
-    uint32_t qh;
-    memcpy(&qh, x[ib].qh, sizeof(qh));
+    if (i00 >= ne00) {
+        return;
+    }
 
-    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
-    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-    v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
-    v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-#ifdef GGML_SYCL_F16
-    // v = v - {16.0f, 16.0f};
-    // v = v * {d, d};
-    v.s0() = (v.s0() - 16.0f) * d;
-    v.s1() = (v.s1() - 16.0f) * d;
+    const int ib = i00/qk; // block index
+    const int iqs = (i00%qk)/qr; // quant index
+    const int iybs = i00 - i00%qk; // dst block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
 
-#else
-    v.x() = (v.x() - 16.0f) * d;
-    v.y() = (v.y() - 16.0f) * d;
-#endif // GGML_SYCL_F16
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(src0_row, ib, iqs, v);
+
+    dst_row[iybs + iqs + 0] = v.x();
+    dst_row[iybs + iqs + y_offset] = v.y();
 }
 
-static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
-                                            const int iqs, dfloat2 &v) {
-    const block_q5_1 * x = (const block_q5_1 *) vx;
+template<typename src0_t, typename dst_t>
+static void k_get_rows_float(
+            const src0_t * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12,
+            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
 
-    const dfloat d = x[ib].dm[0];
-    const dfloat m = x[ib].dm[1];
+    const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+                    item_ct1.get_local_id(2);
+    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1);
+    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) /
+                    ne12;
+    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) %
+                    ne12;
 
-    uint32_t qh;
-    memcpy(&qh, x[ib].qh, sizeof(qh));
+    if (i00 >= ne00) {
+        return;
+    }
 
-    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
-    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-    v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
-    v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
-#ifdef GGML_SYCL_F16
-    // v = v * {d, d};
-    // v = v + {m, m};
-    v.s0() = (v.s0() * d) + m;
-    v.s1() = (v.s1() * d) + m;
-#else
-    v.x() = (v.x() * d) + m;
-    v.y() = (v.y() * d) + m;
-#endif // GGML_SYCL_F16
+    dst_row[i00] = src0_row[i00];
 }
 
-static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
-                                            const int iqs, dfloat2 &v) {
-    const block_q8_0 * x = (const block_q8_0 *) vx;
-
-    const dfloat d = x[ib].d;
-
-    v.x() = x[ib].qs[iqs + 0];
-    v.y() = x[ib].qs[iqs + 1];
+static void mul_mat_p021_f16_f32(
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
+    const sycl::nd_item<3> &item_ct1) {
 
-#ifdef GGML_SYCL_F16
-    // v = v * {d, d};
-    v.s0() *= d;
-    v.s1() *= d;
-#else
-    v.x() *= d;
-    v.y() *= d;
-#endif // GGML_SYCL_F16
-}
+    const sycl::half *x = (const sycl::half *)vx;
 
-template<typename dst_t>
-static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
-                                  const sycl::nd_item<3> &item_ct1) {
+    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                      item_ct1.get_local_id(1);
+    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                        item_ct1.get_local_id(0);
+    const int channel_x = channel / (nchannels_y / nchannels_x);
 
-    const int i = item_ct1.get_group(2);
+    const int nrows_y = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst = row_x;
 
-    // assume 32 threads
-    const int tid = item_ct1.get_local_id(2);
-    const int il  = tid/8;
-    const int ir  = tid%8;
-    const int ib = 8*i + ir;
-    if (ib >= nb32) {
-        return;
-    }
+    float tmp = 0.0f;
 
-    dst_t * y = yy + 256*i + 32*ir + 4*il;
+    for (int col_x0 = 0; col_x0 < ncols_x;
+         col_x0 += item_ct1.get_local_range(2)) {
+        const int col_x = col_x0 + item_ct1.get_local_id(2);
 
-    const block_q4_0 * x = (const block_q4_0 *)vx + ib;
-    const float d = sycl::vec<sycl::half, 1>(x->d)
-                        .convert<float, sycl::rounding_mode::automatic>()[0];
-    const float dm = -8*d;
+        if (col_x >= ncols_x) {
+            break;
+        }
 
-    const uint8_t * q = x->qs + 4*il;
+        // x is transposed and permuted
+        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
+        const float xi =
+            sycl::vec<sycl::half, 1>(x[ix])
+                .convert<float, sycl::rounding_mode::automatic>()[0];
 
-    for (int l = 0; l < 4; ++l) {
-        y[l+ 0] = d * (q[l] & 0xF) + dm;
-        y[l+16] = d * (q[l] >>  4) + dm;
-    }
-}
+        const int row_y = col_x;
 
-template<typename dst_t>
-static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
-                                  const sycl::nd_item<3> &item_ct1) {
 
-    const int i = item_ct1.get_group(2);
+        // y is not transposed but permuted
+        const int iy = channel*nrows_y + row_y;
 
-    // assume 32 threads
-    const int tid = item_ct1.get_local_id(2);
-    const int il  = tid/8;
-    const int ir  = tid%8;
-    const int ib = 8*i + ir;
-    if (ib >= nb32) {
-        return;
+        tmp += xi * y[iy];
     }
 
-    dst_t * y = yy + 256*i + 32*ir + 4*il;
-
-    const block_q4_1 * x = (const block_q4_1 *)vx + ib;
-    const sycl::float2 d =
-        x->dm.convert<float, sycl::rounding_mode::automatic>();
+    // dst is not transposed and not permuted
+    const int idst = channel*nrows_dst + row_dst;
 
-    const uint8_t * q = x->qs + 4*il;
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
 
-    for (int l = 0; l < 4; ++l) {
-        y[l + 0] = d.x() * (q[l] & 0xF) + d.y();
-        y[l + 16] = d.x() * (q[l] >> 4) + d.y();
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[idst] = tmp;
     }
 }
 
+static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
+    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
+    const sycl::nd_item<3> &item_ct1) {
 
-//================================== k-quants
-
-template<typename dst_t>
-static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                  const sycl::nd_item<3> &item_ct1) {
+    const sycl::half *x = (const sycl::half *)vx;
 
-    const int i = item_ct1.get_group(2);
-    const block_q2_K * x = (const block_q2_K *) vx;
+    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                      item_ct1.get_local_id(1);
+    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                        item_ct1.get_local_id(0);
+    const int channel_x = channel / channel_x_divisor;
 
-    const int tid = item_ct1.get_local_id(2);
-    const int n   = tid/32;
-    const int l   = tid - 32*n;
-    const int is  = 8*n + l/16;
+    const int nrows_y   = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst   = row_x;
 
-    const uint8_t q = x[i].qs[32*n + l];
-    dst_t * y = yy + i*QK_K + 128*n;
+    const int idst = channel*nrows_dst + row_dst;
 
-    float dall = x[i].dm[0];
-    float dmin = x[i].dm[1];
-    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
-    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
-    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
-    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
-}
+    float tmp = 0.0f;
 
-template<typename dst_t>
-static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                  const sycl::nd_item<3> &item_ct1) {
+    for (int col_x0 = 0; col_x0 < ncols_x;
+         col_x0 += item_ct1.get_local_range(2)) {
+        const int col_x = col_x0 + item_ct1.get_local_id(2);
 
-    const int i = item_ct1.get_group(2);
-    const block_q3_K * x = (const block_q3_K *) vx;
+        if (col_x >= ncols_x) {
+            break;
+        }
 
-    const int r = item_ct1.get_local_id(2) / 4;
-    const int tid = r/2;
-    const int is0 = r%2;
-    const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
-    const int n = tid / 4;
-    const int j = tid - 4*n;
+        const int row_y = col_x;
 
-    uint8_t m = 1 << (4*n + j);
-    int is = 8*n + 2*j + is0;
-    int shift = 2*j;
+        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
+        const int iy = channel*nrows_y + row_y;
 
-    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
-                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
-                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
-                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
-    float d_all = x[i].d;
-    float dl = d_all * (us - 32);
+        const float xi =
+            sycl::vec<sycl::half, 1>(x[ix])
+                .convert<float, sycl::rounding_mode::automatic>()[0];
 
-    dst_t * y = yy + i*QK_K + 128*n + 32*j;
-    const uint8_t * q = x[i].qs + 32*n;
-    const uint8_t * hm = x[i].hmask;
+        tmp += xi * y[iy];
+    }
 
-    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
-}
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
 
-static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
-    if (j < 4) {
-        d = q[j] & 63; m = q[j + 4] & 63;
-    } else {
-        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
-        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[idst] = tmp;
     }
 }
 
-template<typename dst_t>
-static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                  const sycl::nd_item<3> &item_ct1) {
-    const block_q4_K * x = (const block_q4_K *) vx;
-
-    const int i = item_ct1.get_group(2);
+static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    float * dsti = (float *) cdsti;
 
-    // assume 32 threads
-    const int tid = item_ct1.get_local_id(2);
-    const int il  = tid/8;
-    const int ir  = tid%8;
-    const int is  = 2*il;
-    const int n   = 4;
+    *dsti = *xi;
+}
 
-    dst_t * y = yy + i*QK_K + 64*il + n*ir;
+static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    sycl::half *dsti = (sycl::half *)cdsti;
 
-    const float dall = x[i].dm[0];
-    const float dmin = x[i].dm[1];
+    *dsti = sycl::vec<float, 1>(*xi)
+                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+}
 
-    const uint8_t * q = x[i].qs + 32*il + n*ir;
+static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+    const sycl::half *xi = (const sycl::half *)cxi;
+    sycl::half *dsti = (sycl::half *)cdsti;
 
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[i].scales, sc, m);
-    const float d1 = dall * sc; const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[i].scales, sc, m);
-    const float d2 = dall * sc; const float m2 = dmin * m;
-    for (int l = 0; l < n; ++l) {
-        y[l + 0] = d1 * (q[l] & 0xF) - m1;
-        y[l +32] = d2 * (q[l] >>  4) - m2;
-    }
+    *dsti = *xi;
 }
 
-template<typename dst_t>
-static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                  const sycl::nd_item<3> &item_ct1) {
-    const block_q5_K * x = (const block_q5_K *) vx;
+static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+    const sycl::half *xi = (const sycl::half *)cxi;
+    float * dsti = (float *) cdsti;
 
-    const int i = item_ct1.get_group(2);
+    *dsti = *xi;
+}
 
-    // assume 64 threads - this is very slightly better than the one below
-    const int tid = item_ct1.get_local_id(2);
-    const int il  = tid/16;   // il is in 0...3
-    const int ir  = tid%16;   // ir is in 0...15
-    const int is  = 2*il;     // is is in 0...6
+static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
+    const int16_t *xi = (const int16_t *)cxi;
+    int16_t *dsti = (int16_t *)cdsti;
 
-    dst_t * y = yy + i*QK_K + 64*il + 2*ir;
-
-    const float dall = x[i].dm[0];
-    const float dmin = x[i].dm[1];
-
-    const uint8_t * ql = x[i].qs + 32*il + 2*ir;
-    const uint8_t * qh = x[i].qh + 2*ir;
-
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[i].scales, sc, m);
-    const float d1 = dall * sc; const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[i].scales, sc, m);
-    const float d2 = dall * sc; const float m2 = dmin * m;
-
-    uint8_t   hm  = 1 << (2*il);
-    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
-    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
-    hm <<= 1;
-    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
-    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+    *dsti = *xi;
 }
 
-template<typename dst_t>
-static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                  const sycl::nd_item<3> &item_ct1) {
-    const block_q6_K * x = (const block_q6_K *) vx;
+static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
+    const int32_t *xi = (const int32_t *)cxi;
+    int32_t *dsti = (int32_t *)cdsti;
 
-    const int i = item_ct1.get_group(2);
+    *dsti = *xi;
+}
 
-    // assume 64 threads - this is very slightly better than the one below
-    const int tid = item_ct1.get_local_id(2);
-    const int ip  = tid/32;   // ip is 0 or 1
-    const int il  = tid - 32*ip; // 0...32
-    const int is  = 8*ip + il/16;
+template <cpy_kernel_t cpy_1>
+static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+                        const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                        const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                        const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-    dst_t * y = yy + i*QK_K + 128*ip + il;
+    if (i >= ne) {
+        return;
+    }
 
-    const float d = x[i].d;
+    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+    // then combine those indices with the corresponding byte offsets to get the total offsets
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    const uint8_t * ql = x[i].ql + 64*ip + il;
-    const uint8_t   qh = x[i].qh[32*ip + il];
-    const int8_t  * sc = x[i].scales + is;
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
 
-    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
-    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
-    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
-    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+    cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
-template<typename dst_t>
-static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                     const sycl::nd_item<3> &item_ct1,
-                                     const uint64_t *iq2xxs_grid_ptr,
-                                     const uint8_t *ksigns_iq2xs_ptr,
-                                     const uint8_t *kmask_iq2xs_ptr) {
+static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q8_0 * dsti = (block_q8_0 *) cdsti;
 
-    const int i = item_ct1.get_group(2);
-    const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
+    float amax = 0.0f; // absolute max
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
-    const uint16_t * q2 = x[i].qs + 4*ib;
-    const uint8_t  * aux8 = (const uint8_t *)q2;
-    const uint8_t  * grid = (const uint8_t *)(iq2xxs_grid_ptr + aux8[il]);
-    const uint32_t aux32 = q2[2] | (q2[3] << 16);
-    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
-    const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127];
-    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f);
-}
-
-template<typename dst_t>
-static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                    const sycl::nd_item<3> &item_ct1,
-                                    const uint64_t *iq2xs_grid,
-                                    const uint8_t *ksigns_iq2xs,
-                                    const uint8_t *kmask_iq2xs) {
-
-    const int i = item_ct1.get_group(2);
-    const block_iq2_xs * x = (const block_iq2_xs *) vx;
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = xi[j];
+        amax = sycl::fmax(amax, sycl::fabs((float)v));
+    }
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
-    const uint16_t * q2 = x[i].qs + 4*ib;
-    const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
-    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
-    const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
-    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
-}
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
 
-template <typename dst_t>
-__dpct_inline__ static void
-dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
-                       const sycl::nd_item<3> &item_ct1) {
+    dsti->d = d;
 
-    const int i = item_ct1.get_group(2);
-    const block_iq2_s * x = (const block_iq2_s *) vx;
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = xi[j]*id;
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
-    const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
-    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
-    const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
-#pragma unroll
-    for (int j = 0; j < 8; ++j) {
-        y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+        dsti->qs[j] = sycl::round((float)x0);
     }
 }
 
-template<typename dst_t>
-static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
-                                     const sycl::nd_item<3> &item_ct1,
-                                     const uint32_t *iq3xxs_grid,
-                                     const uint8_t *ksigns_iq2xs,
-                                     const uint8_t *kmask_iq2xs) {
-
-    const int i = item_ct1.get_group(2);
-    const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
-
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
-    const uint8_t  * q3 = x[i].qs + 8*ib;
-    const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
-    const uint8_t  * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
-    const uint8_t  * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
-    const uint32_t aux32 = gas[0] | (gas[1] << 16);
-    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
-    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
-    for (int j = 0; j < 4; ++j) {
-        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
-        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
-    }
-}
-
-template <typename dst_t>
-__dpct_inline__ static void
-dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
-                       const sycl::nd_item<3> &item_ct1,
-                       const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
+static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_0 * dsti = (block_q4_0 *) cdsti;
 
-    const int i = item_ct1.get_group(2);
-    const block_iq3_s * x = (const block_iq3_s *) vx;
+    float amax = 0.0f;
+    float vmax = 0.0f;
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
-    const uint8_t * qs = x[i].qs + 8*ib;
-    const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
-    const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
-    const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
-    const uint8_t signs = x[i].signs[4*ib + il];
-#pragma unroll
-    for (int j = 0; j < 4; ++j) {
-        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
-        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = xi[j];
+        if (amax < sycl::fabs((float)v)) {
+            amax = sycl::fabs((float)v);
+            vmax = v;
+        }
     }
-}
-
-template <typename dst_t>
-__dpct_inline__ static void
-dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
-                       const sycl::nd_item<3> &item_ct1,
-                       const uint32_t *iq1s_grid_gpu) {
 
-    const int i = item_ct1.get_group(2);
-    const block_iq1_s * x = (const block_iq1_s  *) vx;
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f/d : 0.0f;
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
-    const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
-    const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
-    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
-    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
-    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
-    grid32[0] &= 0x0f0f0f0f;
-#pragma unroll
-    for (int j = 0; j < 8; ++j) {
-        y[j] = d * (q[j] + delta);
-    }
-}
+    dsti->d = d;
 
-template <typename dst_t>
-__dpct_inline__ static void
-dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
-                       const sycl::nd_item<3> &item_ct1,
-                       const uint32_t *iq1s_grid_gpu) {
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK4_0/2 + j]*id;
 
-    const int i = item_ct1.get_group(2);
-    const block_iq1_m * x = (const block_iq1_m  *) vx;
+        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
-    const uint16_t * sc = (const uint16_t *)x[i].scales;
-    iq1m_scale_t scale;
-    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-    const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
-    const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
-    const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
-    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
-    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
-    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
-    grid32[0] &= 0x0f0f0f0f;
-#pragma unroll
-    for (int j = 0; j < 8; ++j) {
-        y[j] = d * (q[j] + delta);
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
     }
 }
 
-template <typename dst_t>
-__dpct_inline__ static void
-dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
-                        const sycl::nd_item<3> &item_ct1) {
+static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_1 * dsti = (block_q4_1 *) cdsti;
 
-    const int i = item_ct1.get_group(2);
-    const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
-    const uint8_t  * q4 = x[ib].qs + 4*il;
-    const float d = (float)x[ib].d;
-#pragma unroll
-    for (int j = 0; j < 4; ++j) {
-        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
-        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = xi[j];
+
+        if (v < vmin) vmin = v;
+        if (v > vmax) vmax = v;
     }
 
-}
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
 
+    dsti->dm.x() = d;
+    dsti->dm.y() = vmin;
 
-template <typename dst_t>
-__dpct_inline__ static void
-dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
-                        const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_group(2);
-    const block_iq4_xs * x = (const block_iq4_xs *)vx;
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (xi[0       + j] - vmin)*id;
+        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
-    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
-    const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;
-    const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
-#pragma unroll
-    for (int j = 0; j < 4; ++j) {
-        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
-        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
     }
 }
 
+template <cpy_kernel_t cpy_blck, int qk>
+static void cpy_f32_q(const char * cx, char * cdst, const int ne,
+                      const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                      const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                      const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
+    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                   item_ct1.get_local_id(2)) *
+                  qk;
 
+    if (i >= ne) {
+        return;
+    }
 
-/*
-DPCT1110:4: The total declared local variable size in device function
-dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
-pressure. Consult with your hardware vendor to find the total register size
-available and adjust the code, or use smaller sub-group size to avoid high
-register pressure.
-*/
-static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
-                                        const float *__restrict__ yy,
-                                        float *__restrict__ dst,
-                                        const int ncols, int nrows,
-                                        const sycl::nd_item<3> &item_ct1) {
-
-    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
-
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-    if (row > nrows) return;
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
 
-    const block_q2_K * x = (const block_q2_K *)vx + ib0;
+    cpy_blck(cx + x_offset, cdst + dst_offset);
+}
 
-    float tmp = 0; // partial sum for thread in warp
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
+    return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
+}
 
-    const int tid =
-        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15
-    const int ix =
-        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+struct rope_corr_dims {
+    float v[4];
+};
 
-    const int step = 16/K_QUANTS_PER_ITERATION;
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 
-    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const int in = tid - step*im;                        // 0...15 or 0...7
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
+    }
+    *cos_theta = sycl::cos(theta) * mscale;
+    *sin_theta = sycl::sin(theta) * mscale;
+}
 
-    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
-    const int q_offset = 32*im + l0;
-    const int s_offset = 8*im;
-    const int y_offset = 128*im + l0;
+// rope == RoPE == rotary positional embedding
+template<typename T, bool has_pos>
+static void rope(
+    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+,
+    const sycl::nd_item<3> &item_ct1) {
+    const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                         item_ct1.get_local_id(1));
 
-    uint32_t aux[4];
-    const uint8_t * d = (const uint8_t *)aux;
-    const uint8_t * m = (const uint8_t *)(aux + 2);
+    if (col >= ncols) {
+        return;
+    }
 
-    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
+    const int i = row*ncols + col;
+    const int i2 = row/p_delta_rows;
 
-        const float   * y = yy + i * QK_K + y_offset;
-        const uint8_t * q = x[i].qs + q_offset;
+    const int p = has_pos ? pos[i2] : 0;
+    const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols);
 
-        const float dall = x[i].dm[0];
-        const float dmin = x[i].dm[1];
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
-        aux[0] = a[0] & 0x0f0f0f0f;
-        aux[1] = a[1] & 0x0f0f0f0f;
-        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
-        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+    const float x0 = x[i + 0];
+    const float x1 = x[i + 1];
 
-        float sum1 = 0, sum2 = 0;
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
-                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
-                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
-                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
-                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
-                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
-                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
-                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
-            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
-                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
+    dst[i + 0] = x0*cos_theta - x1*sin_theta;
+    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
 
-        }
-        tmp += dall * sum1 - dmin * sum2;
+template<typename T, bool has_pos, bool has_freq_facs>
+static void rope_neox(
+    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
+    const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
+    const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                         item_ct1.get_local_id(1));
 
+    if (col >= ncols) {
+        return;
     }
 
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
+    const int ib = col / n_dims;
+    const int ic = col % n_dims;
 
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
+    if (ib > 0) {
+        const int i = row*ncols + ib*n_dims + ic;
 
-/*
-DPCT1110:5: The total declared local variable size in device function
-dequantize_mul_mat_vec_q3_k exceeds 128 bytes and may cause high register
-pressure. Consult with your hardware vendor to find the total register size
-available and adjust the code, or use smaller sub-group size to avoid high
-register pressure.
-*/
-static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
-                                        const float *__restrict__ yy,
-                                        float *__restrict__ dst,
-                                        const int ncols, int nrows,
-                                        const sycl::nd_item<3> &item_ct1) {
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
 
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-    if (row > nrows) return;
+        return;
+    }
 
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
+    const int i  = row*ncols + ib*n_dims + ic/2;
+    const int i2 = row/p_delta_rows;
 
-    const block_q3_K * x = (const block_q3_K *)vx + ib0;
+    float cur_rot = inv_ndims * ic - ib;
 
-    float tmp = 0; // partial sum for thread in warp
+    const int p = has_pos ? pos[i2] : 0;
+    const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
 
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
+    const float theta_base =
+        p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
 
-    const int tid =
-        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
-    const int ix =
-        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
-    const int step = 16/K_QUANTS_PER_ITERATION;
-    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const int in = tid - step*im;                        // 0....15 or 0...7
+    const float x0 = x[i + 0];
+    const float x1 = x[i + n_dims/2];
 
-    const uint8_t m = 1 << (4*im);
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
 
-    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
-    const int q_offset =  32*im + l0;
-    const int y_offset = 128*im + l0;
+static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
+                           const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(1);
+    const int col = item_ct1.get_local_id(2);
 
-    uint16_t utmp[4];
-    const int8_t * s = (const int8_t *)utmp;
+    float sum = 0.0f;
+    for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
+        sum += x[row * ncols + i];
+    }
 
-    const uint16_t s_shift = 4*im;
+    sum = warp_reduce_sum(sum, item_ct1);
 
-    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-
-        const float   * y  = yy + i * QK_K + y_offset;
-        const uint8_t * q = x[i].qs + q_offset;
-        const uint8_t * h = x[i].hmask + l0;
+    if (col == 0) {
+        dst[row] = sum;
+    }
+}
 
-        const uint16_t * a = (const uint16_t *)x[i].scales;
-        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
-        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
-        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
-        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
 
-        const float d = x[i].d;
+template<typename T>
+static inline void ggml_sycl_swap(T & a, T & b) {
+    T tmp = a;
+    a = b;
+    b = tmp;
+}
 
-        float sum = 0;
-        for (int l = 0; l < n; ++l) {
-            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
-                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
-                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
-                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
-            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
-                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
-                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
-                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
-        }
-        tmp += d * sum;
+template <ggml_sort_order order>
+__dpct_inline__ static void
+k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
+                  const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
+    // bitonic sort
+    int col = item_ct1.get_local_id(2);
+    int row = item_ct1.get_group(1);
 
+    if (col >= ncols_pad) {
+        return;
     }
 
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    const float * x_row = x + row * ncols;
+    auto dst_row = (int *)dpct_local;
+
+    // initialize indices
+    dst_row[col] = col;
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    for (int k = 2; k <= ncols_pad; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
+                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
+                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            /*
+            DPCT1118:1: SYCL group functions and algorithms must be encountered
+            in converged control flow. You may need to adjust the code.
+            */
+            item_ct1.barrier(sycl::access::fence_space::local_space);
+        }
     }
 
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
     }
 }
 
-/*
-DPCT1110:6: The total declared local variable size in device function
-dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register
-pressure. Consult with your hardware vendor to find the total register size
-available and adjust the code, or use smaller sub-group size to avoid high
-register pressure.
-*/
-static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
-                                        const float *__restrict__ yy,
-                                        float *__restrict__ dst,
-                                        const int ncols, int nrows,
-                                        const sycl::nd_item<3> &item_ct1) {
 
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
+                              const sycl::nd_item<3> &item_ct1) {
+    const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
                     item_ct1.get_local_id(1);
-    if (row > nrows) return;
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    const block_q4_K * x = (const block_q4_K *)vx + ib0;
-
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    const int tid =
-        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
-    const int ix =
-        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
 
-    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
+    if (col >= ncols) {
+        return;
+    }
 
-    const int il  = tid/step;                            // 0...3
-    const int ir  = tid - step*il;                       // 0...7 or 0...3
-    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
+    const int i = row*ncols + col;
+    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
+}
 
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
 
-    const int l0 = n*(2*ir + in);
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
+                         const int nrows_y, const float scale, const float max_bias, const float m0,
+                         const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
+    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
-    uint16_t aux[4];
-    const uint8_t * sc = (const uint8_t *)aux;
+    const int tid = item_ct1.get_local_id(2);
+    const int rowx = item_ct1.get_group(2);
+    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
 
-#if K_QUANTS_PER_ITERATION == 2
-    uint32_t q32[4];
-    const uint8_t * q4 = (const uint8_t *)q32;
-#else
-    uint16_t q16[4];
-    const uint8_t * q4 = (const uint8_t *)q16;
-#endif
+    const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
 
-    float tmp = 0; // partial sum for thread in warp
+    const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+    const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
 
-    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+    float slope = 1.0f;
 
-        const float   * y1 = yy + i*QK_K + y_offset;
-        const float   * y2 = y1 + 128;
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const uint32_t h = rowx/nrows_y; // head index
 
-        const float dall = x[i].dm[0];
-        const float dmin = x[i].dm[1];
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 
-        const uint16_t * a = (const uint16_t *)x[i].scales;
-        aux[0] = a[im+0] & kmask1;
-        aux[1] = a[im+2] & kmask1;
-        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
-        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+        slope = sycl::pow(base, float(exp));
+    }
 
-#if K_QUANTS_PER_ITERATION == 2
-        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
-        const uint32_t * q2 = q1 + 16;
+    float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
+    float max_val = -INFINITY;
 
-        q32[0] = q1[0] & 0x0f0f0f0f;
-        q32[1] = q1[0] & 0xf0f0f0f0;
-        q32[2] = q2[0] & 0x0f0f0f0f;
-        q32[3] = q2[0] & 0xf0f0f0f0;
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
 
-        sycl::float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < 4; ++l) {
-            s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4];
-            s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12];
-            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
-        }
-        tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +
-                       s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -
-               dmin * smin;
-#else
-        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
-        const uint16_t * q2 = q1 + 32;
-
-        q16[0] = q1[0] & 0x0f0f;
-        q16[1] = q1[0] & 0xf0f0;
-        q16[2] = q2[0] & 0x0f0f;
-        q16[3] = q2[0] & 0xf0f0;
-
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < 2; ++l) {
-            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
-            s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
-            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        if (ncols_template == 0 && col >= ncols) {
+            break;
         }
-        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
-#endif
 
-    }
+        const int ix = rowx*ncols + col;
+        const int iy = rowy*ncols + col;
 
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
+        const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
 
-    if (tid == 0) {
-        dst[row] = tmp;
+        vals[col] = val;
+        max_val = sycl::max(max_val, val);
     }
-}
 
-/*
-DPCT1110:7: The total declared local variable size in device function
-dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register
-pressure. Consult with your hardware vendor to find the total register size
-available and adjust the code, or use smaller sub-group size to avoid high
-register pressure.
-*/
-static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
-                                        const float *__restrict__ yy,
-                                        float *__restrict__ dst,
-                                        const int ncols,
-                                        const sycl::nd_item<3> &item_ct1) {
-
-    const int row = item_ct1.get_group(2);
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    const block_q5_K * x = (const block_q5_K *)vx + ib0;
-
-    float tmp = 0; // partial sum for thread in warp
-
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    const int tid = item_ct1.get_local_id(2) / 2; // 0...15
-    const int ix = item_ct1.get_local_id(2) % 2;
-
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 2;
-
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
-
-    const int l0 = n*(2*ir + in);
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
-
-    const uint8_t hm1  = 1 << (2*im);
-    const uint8_t hm2  = hm1 << 4;
-
-    uint16_t aux[4];
-    const uint8_t * sc = (const uint8_t *)aux;
-
-    uint16_t q16[8];
-    const uint8_t * q4 = (const uint8_t *)q16;
-
-    for (int i = ix; i < num_blocks_per_row; i += 2) {
-
-        const uint8_t * ql1 = x[i].qs + q_offset;
-        const uint8_t * qh  = x[i].qh + l0;
-        const float   * y1  = yy + i*QK_K + y_offset;
-        const float   * y2  = y1 + 128;
-
-        const float dall = x[i].dm[0];
-        const float dmin = x[i].dm[1];
-
-        const uint16_t * a = (const uint16_t *)x[i].scales;
-        aux[0] = a[im+0] & kmask1;
-        aux[1] = a[im+2] & kmask1;
-        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
-        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
-
-        sycl::float4 sum = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        const uint16_t * q1 = (const uint16_t *)ql1;
-        const uint16_t * q2 = q1 + 32;
-        q16[0] = q1[0] & 0x0f0f;
-        q16[1] = q1[8] & 0x0f0f;
-        q16[2] = (q1[0] >> 4) & 0x0f0f;
-        q16[3] = (q1[8] >> 4) & 0x0f0f;
-        q16[4] = q2[0] & 0x0f0f;
-        q16[5] = q2[8] & 0x0f0f;
-        q16[6] = (q2[0] >> 4) & 0x0f0f;
-        q16[7] = (q2[8] >> 4) & 0x0f0f;
-        for (int l = 0; l < n; ++l) {
-            sum.x() +=
-                y1[l + 0] * (q4[l + 0] + (qh[l + 0] & (hm1 << 0) ? 16 : 0)) +
-                y1[l + 16] * (q4[l + 2] + (qh[l + 16] & (hm1 << 0) ? 16 : 0));
-            sum.y() +=
-                y1[l + 32] * (q4[l + 4] + (qh[l + 0] & (hm1 << 1) ? 16 : 0)) +
-                y1[l + 48] * (q4[l + 6] + (qh[l + 16] & (hm1 << 1) ? 16 : 0));
-            sum.z() +=
-                y2[l + 0] * (q4[l + 8] + (qh[l + 0] & (hm2 << 0) ? 16 : 0)) +
-                y2[l + 16] * (q4[l + 10] + (qh[l + 16] & (hm2 << 0) ? 16 : 0));
-            sum.w() +=
-                y2[l + 32] * (q4[l + 12] + (qh[l + 0] & (hm2 << 1) ? 16 : 0)) +
-                y2[l + 48] * (q4[l + 14] + (qh[l + 16] & (hm2 << 1) ? 16 : 0));
-            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
-                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+    // find the max value in the block
+    max_val = warp_reduce_max(max_val, item_ct1);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf[lane_id] = -INFINITY;
         }
-        tmp += dall * (sum.x() * sc[0] + sum.y() * sc[1] + sum.z() * sc[4] +
-                       sum.w() * sc[5]) -
-               dmin * smin;
-    }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
+        if (lane_id == 0) {
+            buf[warp_id] = max_val;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
+        max_val = buf[lane_id];
+        max_val = warp_reduce_max(max_val, item_ct1);
     }
-}
-
-static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,
-                                        const sycl::nd_item<3> &item_ct1) {
 
-    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
-
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-    if (row > nrows) return;
-
-    const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
-
-    const block_q6_K * x = (const block_q6_K *)vx + ib0;
+    float tmp = 0.f;
 
-    const int tid =
-        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
-    const int ix =
-        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+                if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
 
-    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
+        const float val = sycl::native::exp(vals[col] - max_val);
+        tmp += val;
+        vals[col] = val;
+    }
 
-    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const int in = tid - step*im;                        // 0...15 or 0...7
+    // find the sum of exps in the block
+    tmp = warp_reduce_sum(tmp, item_ct1);
+    if (block_size > WARP_SIZE) {
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+        if (warp_id == 0) {
+            buf[lane_id] = 0.f;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
-#if K_QUANTS_PER_ITERATION == 1
-    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
-    const int is = 0;
-#else
-    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
-    const int is = in / 4;
-#endif
-    const int ql_offset = 64*im + l0;
-    const int qh_offset = 32*im + l0;
-    const int s_offset  =  8*im + is;
-    const int y_offset = 128*im + l0;
-
-    float tmp = 0; // partial sum for thread in warp
-
-    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-
-        const float   * y  = yy + i * QK_K + y_offset;
-        const uint8_t * ql = x[i].ql + ql_offset;
-        const uint8_t * qh = x[i].qh + qh_offset;
-        const int8_t  * s  = x[i].scales + s_offset;
-
-        const float d = x[i].d;
-
-#if K_QUANTS_PER_ITERATION == 1
-        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
-                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
-                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
-                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
-                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
-                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
-                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
-                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
-        tmp += sum;
-#else
-        float sum = 0;
-        for (int l = 0; l < 4; ++l) {
-            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
-                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
-                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
-                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        if (lane_id == 0) {
+            buf[warp_id] = tmp;
         }
-        tmp += sum;
-#endif
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
+        tmp = buf[lane_id];
+        tmp = warp_reduce_sum(tmp, item_ct1);
     }
 
-    // sum up partial sums and write back result
+    const float inv_sum = 1.f / tmp;
+
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            return;
+        }
 
-    if (tid == 0) {
-        dst[row] = tmp;
+        const int idst = rowx*ncols + col;
+        dst[idst] = vals[col] * inv_sum;
     }
 }
 
-static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
-    const sycl::half *x = (const sycl::half *)vx;
-
-    // automatic half -> float type cast if dfloat == float
-    v.x() = x[ib + iqs + 0];
-    v.y() = x[ib + iqs + 1];
-}
+static void scale_f32(const float * x, float * dst, const float scale, const int k,
+                      const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
-    const float * x = (const float *) vx;
+    if (i >= k) {
+        return;
+    }
 
-    // automatic half -> float type cast if dfloat == float
-    v.x() = x[ib + iqs + 0];
-    v.y() = x[ib + iqs + 1];
+    dst[i] = scale * x[i];
 }
 
-static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int ix = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                   item_ct1.get_local_id(2);
+static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
+                      const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-    if (ix >= kx_padded) {
+    if (i >= k) {
         return;
     }
 
-    const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                   item_ct1.get_local_id(1);
+    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+}
 
-    const int i_padded = iy*kx_padded + ix;
+template <typename T>
+static void im2col_kernel(const float *x, T *dst, int offset_delta,
+                           int IW, int IH, int OW, int KW, int KH,
+                           int pelements, int CHW, int s0, int s1, int p0,
+                           int p1, int d0, int d1,
+                           const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_id(2) +
+                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
+    if (i >= pelements) {
+        return;
+    }
 
-    block_q8_1 * y = (block_q8_1 *) vy;
+    const int ksize = OW * (KH > 1 ? KW : 1);
+    const int kx = i / ksize;
+    const int kd = kx * ksize;
+    const int ky = (i - kd) / OW;
+    const int ix = i % OW;
 
-    const int ib = i_padded / QK8_1; // block index
-    const int iqs = i_padded % QK8_1; // quant index
+    const int64_t iiw = ix * s0 + kx * d0 - p0;
+    const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
 
-    const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
-    float amax = sycl::fabs((float)xi);
-    float sum = xi;
+    const int64_t offset_dst =
+        (item_ct1.get_group(1) * OW + ix) * CHW +
+        (item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
 
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        amax = sycl::fmax(amax, dpct::permute_sub_group_by_xor(
-                                    item_ct1.get_sub_group(), amax, mask));
-        sum +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), sum, mask);
-    }
-
-    const float d = amax / 127;
-    const int8_t q = amax == 0.0f ? 0 : sycl::round(xi / d);
-
-    y[ib].qs[iqs] = q;
-
-    if (iqs > 0) {
-        return;
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] =
+            sycl::vec<float, 1>(0.0f)
+                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+    } else {
+        const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
+        dst[offset_dst] =
+            sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
+                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
     }
-
-    reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
-    reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
 }
 
-template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void k_get_rows(
-            const void * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12,
-            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
-    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                     item_ct1.get_local_id(2)) *
-                    2;
-    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) /
-                    ne12;
-    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) %
-                    ne12;
-
-    if (i00 >= ne00) {
-        return;
-    }
-
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+template <typename Ti, typename To>
+static  void pool2d_nchw_kernel(
+        const int ih, const int iw, const int oh, const int ow,
+        const int kh, const int kw, const int sh, const int sw,
+        const int ph, const int pw, const int parallel_elements,
+        const Ti* src, To* dst, const enum ggml_op_pool op,
+        const sycl::nd_item<3> &item_ct1) {
+        int idx = item_ct1.get_local_id(2) +
+                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
+        if (idx >= parallel_elements) {
+            return;
+        }
 
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+        const int I_HW = ih * iw;
+        const int O_HW = oh * ow;
+        const int nc = idx / O_HW;
+        const int cur_oh = idx % O_HW / ow;
+        const int cur_ow = idx % O_HW % ow;
+        const Ti* i_ptr = src + nc * I_HW;
+        To* o_ptr = dst + nc * O_HW;
+        const int start_h = cur_oh * sh - ph;
+        const int bh = sycl::max(0, start_h);
+        const int eh = sycl::min(ih, start_h + kh);
+        const int start_w = cur_ow * sw - pw;
+        const int bw = sycl::max(0, start_w);
+        const int ew = sycl::min(iw, start_w + kw);
 
-    const int ib = i00/qk; // block index
-    const int iqs = (i00%qk)/qr; // quant index
-    const int iybs = i00 - i00%qk; // dst block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
+        To res = 0;
 
-    // dequantize
-    dfloat2 v;
-    dequantize_kernel(src0_row, ib, iqs, v);
+        switch (op) {
+            case GGML_OP_POOL_AVG: res = 0; break;
+            case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+        }
 
-    dst_row[iybs + iqs + 0] = v.x();
-    dst_row[iybs + iqs + y_offset] = v.y();
+        for (int i = bh; i < eh; i += 1) {
+            for (int j = bw; j < ew; j += 1) {
+#if DPCT_COMPATIBILITY_TEMP >= 350
+                /*
+                DPCT1098:106: The '*' expression is used instead of the __ldg
+                call. These two expressions do not provide the exact same
+                functionality. Check the generated code for potential precision
+                and/or performance issues.
+                */
+                Ti cur = *(i_ptr + i * iw + j);
+#else
+                Ti cur = i_ptr[i * iw + j];
+#endif
+                switch (op) {
+                    case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
+                    case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
+                }
+            }
+        }
+        o_ptr[cur_oh * ow + cur_ow] = res;
 }
 
-template<typename src0_t, typename dst_t>
-static void k_get_rows_float(
-            const src0_t * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12,
-            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
-    const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                    item_ct1.get_local_id(2);
-    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) /
-                    ne12;
-    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) %
-                    ne12;
-
-    if (i00 >= ne00) {
-        return;
-    }
-
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+template <int qk, int qr, dequantize_kernel_t dq>
+static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                          ggml_tensor *dst, const void *src0_dd,
+                          const int32_t *src1_dd, float *dst_dd,
+                          queue_ptr stream) {
 
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    dst_row[i00] = src0_row[i00];
-}
+    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
+    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
 
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
-                             const sycl::nd_item<3> &item_ct1) {
-    const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                       item_ct1.get_local_id(2));
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
 
-    if (i >= k) {
-        return;
-    }
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
 
-    const int ib = i/qk; // block index
-    const int iqs = (i%qk)/qr; // quant index
-    const int iybs = i - i%qk; // y block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
+    GGML_ASSERT(ne00 % 2 == 0);
 
-    // dequantize
-    dfloat2 v;
-    dequantize_kernel(vx, ib, iqs, v);
+    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             k_get_rows<qk, qr, dq>(
+                                 src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
+                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+                         });
 
-    y[iybs + iqs + 0] = v.x();
-    y[iybs + iqs + y_offset] = v.y();
+    (void) dst;
 }
 
-template <typename src_t, typename dst_t>
-static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
+template <typename src0_t>
+static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                const ggml_tensor *src1, ggml_tensor *dst,
+                                const src0_t *src0_dd, const int32_t *src1_dd,
+                                float *dst_dd, queue_ptr stream) {
 
-    const src_t * x = (src_t *) vx;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    y[i] = x[i];
-}
+    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+    const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
+    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
 
-// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
-// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
 
-#define VDR_Q4_0_Q8_1_MMVQ 2
-#define VDR_Q4_0_Q8_1_MMQ  4
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
 
-template <int vdr>
-static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
-                                                    const float &d4,
-                                                    const sycl::half2 &ds8) {
-    int sumi = 0;
-#pragma unroll
-    for (int i = 0; i < vdr; ++i) {
-        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
-        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-        // SIMD dot product of quantized values
-        sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
-        sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
+                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+            });
     }
 
-    const sycl::float2 ds8f =
-        ds8.convert<float, sycl::rounding_mode::automatic>();
-
-    // second part effectively subtracts 8 from each quant value
-    return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
+    (void) dst;
 }
 
-#define VDR_Q4_1_Q8_1_MMVQ 2
-#define VDR_Q4_1_Q8_1_MMQ  4
-
-template <int vdr>
-static __dpct_inline__ float vec_dot_q4_1_q8_1_impl(const int *v, const int *u,
-                                                    const sycl::half2 &dm4,
-                                                    const sycl::half2 &ds8) {
-
-    int sumi = 0;
-
-#pragma unroll
-    for (int i = 0; i < vdr; ++i) {
-        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
-        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
-
-        // SIMD dot product of quantized values
-        sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
-        sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
-    }
-
-#ifdef GGML_SYCL_F16
-    const sycl::float2 tmp =
-        (dm4 * ds8).convert<float, sycl::rounding_mode::automatic>();
-    const float d4d8 = tmp.x();
-    const float m4s8 = tmp.y();
-#else
-    const sycl::float2 dm4f =
-        dm4.convert<float, sycl::rounding_mode::automatic>();
-    const sycl::float2 ds8f =
-        ds8.convert<float, sycl::rounding_mode::automatic>();
-    const float d4d8 = dm4f.x() * ds8f.x();
-    const float m4s8 = dm4f.y() * ds8f.y();
-#endif // GGML_SYCL_F16
-
-    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
-    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
-}
+template<float (*bin_op)(const float, const float)>
+struct bin_bcast_sycl {
+    template <typename src0_t, typename src1_t, typename dst_t>
+    void operator()(ggml_backend_sycl_context & ctx,
+                    const struct ggml_tensor *src0,
+                    const struct ggml_tensor *src1, struct ggml_tensor *dst,
+                    const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
+                    queue_ptr stream) {
 
-#define VDR_Q5_0_Q8_1_MMVQ 2
-#define VDR_Q5_0_Q8_1_MMQ  4
+        GGML_TENSOR_BINARY_OP_LOCALS
 
-template <int vdr>
-static __dpct_inline__ float
-vec_dot_q5_0_q8_1_impl(const int *vl, const int *vh, const int *u,
-                       const float &d5, const sycl::half2 &ds8) {
-    int sumi = 0;
+        int nr0 = ne10/ne0;
+        int nr1 = ne11/ne1;
+        int nr2 = ne12/ne2;
+        int nr3 = ne13/ne3;
 
-#pragma unroll
-    for (int i = 0; i < vdr; ++i) {
-        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
-        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
-        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
-        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
-        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
-        sumi = dpct::dp4a(vi0, u[2 * i + 0],
-                          sumi); // SIMD dot product of quantized values
+        int nr[4] = { nr0, nr1, nr2, nr3 };
 
-        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
-        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
-        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
-        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
-        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
-        sumi = dpct::dp4a(vi1, u[2 * i + 1],
-                          sumi); // SIMD dot product of quantized values
-    }
+        // collapse dimensions until first broadcast dimension
+        int64_t cne0[] = {ne0, ne1, ne2, ne3};
+        int64_t cne1[] = {ne10, ne11, ne12, ne13};
+        size_t cnb0[] = {nb0, nb1, nb2, nb3};
+        size_t cnb1[] = {nb10, nb11, nb12, nb13};
+        auto collapse = [](int64_t cne[]) {
+            cne[0] *= cne[1];
+            cne[1] = cne[2];
+            cne[2] = cne[3];
+            cne[3] = 1;
+        };
 
-    const sycl::float2 ds8f =
-        ds8.convert<float, sycl::rounding_mode::automatic>();
+        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+            cnb[1] *= cne[1];
+            cnb[2] *= cne[2];
+            cnb[3] *= cne[3];
+        };
 
-    // second part effectively subtracts 16 from each quant value
-    return d5 * (sumi * ds8f.x() - (16 * vdr / QI5_0) * ds8f.y());
-}
-
-#define VDR_Q5_1_Q8_1_MMVQ 2
-#define VDR_Q5_1_Q8_1_MMQ  4
-
-template <int vdr>
-static __dpct_inline__ float
-vec_dot_q5_1_q8_1_impl(const int *vl, const int *vh, const int *u,
-                       const sycl::half2 &dm5, const sycl::half2 &ds8) {
-
-    int sumi = 0;
-
-#pragma unroll
-    for (int i = 0; i < vdr; ++i) {
-        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
-        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
-        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
-        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
-        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
-        sumi = dpct::dp4a(vi0, u[2 * i + 0],
-                          sumi); // SIMD dot product of quantized values
-
-        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
-        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
-        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
-        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
-        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
-        sumi = dpct::dp4a(vi1, u[2 * i + 1],
-                          sumi); // SIMD dot product of quantized values
-    }
-
-#ifdef GGML_SYCL_F16
-     const sycl::float2 tmp =
-        (dm5 * ds8).convert<float, sycl::rounding_mode::automatic>();
-    const float d5d8 = tmp.x();
-    const float m5s8 = tmp.y();
-
-
-#else
-    const sycl::float2 dm5f =
-        dm5.convert<float, sycl::rounding_mode::automatic>();
-    const sycl::float2 ds8f =
-        ds8.convert<float, sycl::rounding_mode::automatic>();
-    const float d5d8 = dm5f.x() * ds8f.x();
-    const float m5s8 = dm5f.y() * ds8f.y();
-#endif // GGML_SYCL_F16
-
-    // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
-    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
-}
-
-#define VDR_Q8_0_Q8_1_MMVQ 2
-#define VDR_Q8_0_Q8_1_MMQ 8
-
-template <int vdr>
-static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u,
-                                                    const float &d8_0,
-                                                    const float &d8_1) {
-
-    int sumi = 0;
-
-#pragma unroll
-    for (int i = 0; i < vdr; ++i) {
-        // SIMD dot product of quantized values
-        sumi = dpct::dp4a(v[i], u[i], sumi);
-    }
-
-    return d8_0*d8_1 * sumi;
-}
-
-template <int vdr>
-static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,
-                                                    const sycl::half2 &dm8,
-                                                    const sycl::half2 &ds8) {
-
-    int sumi = 0;
-
-#pragma unroll
-    for (int i = 0; i < vdr; ++i) {
-        // SIMD dot product of quantized values
-        sumi = dpct::dp4a(v[i], u[i], sumi);
-    }
-
-#ifdef GGML_SYCL_F16
-    const sycl::float2 tmp =
-        (dm8 * ds8).convert<float, sycl::rounding_mode::automatic>();
-    const float d8d8 = tmp.x();
-    const float m8s8 = tmp.y();
-#else
-    const sycl::float2 dm8f =
-        dm8.convert<float, sycl::rounding_mode::automatic>();
-    const sycl::float2 ds8f =
-        ds8.convert<float, sycl::rounding_mode::automatic>();
-    const float d8d8 = dm8f.x() * ds8f.x();
-    const float m8s8 = dm8f.y() * ds8f.y();
-#endif // GGML_SYCL_F16
-
-    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
-    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
-}
-
-#define VDR_Q2_K_Q8_1_MMVQ 1
-#define VDR_Q2_K_Q8_1_MMQ  2
-
-// contiguous v/x values
-static __dpct_inline__ float vec_dot_q2_K_q8_1_impl_mmvq(
-    const int &v, const int *__restrict__ u, const uint8_t *__restrict__ scales,
-    const sycl::half2 &dm2, const float *__restrict__ d8) {
-
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-#pragma unroll
-    for (int i = 0; i < QR2_K; ++i) {
-        const int sc = scales[2*i];
-
-        const int vi = (v >> (2*i)) & 0x03030303;
-
-        sumf_d +=
-            d8[i] * (dpct::dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
-
-        // fill int with 4x m
-        int m = sc >> 4;
-        m |= m <<  8;
-        m |= m << 16;
-        sumf_m += d8[i] *
-                  dpct::dp4a(
-                      m, u[i],
-                      0); // multiply constant q2_K part with sum of q8_1 values
-    }
-
-    const sycl::float2 dm2f =
-        dm2.convert<float, sycl::rounding_mode::automatic>();
-
-    return dm2f.x() * sumf_d - dm2f.y() * sumf_m;
-}
-
-// contiguous u/y values
-static __dpct_inline__ float
-vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
-                           const uint8_t *__restrict__ scales,
-                           const sycl::half2 &dm2, const float &d8) {
-
-    int sumi_d = 0;
-    int sumi_m = 0;
-
-#pragma unroll
-    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
-        int sumi_d_sc = 0;
-
-        const int sc = scales[i0 / (QI8_1/2)];
-
-        // fill int with 4x m
-        int m = sc >> 4;
-        m |= m <<  8;
-        m |= m << 16;
-
-#pragma unroll
-        for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
-            sumi_m = dpct::dp4a(m, u[i],
-                                sumi_m); // multiply sum of q8_1 values with m
-        }
-
-        sumi_d += sumi_d_sc * (sc & 0xF);
-    }
-
-    const sycl::float2 dm2f =
-        dm2.convert<float, sycl::rounding_mode::automatic>();
-
-    return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
-}
-
-#define VDR_Q3_K_Q8_1_MMVQ 1
-#define VDR_Q3_K_Q8_1_MMQ  2
-
-// contiguous v/x values
-static __dpct_inline__ float vec_dot_q3_K_q8_1_impl_mmvq(
-    const int &vl, const int &vh, const int *__restrict__ u,
-    const uint8_t *__restrict__ scales, const int &scale_offset,
-    const float &d3, const float *__restrict__ d8) {
-
-    float sumf = 0.0f;
-
-#pragma unroll
-    for (int i = 0; i < QR3_K; ++i) {
-        const int isc = scale_offset + 2*i;
-
-        const int isc_low = isc % (QK_K/32);
-        const int sc_shift_low = 4 * (isc / (QK_K/32));
-        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;
-
-        const int isc_high = isc % (QK_K/64);
-        const int sc_shift_high = 2 * (isc / (QK_K/64));
-        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
-
-        const int sc = (sc_low | sc_high) - 32;
-
-        const int vil = (vl >> (2*i)) & 0x03030303;
-
-        const int vih = ((vh >> i) << 2) & 0x04040404;
-
-        const int vi =
-            dpct::vectorized_binary<sycl::char4>(vil, vih, dpct::sub_sat());
-
-        sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
-    }
-
-    return d3 * sumf;
-}
-
-// contiguous u/y values
-static __dpct_inline__ float
-vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
-                           const int8_t *__restrict__ scales, const float &d3,
-                           const float &d8) {
-
-    int sumi = 0;
-
-#pragma unroll
-    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
-        int sumi_sc = 0;
-
-        for (int i = i0; i < i0 + QI8_1/2; ++i) {
-            sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
-        }
-
-        sumi += sumi_sc * scales[i0 / (QI8_1/2)];
-    }
-
-    return d3*d8 * sumi;
-}
-
-#define VDR_Q4_K_Q8_1_MMVQ 2
-#define VDR_Q4_K_Q8_1_MMQ  8
-
-// contiguous v/x values
-static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_vmmq(
-    const int *__restrict__ v, const int *__restrict__ u,
-    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
-    const sycl::half2 &dm4, const float *__restrict__ d8) {
-
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-#pragma unroll
-    for (int i = 0; i < QR4_K; ++i) {
-        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
-        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
-
-        const int dot1 =
-            dpct::dp4a(v1i, u[2 * i + 1],
-                       dpct::dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product
-        const int dot2 =
-            dpct::dp4a(0x01010101, u[2 * i + 1],
-                       dpct::dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u
-
-        sumf_d += d8[i] * (dot1 * sc[i]);
-        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
-    }
-
-    const sycl::float2 dm4f =
-        dm4.convert<float, sycl::rounding_mode::automatic>();
-
-    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
-}
-
-// contiguous u/y values
-static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
-    const int *__restrict__ v, const int *__restrict__ u,
-    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
-    const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
-
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-#pragma unroll
-    for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
-        int sumi_d = 0;
-
-#pragma unroll
-        for (int j = 0; j < QI8_1; ++j) {
-            sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
-                                u[i * QI8_1 + j], sumi_d); // SIMD dot product
-        }
-
-        const sycl::float2 ds8f =
-            ds8[i].convert<float, sycl::rounding_mode::automatic>();
-
-        sumf_d += ds8f.x() * (sc[i] * sumi_d);
-        sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
-    }
-
-    const sycl::float2 dm4f =
-        dm4.convert<float, sycl::rounding_mode::automatic>();
-
-    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
-}
-
-#define VDR_Q5_K_Q8_1_MMVQ 2
-#define VDR_Q5_K_Q8_1_MMQ  8
-
-// contiguous v/x values
-static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_vmmq(
-    const int *__restrict__ vl, const int *__restrict__ vh,
-    const int *__restrict__ u, const uint8_t *__restrict__ sc,
-    const uint8_t *__restrict__ m, const sycl::half2 &dm5,
-    const float *__restrict__ d8) {
-
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-#pragma unroll
-    for (int i = 0; i < QR5_K; ++i) {
-        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
-        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
-
-        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
-        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
-
-        const int v0i = vl0i | vh0i;
-        const int v1i = vl1i | vh1i;
-
-        const int dot1 =
-            dpct::dp4a(v0i, u[2 * i + 0],
-                       dpct::dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product
-        const int dot2 =
-            dpct::dp4a(0x01010101, u[2 * i + 0],
-                       dpct::dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u
-
-        sumf_d += d8[i] * (dot1 * sc[i]);
-        sumf_m += d8[i] * (dot2 * m[i]);
-
-    }
-
-    const sycl::float2 dm5f =
-        dm5.convert<float, sycl::rounding_mode::automatic>();
-
-    return dm5f.x() * sumf_d - dm5f.y() * sumf_m;
-}
-
-// contiguous u/y values
-static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
-    const int *__restrict__ v, const int *__restrict__ u,
-    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
-    const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
-
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-#pragma unroll
-    for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
-        int sumi_d = 0;
-
-#pragma unroll
-        for (int j = 0; j < QI8_1; ++j) {
-            sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
-                                sumi_d); // SIMD dot product
-        }
-
-        const sycl::float2 ds8f =
-            ds8[i].convert<float, sycl::rounding_mode::automatic>();
-
-        sumf_d += ds8f.x() * (sc[i] * sumi_d);
-        sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
-    }
-
-    const sycl::float2 dm4f =
-        dm4.convert<float, sycl::rounding_mode::automatic>();
-
-    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
-}
-
-#define VDR_Q6_K_Q8_1_MMVQ 1
-#define VDR_Q6_K_Q8_1_MMQ  8
-
-// contiguous v/x values
-static __dpct_inline__ float
-vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
-                            const int *__restrict__ u,
-                            const int8_t *__restrict__ scales, const float &d,
-                            const float *__restrict__ d8) {
-
-    float sumf = 0.0f;
-
-#pragma unroll
-    for (int i = 0; i < QR6_K; ++i) {
-        const int sc = scales[4*i];
-
-        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
-
-        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
-
-        const int vi = dpct::vectorized_binary<sycl::char4>(
-            (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32
-
-        sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
-    }
-
-    return d*sumf;
-}
-
-// contiguous u/y values
-static __dpct_inline__ float
-vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
-                           const int8_t *__restrict__ sc, const float &d6,
-                           const float *__restrict__ d8) {
-
-    float sumf_d = 0.0f;
-
-#pragma unroll
-    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
-        sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
-
-#pragma unroll
-        for (int i = i0; i < i0 + 2; ++i) {
-            sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
-                                    sumi_d.x()); // SIMD dot product
-            sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
-                                    sumi_d.x()); // SIMD dot product
-
-            sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
-                                    sumi_d.y()); // SIMD dot product
-            sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
-                                    sumi_d.y()); // SIMD dot product
-        }
-
-        sumf_d += d8[i0 / 4] *
-                  (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
-    }
-
-    return d6 * sumf_d;
-}
-
-static __dpct_inline__ float
-vec_dot_q4_0_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
-
-    int v[VDR_Q4_0_Q8_1_MMVQ];
-    int u[2*VDR_Q4_0_Q8_1_MMVQ];
-
-#pragma unroll
-    for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
-        v[i]     = get_int_from_uint8(bq4_0->qs, iqs + i);
-        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
-        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
-    }
-
-    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
-    (void)x_qh; (void)x_sc;
-
-    *x_ql = tile_x_qs_q4_0;
-    *x_dm = (sycl::half2 *)tile_x_d_q4_0;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh; (void)x_sc;
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI4_0;
-    const int kqsx = k % QI4_0;
-
-    const block_q4_0 * bx0 = (const block_q4_0 *) vx;
-
-    float * x_dmf = (float *) x_dm;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
-        // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
-        int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
-    }
-}
-
-static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh; (void)x_sc;
-
-    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-    const float * x_dmf = (const float *) x_dm;
-
-    int u[2*VDR_Q4_0_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
-    }
-
-    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
-        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
-         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
-}
-
-static __dpct_inline__ float
-vec_dot_q4_1_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
-
-    int v[VDR_Q4_1_Q8_1_MMVQ];
-    int u[2*VDR_Q4_1_Q8_1_MMVQ];
-
-#pragma unroll
-    for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
-        v[i]    = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
-        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
-        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
-    }
-
-    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
-    (void)x_qh; (void)x_sc;
-
-    *x_ql = tile_x_qs_q4_1;
-    *x_dm = tile_x_dm_q4_1;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh; (void)x_sc;
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI4_1;
-    const int kqsx = k % QI4_1;
-
-    const block_q4_1 * bx0 = (const block_q4_1 *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
-        int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
-    }
-}
-
-static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh; (void)x_sc;
-
-    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-
-    int u[2*VDR_Q4_1_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
-    }
-
-    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
-        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
-         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
-}
-
-static __dpct_inline__ float
-vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
-
-    int vl[VDR_Q5_0_Q8_1_MMVQ];
-    int vh[VDR_Q5_0_Q8_1_MMVQ];
-    int  u[2*VDR_Q5_0_Q8_1_MMVQ];
-
-#pragma unroll
-    for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
-        vl[i]    = get_int_from_uint8(bq5_0->qs, iqs + i);
-        vh[i]    = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
-        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
-        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
-    }
-
-    return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
-    (void)x_qh; (void)x_sc;
-
-    *x_ql = tile_x_ql_q5_0;
-    *x_dm = (sycl::half2 *)tile_x_d_q5_0;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh; (void)x_sc;
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI5_0;
-    const int kqsx = k % QI5_0;
-
-    const block_q5_0 * bx0 = (const block_q5_0 *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        const int ql = get_int_from_uint8(bxi->qs, kqsx);
-        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
-
-        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
-        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
-        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
-        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
-        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
-        qs0 = dpct::vectorized_binary<sycl::char4>(
-            qs0, 0x10101010, dpct::sub_sat()); // subtract 16
-
-        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
-
-        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
-        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
-        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
-        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
-        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
-        qs1 = dpct::vectorized_binary<sycl::char4>(
-            qs1, 0x10101010, dpct::sub_sat()); // subtract 16
-
-        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
-    const int kbxd = k % blocks_per_tile_x_row;
-    float * x_dmf = (float *) x_dm;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
-        int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
-    }
-}
-
-static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh; (void)x_sc;
-
-    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-    const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
-    const float * x_dmf = (const float *) x_dm;
-    const float * y_df  = (const float *) y_ds;
-
-    int u[2*VDR_Q5_0_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
-    }
-
-    return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
-        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
-}
-
-static __dpct_inline__ float
-vec_dot_q5_1_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
-
-    int vl[VDR_Q5_1_Q8_1_MMVQ];
-    int vh[VDR_Q5_1_Q8_1_MMVQ];
-    int  u[2*VDR_Q5_1_Q8_1_MMVQ];
-
-#pragma unroll
-    for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
-        vl[i]   = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
-        vh[i]   = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
-        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
-        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
-    }
-
-    return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
-    (void)x_qh; (void)x_sc;
-
-    *x_ql = tile_x_ql_q5_1;
-    *x_dm = tile_x_dm_q5_1;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh; (void)x_sc;
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset < nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI5_1;
-    const int kqsx = k % QI5_1;
-
-    const block_q5_1 * bx0 = (const block_q5_1 *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
-        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
-
-        int qs0 = (ql >>  0) & 0x0F0F0F0F;
-        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
-        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
-        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
-        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
-
-        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
-
-        int qs1 = (ql >>  4) & 0x0F0F0F0F;
-        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
-        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
-        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
-        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
-
-        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
-        int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
-    }
-}
-
-static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh; (void)x_sc;
-
-    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-    const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
-
-    int u[2*VDR_Q5_1_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
-    }
-
-    return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
-        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
-}
-
-static __dpct_inline__ float
-vec_dot_q8_0_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
-
-    int v[VDR_Q8_0_Q8_1_MMVQ];
-    int u[VDR_Q8_0_Q8_1_MMVQ];
-
-#pragma unroll
-    for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
-        v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
-        u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
-    }
-
-    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d,
-                                                      bq8_1->ds[0]);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
-    (void)x_qh; (void)x_sc;
-
-    *x_ql = tile_x_qs_q8_0;
-    *x_dm = (sycl::half2 *)tile_x_d_q8_0;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh; (void)x_sc;
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI8_0;
-    const int kqsx = k % QI8_0;
-    float * x_dmf = (float *) x_dm;
-
-    const block_q8_0 * bx0 = (const block_q8_0 *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
-        int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
-    }
-}
-
-static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh; (void)x_sc;
-
-    const float * x_dmf = (const float *) x_dm;
-    const float * y_df  = (const float *) y_ds;
-
-    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
-        (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
-         y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
-}
-
-static __dpct_inline__ float
-vec_dot_q2_K_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
-
-    const int bq8_offset = QR2_K * (iqs / QI8_1);
-    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
-
-    const uint8_t * scales = bq2_K->scales + scale_offset;
-
-    const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
-    int    u[QR2_K];
-    float d8[QR2_K];
-
-#pragma unroll
-    for (int i = 0; i < QR2_K; ++ i) {
-        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + i].ds[0];
-    }
-
-    return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
-                    int *tile_x_sc_q2_K) {
-    (void)x_qh;
-
-    *x_ql = tile_x_ql_q2_K;
-    *x_dm = tile_x_dm_q2_K;
-    *x_sc = tile_x_sc_q2_K;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh;
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI2_K;
-    const int kqsx = k % QI2_K;
-
-    const block_q2_K * bx0 = (const block_q2_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
-        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
-        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
-
-        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
-    }
-}
-
-static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh;
-
-    const int kbx = k / QI2_K;
-    const int ky  = (k % QI2_K) * QR2_K;
-    const float * y_df = (const float *) y_ds;
-
-    int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
-
-    const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
-    const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
-
-#pragma unroll
-    for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
-        v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
-    }
-
-    const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
-
-    const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
-    return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
-}
-
-static __dpct_inline__ float
-vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
-
-    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
-    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
-
-    const float d = bq3_K->d;
-
-    const int vl = get_int_from_uint8(bq3_K->qs, iqs);
-
-    // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
-    const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
-
-    int    u[QR3_K];
-    float d8[QR3_K];
-
-#pragma unroll
-    for (int i = 0; i < QR3_K; ++i) {
-        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + i].ds[0];
-    }
-
-    return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
-                    int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
-
-    *x_ql = tile_x_ql_q3_K;
-    *x_dm = tile_x_dm_q3_K;
-    *x_qh = tile_x_qh_q3_K;
-    *x_sc = tile_x_sc_q3_K;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI3_K;
-    const int kqsx = k % QI3_K;
-
-    const block_q3_K * bx0 = (const block_q3_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
-    const int kbxd = k % blocks_per_tile_x_row;
-    float * x_dmf = (float *) x_dm;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
-        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
-        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
-
-        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
-        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
-        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
-
-        const int ksc = k % (QI3_K/4);
-
-        const int ksc_low = ksc % (QI3_K/8);
-        const int shift_low = 4 * (ksc / (QI3_K/8));
-        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
-
-        const int ksc_high = QI3_K/8;
-        const int shift_high = 2 * ksc;
-        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
-
-        const int sc = dpct::vectorized_binary<sycl::char4>(
-            sc_low | sc_high, 0x20202020, dpct::sub_sat());
-
-        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
-    }
-}
-
-static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-
-    const int kbx  = k / QI3_K;
-    const int ky  = (k % QI3_K) * QR3_K;
-    const float * x_dmf = (const float *) x_dm;
-    const float * y_df  = (const float *) y_ds;
-
-    const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
-
-    int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
-        const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
-        const int shift = 2 * ((ky % 32) / 8);
-        const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
-
-        const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
-        const int vlh = (vh << 2) & 0x04040404;
-
-        v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
-    }
-
-    const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
-    return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
-}
-
-static __dpct_inline__ float
-vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
-
-    int    v[2];
-    int    u[2*QR4_K];
-    float d8[QR4_K];
-
-    // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
-    const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
-
-    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
-    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
-    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
-    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
-
-    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
-    v[0] = q4[0];
-    v[1] = q4[4];
-
-    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
-    uint16_t aux[2];
-    const int j = bq8_offset/2;
-    if (j < 2) {
-        aux[0] = scales[j+0] & 0x3f3f;
-        aux[1] = scales[j+2] & 0x3f3f;
-    } else {
-        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
-        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
-    }
-    const uint8_t * sc = (const uint8_t *)aux;
-    const uint8_t * m  = sc + 2;
-
-    for (int i = 0; i < QR4_K; ++i) {
-        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        d8[i] = bq8i->ds[0];
-
-        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
-        u[2*i+0] = q8[0];
-        u[2*i+1] = q8[4];
-    }
-
-    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
-                    int *tile_x_sc_q4_K) {
-    (void)x_qh;
-
-    *x_ql = tile_x_ql_q4_K;
-    *x_dm = tile_x_dm_q4_K;
-    *x_sc = tile_x_sc_q4_K;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh;
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI4_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI4_K; // == k if QK_K == 256
-
-    const block_q4_K * bx0 = (const block_q4_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
-        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
-
-        const int * scales = (const int *) bxi->scales;
-
-        const int ksc = k % (WARP_SIZE/8);
-
-        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
-        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
-        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
-
-        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
-    }
-}
-
-static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh;
-
-    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
-
-    const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
-    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
-                                      x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
-}
-
-static __dpct_inline__ float
-vec_dot_q5_K_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
-
-    int   vl[2];
-    int   vh[2];
-    int    u[2*QR5_K];
-    float d8[QR5_K];
-
-    const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
-    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
-    const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
-
-    vl[0] = ql[0];
-    vl[1] = ql[4];
-
-    vh[0] = qh[0] >> bq8_offset;
-    vh[1] = qh[4] >> bq8_offset;
-
-    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
-    uint16_t aux[2];
-    const int j = bq8_offset/2;
-    if (j < 2) {
-        aux[0] = scales[j+0] & 0x3f3f;
-        aux[1] = scales[j+2] & 0x3f3f;
-    } else {
-        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
-        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
-    }
-    const uint8_t * sc = (const uint8_t *)aux;
-    const uint8_t * m  = sc + 2;
-
-#pragma unroll
-    for (int i = 0; i < QR5_K; ++i) {
-        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        d8[i] = bq8i->ds[0];
-
-        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
-        u[2*i+0] = q8[0];
-        u[2*i+1] = q8[4];
-    }
-
-    return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
-                    int *tile_x_sc_q5_K) {
-    (void)x_qh;
-
-    *x_ql = tile_x_ql_q5_K;
-    *x_dm = tile_x_dm_q5_K;
-    *x_sc = tile_x_sc_q5_K;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh;
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI5_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI5_K; // == k if QK_K == 256
-
-    const block_q5_K * bx0 = (const block_q5_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
-        const int ky = QR5_K*kqsx;
-
-        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
-        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
-        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
-
-        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
-        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
-        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
-
-        const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
-        const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
-
-        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
-        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
-        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
-
-        const int * scales = (const int *) bxi->scales;
-
-        const int ksc = k % (WARP_SIZE/8);
-
-        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
-        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
-        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
-
-        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
-    }
-}
-
-static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh;
-
-    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
-
-    const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k;
-    const int index_y = j * WARP_SIZE             + (QR5_K*k) % WARP_SIZE;
-    return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
-                                      x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
-}
-
-static __dpct_inline__ float
-vec_dot_q6_K_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
-
-    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
-    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
-    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
-
-    const int vl = get_int_from_uint8(bq6_K->ql, iqs);
-    const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
-
-    const int8_t * scales = bq6_K->scales + scale_offset;
-
-    int    u[QR6_K];
-    float d8[QR6_K];
-
-#pragma unroll
-    for (int i = 0; i < QR6_K; ++i) {
-        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + 2 * i].ds[0];
-    }
-
-    return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
-}
-
-template <int mmq_y>
-static __dpct_inline__ void
-allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
-                    int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
-    (void)x_qh;
-
-    *x_ql = tile_x_ql;
-    *x_dm = tile_x_dm;
-    *x_sc = tile_x_sc;
-}
-
-template <int mmq_y, int nwarps, bool need_check>
-static __dpct_inline__ void
-load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
-                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
-                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
-                const int &k, const int &blocks_per_row) {
-    (void)x_qh;
-
-    GGML_SYCL_ASSUME(i_offset >= 0);
-    GGML_SYCL_ASSUME(i_offset <  nwarps);
-    GGML_SYCL_ASSUME(k >= 0);
-    GGML_SYCL_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI6_K; // == k if QK_K == 256
-
-    const block_q6_K * bx0 = (const block_q6_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
-        const int ky = QR6_K*kqsx;
-
-        const int ql = get_int_from_uint8(bxi->ql, kqsx);
-        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
-        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
-
-        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
-        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
-        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
-
-        const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
-        const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
-
-        x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
-            dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
-                                                 dpct::sub_sat());
-        x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
-            dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
-                                                 dpct::sub_sat());
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
-    float * x_dmf = (float *) x_dm;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
-        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
-
-        if (need_check) {
-            i = sycl::min(i, i_max);
-        }
-
-        const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
-
-        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
-    }
-}
-
-static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
-    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
-    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
-    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
-    const int &i, const int &j, const int &k) {
-    (void)x_qh;
-
-    const float * x_dmf = (const float *) x_dm;
-    const float * y_df  = (const float *) y_ds;
-
-    const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
-
-    const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k;
-    const int index_y = j * WARP_SIZE             + (QR6_K*k) % WARP_SIZE;
-    return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
-}
-
-
-static __dpct_inline__ float
-vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
-                     const block_q8_1 *__restrict__ bq8_1, const int &iqs,
-                     const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs,
-                     const uint8_t *kmask_iq2xs) {
-    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
-
-#if QR2_XXS == 8
-    const int ib32 = iqs;
-    const uint16_t * q2 = bq2->qs + 4*ib32;
-    const uint8_t  * aux8 = (const uint8_t *)q2;
-    const int8_t   * q8 = bq8_1[ib32].qs;
-    uint32_t aux32 = q2[2] | (q2[3] << 16);
-    int sumi = 0;
-    for (int l = 0; l < 4; ++l) {
-        const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
-        const uint8_t  signs = ksigns_iq2xs[aux32 & 127];
-        for (int j = 0; j < 8; ++j) {
-            sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-        }
-        q8 += 8;
-        aux32 >>= 7;
-    }
-    const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f;
-    return d * sumi;
-#else
-    // iqs is 0...15
-    const int ib32 = iqs/2;
-    const int il = iqs%2;
-    const uint16_t * q2 = bq2->qs + 4*ib32;
-    const uint8_t  * aux8 = (const uint8_t *)q2;
-    const uint8_t  * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
-    const uint8_t  * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
-    const uint32_t aux32 = q2[2] | (q2[3] << 16);
-    const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * bq8_1[ib32].ds[0] * 0.25f;
-    const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
-    const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
-    const int8_t * q8 = bq8_1[ib32].qs + 16*il;
-    int sumi1 = 0, sumi2 = 0;
-    for (int j = 0; j < 8; ++j) {
-        sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
-        sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
-    }
-    return d * (sumi1 + sumi2);
-#endif
-}
-
-static __dpct_inline__ float
-vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
-                    const block_q8_1 *__restrict__ bq8_1, const int &iqs,
-                    const uint64_t *iq2xs_grid, const uint64_t *ksigns64) {
-#if DPCT_COMPATIBILITY_TEMP >=                                                 \
-    MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
-
-    const int ib32 = iqs;
-    const uint16_t * q2 = bq2->qs + 4*ib32;
-    const int8_t   * q8 = bq8_1[ib32].qs;
-    const uint8_t ls1 = bq2->scales[ib32] & 0xf;
-    const uint8_t ls2 = bq2->scales[ib32] >>  4;
-    int sumi1 = 0;
-    for (int l = 0; l < 2; ++l) {
-        const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
-        const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
-        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
-            grid[0] ^ signs[0], signs[0], std::minus<>());
-        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
-            grid[1] ^ signs[1], signs[1], std::minus<>());
-        sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
-        sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
-        q8 += 8;
-    }
-    int sumi2 = 0;
-    for (int l = 2; l < 4; ++l) {
-        const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
-        const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
-        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
-            grid[0] ^ signs[0], signs[0], std::minus<>());
-        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
-            grid[1] ^ signs[1], signs[1], std::minus<>());
-        sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
-        sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
-        q8 += 8;
-    }
-    const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
-    return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
-#else
-    assert(false);
-    return 0.f;
-#endif
-}
-
-static __dpct_inline__ float
-vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
-                   const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-    const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
-
-    const int ib32 = iqs;
-    const int8_t  * q8 = bq8_1[ib32].qs;
-    const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
-    const uint8_t ls1 = bq2->scales[ib32] & 0xf;
-    const uint8_t ls2 = bq2->scales[ib32] >>  4;
-    int sumi1 = 0;
-    for (int l = 0; l < 2; ++l) {
-        const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
-        const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
-            ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
-            std::equal_to<>());
-        const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
-            ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
-            std::equal_to<>());
-        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
-            grid[0] ^ signs0, signs0, std::minus<>());
-        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
-            grid[1] ^ signs1, signs1, std::minus<>());
-        sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
-        sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
-        q8 += 8;
-    }
-    int sumi2 = 0;
-    for (int l = 2; l < 4; ++l) {
-        const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
-        const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
-            ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
-            std::equal_to<>());
-        const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
-            ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
-            std::equal_to<>());
-        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
-            grid[0] ^ signs0, signs0, std::minus<>());
-        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
-            grid[1] ^ signs1, signs1, std::minus<>());
-        sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
-        sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
-        q8 += 8;
-    }
-    const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
-    return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
-}
-
-static __dpct_inline__ float
-vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
-                     const block_q8_1 *__restrict__ bq8_1, const int &iqs,
-                     const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) {
-#if DPCT_COMPATIBILITY_TEMP >=                                                 \
-    MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
-
-    const int ib32 = iqs;
-    const uint8_t  * q3 = bq2->qs + 8*ib32;
-    const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
-    const int8_t   * q8 = bq8_1[ib32].qs;
-    uint32_t aux32 = gas[0] | (gas[1] << 16);
-    int sumi = 0;
-    for (int l = 0; l < 4; ++l) {
-        const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
-        const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
-        const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
-        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
-            grid1[0] ^ signs[0], signs[0], std::minus<>());
-        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
-            grid2[0] ^ signs[1], signs[1], std::minus<>());
-        sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
-        sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
-        q8 += 8;
-        aux32 >>= 7;
-    }
-    const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f;
-    return d * sumi;
-#else
-    assert(false);
-    return 0.f;
-#endif
-}
-
-static __dpct_inline__ float
-vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
-                   const block_q8_1 *__restrict__ bq8_1, const int &iqs,
-                   const uint32_t *iq3s_grid) {
-    const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
-
-    const int ib32 = iqs;
-    const uint8_t  * qs = bq2->qs + 8*ib32;
-    const int8_t   * q8 = bq8_1[ib32].qs;
-    int sumi = 0;
-    for (int l = 0; l < 4; ++l) {
-        const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
-        const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
-        uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
-            ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,
-            0x08040201, std::equal_to<>());
-        uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
-            ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,
-            0x08040201, std::equal_to<>());
-        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
-            grid1[0] ^ signs0, signs0, std::minus<>());
-        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
-            grid2[0] ^ signs1, signs1, std::minus<>());
-        sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
-        sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
-        q8 += 8;
-    }
-    const float d =
-        (float)bq2->d *
-        (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
-        bq8_1[ib32].ds[0];
-    return d * sumi;
-}
-
-static __dpct_inline__ float
-vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
-                   const block_q8_1 *__restrict__ bq8_1, const int &iqs,
-                   const uint32_t *iq1s_grid_gpu) {
-    const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
-
-    const int ib32 = iqs;
-    int sumi = 0;
-    const int * q8 = (const int *)bq8_1[ib32].qs;
-    for (int l = 0; l < 4; ++l) {
-        const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
-        int grid0 = grid[0] & 0x0f0f0f0f;
-        int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
-        sumi = dpct::dp4a(q8[2 * l + 1], grid1,
-                          dpct::dp4a(q8[2 * l + 0], grid0, sumi));
-    }
-
-    const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
-    const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
-    const float d = d1q * bq8_1[ib32].ds[0];
-    const float m = d1q * bq8_1[ib32].ds[1];
-    return d * sumi + m * delta;
-}
-
-static __dpct_inline__ float
-vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
-                   const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-    const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
-
-    const int ib32 = iqs;
-    int   sumi[2] = {0, 0};
-    float sumf[2] = {0.f, 0.f};
-
-    const int * q8 = (const int *)bq8_1[ib32].qs;
-    for (int l = 0; l < 4; ++l) {
-        const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
-        int grid0 = grid[0] & 0x0f0f0f0f;
-        int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
-        sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,
-                                 dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));
-        const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
-        const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,
-                                    dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));
-        sumf[l/2] += delta*sumy;
-    }
-
-    iq1m_scale_t scale;
-    const uint16_t * sc = (const uint16_t *)bq1->scales;
-    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-    const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
-    return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
-}
-
-static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
-                                                  const uint8_t *values,
-                                                  int &val1, int &val2) {
-
-    uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
-    aux32 = q4 & 0x0f0f0f0f;
-    uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
-    uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
-    val1 = v1 | (v2 << 16);
-    aux32 = (q4 >> 4) & 0x0f0f0f0f;
-    v1 = values[q8[0]] | (values[q8[1]] << 8);
-    v2 = values[q8[2]] | (values[q8[3]] << 8);
-    val2 = v1 | (v2 << 16);
-}
-
-
-static __dpct_inline__ float
-vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
-                    const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
-
-    const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
-    const int32_t  * q8 = (const int32_t  *)bq8_1->qs + iqs;
-
-    const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
-
-    int v1, v2;
-    int sumi1 = 0, sumi2 = 0;
-    for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
-        const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
-        get_int_from_table_16(aux, values, v1, v2);
-        sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);
-        sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);
-    }
-
-    const float d = (float)bq->d * bq8_1->ds[0];
-    return d * (sumi1 + sumi2);
-}
-
-
-static __dpct_inline__ float
-vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
-                    const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
-    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
-    const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
-
-    // iqs is 0...7
-    const int ib32 = iqs;
-    const int32_t  * q8 = (const int *)bq8_1[ib32].qs;
-    const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
-    const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
-    const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];
-    int v1, v2;
-    int sumi1 = 0, sumi2 = 0;
-    for (int j = 0; j < 4; ++j) {
-        get_int_from_table_16(q4[j], values, v1, v2);
-        sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);
-        sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
-    }
-    return d * (sumi1 + sumi2);
-}
-
-template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
-          int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
-          vec_dot_q_mul_mat_sycl_t vec_dot>
-/*
-DPCT1110:8: The total declared local variable size in device function mul_mat_q
-exceeds 128 bytes and may cause high register pressure. Consult with your
-hardware vendor to find the total register size available and adjust the code,
-or use smaller sub-group size to avoid high register pressure.
-*/
-static __dpct_inline__ void
-mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
-          float *__restrict__ dst, const int ncols_x, const int nrows_x,
-          const int ncols_y, const int nrows_y, const int nrows_dst,
-          int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
-          int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
-          sycl::half2 *tile_y_ds) {
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    const int blocks_per_row_x = ncols_x / qk;
-    const int blocks_per_col_y = nrows_y / QK8_1;
-    const int blocks_per_warp = WARP_SIZE / qi;
-
-    const int & ncols_dst = ncols_y;
-
-    const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
-    const int & row_x_0 = row_dst_0;
-
-    const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
-    const int & col_y_0 = col_dst_0;
-
-    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
-
-    for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
-
-        load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
-                   tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
-                   nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
-                   blocks_per_row_x);
-
-#pragma unroll
-        for (int ir = 0; ir < qr; ++ir) {
-            const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
-            const int kbxd = kqs / QI8_1;
-
-#pragma unroll
-            for (int i = 0; i < mmq_x; i += nwarps) {
-                const int col_y_eff = dpct::min(
-                    (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
-                    ncols_y - 1); // to prevent out-of-bounds memory accesses
-
-                const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
-
-                const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
-                                    kqs % WARP_SIZE;
-                tile_y_qs[index_y] = get_int_from_int8_aligned(
-                    by0->qs, item_ct1.get_local_id(2) % QI8_1);
-            }
-
-#pragma unroll
-            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
-                const int ids =
-                    (ids0 + item_ct1.get_local_id(1) * QI8_1 +
-                     item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
-                    mmq_x;
-                const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
-                const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
-
-                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
-                const sycl::half2 *dsi_src =
-                    &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
-                       ir * (WARP_SIZE / QI8_1) + kby]
-                         .ds;
-                sycl::half2 *dsi_dst =
-                    &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
-                if (need_sum) {
-                    *dsi_dst = *dsi_src;
-                } else {
-                    float * dfi_dst = (float *) dsi_dst;
-                    *dfi_dst = (*dsi_src)[0];
-                }
-            }
-
-            /*
-            DPCT1118:9: SYCL group functions and algorithms must be encountered
-            in converged control flow. You may need to adjust the code.
-            */
-            /*
-            DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
-            sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
-            better performance if there is no access to global memory.
-            */
-            item_ct1.barrier();
-
-// #pragma unroll // unrolling this loop causes too much register pressure
-            for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
-#pragma unroll
-                for (int j = 0; j < mmq_x; j += nwarps) {
-#pragma unroll
-                    for (int i = 0; i < mmq_y; i += WARP_SIZE) {
-                        sum[i / WARP_SIZE][j / nwarps] += vec_dot(
-                            tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
-                            tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
-                            item_ct1.get_local_id(1) + j, k);
-                    }
-                }
-            }
-
-            /*
-            DPCT1118:10: SYCL group functions and algorithms must be encountered
-            in converged control flow. You may need to adjust the code.
-            */
-            /*
-            DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
-            sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
-            better performance if there is no access to global memory.
-            */
-            item_ct1.barrier();
-        }
-    }
-
-#pragma unroll
-    for (int j = 0; j < mmq_x; j += nwarps) {
-        const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
-
-        if (col_dst >= ncols_dst) {
-            return;
-        }
-
-#pragma unroll
-        for (int i = 0; i < mmq_y; i += WARP_SIZE) {
-            const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
-
-            if (row_dst >= nrows_dst) {
-                continue;
-            }
-
-            dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
-        }
-    }
-}
-
-#define  MMQ_X_Q4_0_RDNA2  64
-#define  MMQ_Y_Q4_0_RDNA2  128
-#define NWARPS_Q4_0_RDNA2  8
-#define  MMQ_X_Q4_0_RDNA1  64
-#define  MMQ_Y_Q4_0_RDNA1  64
-#define NWARPS_Q4_0_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q4_0_AMPERE 4
-#define  MMQ_Y_Q4_0_AMPERE 32
-#define NWARPS_Q4_0_AMPERE 4
-#else
-#define  MMQ_X_Q4_0_AMPERE 64
-#define  MMQ_Y_Q4_0_AMPERE 128
-#define NWARPS_Q4_0_AMPERE 4
-#endif
-#define  MMQ_X_Q4_0_PASCAL 64
-#define  MMQ_Y_Q4_0_PASCAL 64
-#define NWARPS_Q4_0_PASCAL 8
-
-template <bool need_check> static void
-    mul_mat_q4_0(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
-    int *tile_y_qs, sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-
-    const int mmq_x  =  MMQ_X_Q4_0_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;
-    const int nwarps = NWARPS_Q4_0_AMPERE;
-    allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_qs_q4_0, tile_x_d_q4_0);
-    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
-              load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
-              vec_dot_q4_0_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q4_1_RDNA2  64
-#define  MMQ_Y_Q4_1_RDNA2  128
-#define NWARPS_Q4_1_RDNA2  8
-#define  MMQ_X_Q4_1_RDNA1  64
-#define  MMQ_Y_Q4_1_RDNA1  64
-#define NWARPS_Q4_1_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q4_1_AMPERE 4
-#define  MMQ_Y_Q4_1_AMPERE 32
-#define NWARPS_Q4_1_AMPERE 4
-#else
-#define  MMQ_X_Q4_1_AMPERE 64
-#define  MMQ_Y_Q4_1_AMPERE 128
-#define NWARPS_Q4_1_AMPERE 4
-#endif
-#define  MMQ_X_Q4_1_PASCAL 64
-#define  MMQ_Y_Q4_1_PASCAL 64
-#define NWARPS_Q4_1_PASCAL 8
-
-template <bool need_check> static void
-    mul_mat_q4_1(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
-    sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q4_1_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;
-    const int nwarps = NWARPS_Q4_1_AMPERE;
-    allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_qs_q4_1, tile_x_dm_q4_1);
-    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
-              load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
-              vec_dot_q4_1_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q5_0_RDNA2  64
-#define  MMQ_Y_Q5_0_RDNA2  128
-#define NWARPS_Q5_0_RDNA2  8
-#define  MMQ_X_Q5_0_RDNA1  64
-#define  MMQ_Y_Q5_0_RDNA1  64
-#define NWARPS_Q5_0_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q5_0_AMPERE 4
-#define  MMQ_Y_Q5_0_AMPERE 32
-#define NWARPS_Q5_0_AMPERE 4
-#else
-#define  MMQ_X_Q5_0_AMPERE 128
-#define  MMQ_Y_Q5_0_AMPERE 64
-#define NWARPS_Q5_0_AMPERE 4
-#endif
-#define  MMQ_X_Q5_0_PASCAL 64
-#define  MMQ_Y_Q5_0_PASCAL 64
-#define NWARPS_Q5_0_PASCAL 8
-
-template <bool need_check> static void
-    mul_mat_q5_0(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
-    int *tile_y_qs, sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q5_0_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;
-    const int nwarps = NWARPS_Q5_0_AMPERE;
-    allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_ql_q5_0, tile_x_d_q5_0);
-    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
-              load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
-              vec_dot_q5_0_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q5_1_RDNA2  64
-#define  MMQ_Y_Q5_1_RDNA2  128
-#define NWARPS_Q5_1_RDNA2  8
-#define  MMQ_X_Q5_1_RDNA1  64
-#define  MMQ_Y_Q5_1_RDNA1  64
-#define NWARPS_Q5_1_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q5_1_AMPERE 4
-#define  MMQ_Y_Q5_1_AMPERE 32
-#define NWARPS_Q5_1_AMPERE 4
-#else
-#define  MMQ_X_Q5_1_AMPERE 128
-#define  MMQ_Y_Q5_1_AMPERE 64
-#define NWARPS_Q5_1_AMPERE 4
-#endif
-#define  MMQ_X_Q5_1_PASCAL 64
-#define  MMQ_Y_Q5_1_PASCAL 64
-#define NWARPS_Q5_1_PASCAL 8
-
-template <bool need_check> static void
-mul_mat_q5_1(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
-    sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q5_1_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;
-    const int nwarps = NWARPS_Q5_1_AMPERE;
-    allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_ql_q5_1, tile_x_dm_q5_1);
-    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
-              load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
-              vec_dot_q5_1_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q8_0_RDNA2  64
-#define  MMQ_Y_Q8_0_RDNA2  128
-#define NWARPS_Q8_0_RDNA2  8
-#define  MMQ_X_Q8_0_RDNA1  64
-#define  MMQ_Y_Q8_0_RDNA1  64
-#define NWARPS_Q8_0_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q8_0_AMPERE 4
-#define  MMQ_Y_Q8_0_AMPERE 32
-#define NWARPS_Q8_0_AMPERE 4
-#else
-#define  MMQ_X_Q8_0_AMPERE 128
-#define  MMQ_Y_Q8_0_AMPERE 64
-#define NWARPS_Q8_0_AMPERE 4
-#endif
-#define  MMQ_X_Q8_0_PASCAL 64
-#define  MMQ_Y_Q8_0_PASCAL 64
-#define NWARPS_Q8_0_PASCAL 8
-
-template <bool need_check> static void
-    mul_mat_q8_0(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
-    int *tile_y_qs, sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q8_0_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;
-    const int nwarps = NWARPS_Q8_0_AMPERE;
-    allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_qs_q8_0, tile_x_d_q8_0);
-    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
-              load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
-              vec_dot_q8_0_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q2_K_RDNA2  64
-#define  MMQ_Y_Q2_K_RDNA2  128
-#define NWARPS_Q2_K_RDNA2  8
-#define  MMQ_X_Q2_K_RDNA1  128
-#define  MMQ_Y_Q2_K_RDNA1  32
-#define NWARPS_Q2_K_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q2_K_AMPERE 4
-#define  MMQ_Y_Q2_K_AMPERE 32
-#define NWARPS_Q2_K_AMPERE 4
-#else
-#define  MMQ_X_Q2_K_AMPERE 64
-#define  MMQ_Y_Q2_K_AMPERE 128
-#define NWARPS_Q2_K_AMPERE 4
-#endif
-#define  MMQ_X_Q2_K_PASCAL 64
-#define  MMQ_Y_Q2_K_PASCAL 64
-#define NWARPS_Q2_K_PASCAL 8
-
-template <bool need_check> static void
-mul_mat_q2_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
-    sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
-    sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q2_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;
-    const int nwarps = NWARPS_Q2_K_AMPERE;
-    allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
-    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
-              load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
-              vec_dot_q2_K_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q3_K_RDNA2  128
-#define  MMQ_Y_Q3_K_RDNA2  64
-#define NWARPS_Q3_K_RDNA2  8
-#define  MMQ_X_Q3_K_RDNA1  32
-#define  MMQ_Y_Q3_K_RDNA1  128
-#define NWARPS_Q3_K_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q3_K_AMPERE 4
-#define  MMQ_Y_Q3_K_AMPERE 32
-#define NWARPS_Q3_K_AMPERE 4
-#else
-#define  MMQ_X_Q3_K_AMPERE 128
-#define  MMQ_Y_Q3_K_AMPERE 128
-#define NWARPS_Q3_K_AMPERE 4
-#endif
-#define  MMQ_X_Q3_K_PASCAL 64
-#define  MMQ_Y_Q3_K_PASCAL 64
-#define NWARPS_Q3_K_PASCAL 8
-
-template <bool need_check> static void
-mul_mat_q3_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
-    sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
-    int *tile_y_qs, sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q3_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;
-    const int nwarps = NWARPS_Q3_K_AMPERE;
-    allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
-                               tile_x_sc_q3_K);
-    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
-              load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
-              vec_dot_q3_K_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q4_K_RDNA2  64
-#define  MMQ_Y_Q4_K_RDNA2  128
-#define NWARPS_Q4_K_RDNA2  8
-#define  MMQ_X_Q4_K_RDNA1  32
-#define  MMQ_Y_Q4_K_RDNA1  64
-#define NWARPS_Q4_K_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q4_K_AMPERE 4
-#define  MMQ_Y_Q4_K_AMPERE 32
-#define NWARPS_Q4_K_AMPERE 4
-#else
-#define  MMQ_X_Q4_K_AMPERE 64
-#define  MMQ_Y_Q4_K_AMPERE 128
-#define NWARPS_Q4_K_AMPERE 4
-#endif
-#define  MMQ_X_Q4_K_PASCAL 64
-#define  MMQ_Y_Q4_K_PASCAL 64
-#define NWARPS_Q4_K_PASCAL 8
-
-template <bool need_check> static void
-    mul_mat_q4_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
-    sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
-    sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q4_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;
-    const int nwarps = NWARPS_Q4_K_AMPERE;
-    allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
-    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
-              load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
-              vec_dot_q4_K_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q5_K_RDNA2  64
-#define  MMQ_Y_Q5_K_RDNA2  128
-#define NWARPS_Q5_K_RDNA2  8
-#define  MMQ_X_Q5_K_RDNA1  32
-#define  MMQ_Y_Q5_K_RDNA1  64
-#define NWARPS_Q5_K_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q5_K_AMPERE 4
-#define  MMQ_Y_Q5_K_AMPERE 32
-#define NWARPS_Q5_K_AMPERE 4
-#else
-#define  MMQ_X_Q5_K_AMPERE 64
-#define  MMQ_Y_Q5_K_AMPERE 128
-#define NWARPS_Q5_K_AMPERE 4
-#endif
-#define  MMQ_X_Q5_K_PASCAL 64
-#define  MMQ_Y_Q5_K_PASCAL 64
-#define NWARPS_Q5_K_PASCAL 8
-
-template <bool need_check> static void
-mul_mat_q5_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
-    sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
-    sycl::half2 *tile_y_ds) {
-    int   * tile_x_ql = nullptr;
-    sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q5_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;
-    const int nwarps = NWARPS_Q5_K_AMPERE;
-    allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
-    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
-              load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
-              vec_dot_q5_K_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-#define  MMQ_X_Q6_K_RDNA2  64
-#define  MMQ_Y_Q6_K_RDNA2  128
-#define NWARPS_Q6_K_RDNA2  8
-#define  MMQ_X_Q6_K_RDNA1  32
-#define  MMQ_Y_Q6_K_RDNA1  64
-#define NWARPS_Q6_K_RDNA1  8
-#if defined(SYCL_USE_XMX)
-#define  MMQ_X_Q6_K_AMPERE 4
-#define  MMQ_Y_Q6_K_AMPERE 32
-#define NWARPS_Q6_K_AMPERE 4
-#else
-#define  MMQ_X_Q6_K_AMPERE 64
-#define  MMQ_Y_Q6_K_AMPERE 64
-#define NWARPS_Q6_K_AMPERE 4
-#endif
-#define  MMQ_X_Q6_K_PASCAL 64
-#define  MMQ_Y_Q6_K_PASCAL 64
-#define NWARPS_Q6_K_PASCAL 8
-
-template <bool need_check> static void
-    mul_mat_q6_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
-    const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
-    int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
-    // int   * tile_x_ql = nullptr;
-    // sycl::half2 *tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    // int   * tile_x_sc = nullptr;
-
-//sycl_todo: change according to hardware
-    const int mmq_x  =  MMQ_X_Q6_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;
-    const int nwarps = NWARPS_Q6_K_AMPERE;
-    allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
-                               tile_x_ql, tile_x_dm, tile_x_sc);
-    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
-              load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
-              vec_dot_q6_K_q8_1_mul_mat>(
-        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
-        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
-}
-
-template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
-static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-    const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
-
-    // partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
-         i += blocks_per_warp) {
-      const int ibx = row * blocks_per_row + i; // x block index
-
-      const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
-
-      const int iqs =
-          vdr *
-          (item_ct1.get_local_id(2) -
-           i * qi_vdr); // x block quant index when casting the quants to int
-
-      tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
-                                       const void *__restrict__ vy,
-                                       float *__restrict__ dst, const int ncols,
-                                       const int nrows,
-                                       const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
-                                      const void *__restrict__ vy,
-                                      float *__restrict__ dst, const int ncols,
-                                      const int nrows,
-                                      const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
-                                     const void *__restrict__ vy,
-                                     float *__restrict__ dst, const int ncols,
-                                     const int nrows,
-                                     const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
-                                       const void *__restrict__ vy,
-                                       float *__restrict__ dst, const int ncols,
-                                       const int nrows,
-                                       const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
-                                     const void *__restrict__ vy,
-                                     float *__restrict__ dst, const int ncols,
-                                     const int nrows,
-                                     const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
-                                     const void *__restrict__ vy,
-                                     float *__restrict__ dst, const int ncols,
-                                     const int nrows,
-                                     const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
-                                     const void *__restrict__ vy,
-                                     float *__restrict__ dst, const int ncols,
-                                     const int nrows,
-                                     const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
-                                      const void *__restrict__ vy,
-                                      float *__restrict__ dst, const int ncols,
-                                      const int nrows,
-                                      const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-
-template <int qk, int qi, typename block_q_t, int vdr>
-static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
-                                      const void *__restrict__ vy,
-                                      float *__restrict__ dst, const int ncols,
-                                      const int nrows,
-                                      const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
-// partial sum for each thread
-    float tmp = 0.0f;
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
-         i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i; // x block index
-
-        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
-
-        const int iqs =
-            vdr *
-            (item_ct1.get_local_id(2) %
-             (qi / vdr)); // x block quant index when casting the quants to int
-
-        tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[row] = tmp;
-    }
-}
-
-
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
-                                   const sycl::nd_item<3> &item_ct1) {
-    // qk = quantized weights per x block
-    // qr = number of quantized weights per data value in x block
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-
-    if (row >= nrows) {
-        return;
-    }
-
-    const int tid = item_ct1.get_local_id(2);
-
-    const int iter_stride = 2*GGML_SYCL_DMMV_X;
-    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
-    const int y_offset = qr == 1 ? 1 : qk/2;
-
-// partial sum for each thread
-#ifdef GGML_SYCL_F16
-    sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
-#else
-    float tmp = 0.0f;
-#endif // GGML_SYCL_F16
-
-    for (int i = 0; i < ncols; i += iter_stride) {
-        const int col = i + vals_per_iter*tid;
-        const int ib = (row*ncols + col)/qk; // x block index
-        const int iqs = (col%qk)/qr; // x quant index
-        const int iybs = col - col%qk; // y block start index
-
-// processing >2 values per i iter is faster for fast GPUs
-#pragma unroll
-        for (int j = 0; j < vals_per_iter; j += 2) {
-            // process 2 vals per j iter
-
-            // dequantize
-            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
-            dfloat2 v;
-            dequantize_kernel(vx, ib, iqs + j/qr, v);
-
-            // matrix multiplication
-            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
-#ifdef GGML_SYCL_F16
-            dfloat2 t1{y[iybs + iqs + j / qr + 0],
-                        y[iybs + iqs + j / qr + y_offset]};
-
-            tmp += v * t1;
-#else
-            tmp += v.x() * y[iybs + iqs + j / qr + 0];
-            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
-#endif // GGML_SYCL_F16
-        }
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (tid == 0) {
-#ifdef GGML_SYCL_F16
-        dst[row] = tmp.x() + tmp.y();
-#else
-        dst[row] = tmp;
-#endif // GGML_SYCL_F16
-    }
-}
-
-static void mul_mat_p021_f16_f32(
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
-    const sycl::nd_item<3> &item_ct1) {
-
-    const sycl::half *x = (const sycl::half *)vx;
-
-    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                      item_ct1.get_local_id(1);
-    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                        item_ct1.get_local_id(0);
-    const int channel_x = channel / (nchannels_y / nchannels_x);
-
-    const int nrows_y = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst = row_x;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x;
-         col_x0 += item_ct1.get_local_range(2)) {
-        const int col_x = col_x0 + item_ct1.get_local_id(2);
-
-        if (col_x >= ncols_x) {
-            break;
-        }
-
-        // x is transposed and permuted
-        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
-        const float xi =
-            sycl::vec<sycl::half, 1>(x[ix])
-                .convert<float, sycl::rounding_mode::automatic>()[0];
-
-        const int row_y = col_x;
-
-
-        // y is not transposed but permuted
-        const int iy = channel*nrows_y + row_y;
-
-        tmp += xi * y[iy];
-    }
-
-    // dst is not transposed and not permuted
-    const int idst = channel*nrows_dst + row_dst;
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[idst] = tmp;
-    }
-}
-
-static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
-    const sycl::nd_item<3> &item_ct1) {
-
-    const sycl::half *x = (const sycl::half *)vx;
-
-    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                      item_ct1.get_local_id(1);
-    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                        item_ct1.get_local_id(0);
-    const int channel_x = channel / channel_x_divisor;
-
-    const int nrows_y   = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst   = row_x;
-
-    const int idst = channel*nrows_dst + row_dst;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x;
-         col_x0 += item_ct1.get_local_range(2)) {
-        const int col_x = col_x0 + item_ct1.get_local_id(2);
-
-        if (col_x >= ncols_x) {
-            break;
-        }
-
-        const int row_y = col_x;
-
-        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
-        const int iy = channel*nrows_y + row_y;
-
-        const float xi =
-            sycl::vec<sycl::half, 1>(x[ix])
-                .convert<float, sycl::rounding_mode::automatic>()[0];
-
-        tmp += xi * y[iy];
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[idst] = tmp;
-    }
-}
-
-static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    float * dsti = (float *) cdsti;
-
-    *dsti = *xi;
-}
-
-static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    sycl::half *dsti = (sycl::half *)cdsti;
-
-    *dsti = sycl::vec<float, 1>(*xi)
-                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
-}
-
-static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
-    const sycl::half *xi = (const sycl::half *)cxi;
-    sycl::half *dsti = (sycl::half *)cdsti;
-
-    *dsti = *xi;
-}
-
-static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
-    const sycl::half *xi = (const sycl::half *)cxi;
-    float * dsti = (float *) cdsti;
-
-    *dsti = *xi;
-}
-
-static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
-    const int16_t *xi = (const int16_t *)cxi;
-    int16_t *dsti = (int16_t *)cdsti;
-
-    *dsti = *xi;
-}
-
-static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
-    const int32_t *xi = (const int32_t *)cxi;
-    int32_t *dsti = (int32_t *)cdsti;
-
-    *dsti = *xi;
-}
-
-template <cpy_kernel_t cpy_1>
-static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
-                        const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                        const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                        const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= ne) {
-        return;
-    }
-
-    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
-    // then combine those indices with the corresponding byte offsets to get the total offsets
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
-
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
-
-    cpy_1(cx + x_offset, cdst + dst_offset);
-}
-
-static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q8_0 * dsti = (block_q8_0 *) cdsti;
-
-    float amax = 0.0f; // absolute max
-
-    for (int j = 0; j < QK8_0; j++) {
-        const float v = xi[j];
-        amax = sycl::fmax(amax, sycl::fabs((float)v));
-    }
-
-    const float d = amax / ((1 << 7) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->d = d;
-
-    for (int j = 0; j < QK8_0; ++j) {
-        const float x0 = xi[j]*id;
-
-        dsti->qs[j] = sycl::round((float)x0);
-    }
-}
-
-static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q4_0 * dsti = (block_q4_0 *) cdsti;
-
-    float amax = 0.0f;
-    float vmax = 0.0f;
-
-    for (int j = 0; j < QK4_0; ++j) {
-        const float v = xi[j];
-        if (amax < sycl::fabs((float)v)) {
-            amax = sycl::fabs((float)v);
-            vmax = v;
-        }
-    }
-
-    const float d  = vmax / -8;
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->d = d;
-
-    for (int j = 0; j < QK4_0/2; ++j) {
-        const float x0 = xi[0       + j]*id;
-        const float x1 = xi[QK4_0/2 + j]*id;
-
-        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
-        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
-
-        dsti->qs[j]  = xi0;
-        dsti->qs[j] |= xi1 << 4;
-    }
-}
-
-static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q4_1 * dsti = (block_q4_1 *) cdsti;
-
-    float vmin = FLT_MAX;
-    float vmax = -FLT_MAX;
-
-    for (int j = 0; j < QK4_1; ++j) {
-        const float v = xi[j];
-
-        if (v < vmin) vmin = v;
-        if (v > vmax) vmax = v;
-    }
-
-    const float d  = (vmax - vmin) / ((1 << 4) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->dm.x() = d;
-    dsti->dm.y() = vmin;
-
-    for (int j = 0; j < QK4_1/2; ++j) {
-        const float x0 = (xi[0       + j] - vmin)*id;
-        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
-
-        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
-        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
-
-        dsti->qs[j]  = xi0;
-        dsti->qs[j] |= xi1 << 4;
-    }
-}
-
-template <cpy_kernel_t cpy_blck, int qk>
-static void cpy_f32_q(const char * cx, char * cdst, const int ne,
-                      const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                      const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                      const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
-    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                   item_ct1.get_local_id(2)) *
-                  qk;
-
-    if (i >= ne) {
-        return;
-    }
-
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
-
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
-
-    cpy_blck(cx + x_offset, cdst + dst_offset);
-}
-
-static float rope_yarn_ramp(const float low, const float high, const int i0) {
-    const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
-    return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
-}
-
-struct rope_corr_dims {
-    float v[4];
-};
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
-    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta
-) {
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
-    }
-    *cos_theta = sycl::cos(theta) * mscale;
-    *sin_theta = sycl::sin(theta) * mscale;
-}
-
-// rope == RoPE == rotary positional embedding
-template<typename T, bool has_pos>
-static void rope(
-    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims
-,
-    const sycl::nd_item<3> &item_ct1) {
-    const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                         item_ct1.get_local_id(1));
-
-    if (col >= ncols) {
-        return;
-    }
-
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int i = row*ncols + col;
-    const int i2 = row/p_delta_rows;
-
-    const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p * dpct::pow(freq_base, -float(col) / ncols);
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + 1];
-
-    dst[i + 0] = x0*cos_theta - x1*sin_theta;
-    dst[i + 1] = x0*sin_theta + x1*cos_theta;
-}
-
-template<typename T, bool has_pos, bool has_freq_facs>
-static void rope_neox(
-    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
-    const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
-    const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                         item_ct1.get_local_id(1));
-
-    if (col >= ncols) {
-        return;
-    }
-
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int ib = col / n_dims;
-    const int ic = col % n_dims;
-
-    if (ib > 0) {
-        const int i = row*ncols + ib*n_dims + ic;
-
-        dst[i + 0] = x[i + 0];
-        dst[i + 1] = x[i + 1];
-
-        return;
-    }
-
-    const int i  = row*ncols + ib*n_dims + ic/2;
-    const int i2 = row/p_delta_rows;
-
-    float cur_rot = inv_ndims * ic - ib;
-
-    const int p = has_pos ? pos[i2] : 0;
-    const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
-
-    const float theta_base =
-        p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
-
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
-}
-
-static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
-                           const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(1);
-    const int col = item_ct1.get_local_id(2);
-
-    float sum = 0.0f;
-    for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
-        sum += x[row * ncols + i];
-    }
-
-    sum = warp_reduce_sum(sum, item_ct1);
-
-    if (col == 0) {
-        dst[row] = sum;
-    }
-}
-
-
-template<typename T>
-static inline void ggml_sycl_swap(T & a, T & b) {
-    T tmp = a;
-    a = b;
-    b = tmp;
-}
-
-template <ggml_sort_order order>
-__dpct_inline__ static void
-k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
-                  const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
-    // bitonic sort
-    int col = item_ct1.get_local_id(2);
-    int row = item_ct1.get_group(1);
-
-    if (col >= ncols_pad) {
-        return;
-    }
-
-    const float * x_row = x + row * ncols;
-    auto dst_row = (int *)dpct_local;
-
-    // initialize indices
-    dst_row[col] = col;
-
-    item_ct1.barrier(sycl::access::fence_space::local_space);
-
-    for (int k = 2; k <= ncols_pad; k *= 2) {
-        for (int j = k / 2; j > 0; j /= 2) {
-            int ixj = col ^ j;
-            if (ixj > col) {
-                if ((col & k) == 0) {
-                    if (dst_row[col] >= ncols ||
-                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
-                    ) {
-                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
-                    }
-                } else {
-                    if (dst_row[ixj] >= ncols ||
-                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
-                    ) {
-                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
-                    }
-                }
-            }
-            /*
-            DPCT1118:1: SYCL group functions and algorithms must be encountered
-            in converged control flow. You may need to adjust the code.
-            */
-            item_ct1.barrier(sycl::access::fence_space::local_space);
-        }
-    }
-
-    // copy the result to dst without the padding
-    if (col < ncols) {
-        dst[row * ncols + col] = dst_row[col];
-    }
-}
-
-
-static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
-                              const sycl::nd_item<3> &item_ct1) {
-    const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-
-    if (col >= ncols) {
-        return;
-    }
-
-    const int i = row*ncols + col;
-    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
-    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
-    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
-}
-
-
-template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
-                         const int nrows_y, const float scale, const float max_bias, const float m0,
-                         const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
-    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
-
-    const int tid = item_ct1.get_local_id(2);
-    const int rowx = item_ct1.get_group(2);
-    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
-
-    const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
-
-    const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-    const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-
-    float slope = 1.0f;
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const uint32_t h = rowx/nrows_y; // head index
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slope = sycl::pow(base, float(exp));
-    }
-
-    float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
-    float max_val = -INFINITY;
-
-    for (int col0 = 0; col0 < ncols; col0 += block_size) {
-        const int col = col0 + tid;
-
-        if (ncols_template == 0 && col >= ncols) {
-            break;
-        }
-
-        const int ix = rowx*ncols + col;
-        const int iy = rowy*ncols + col;
-
-        const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
-
-        vals[col] = val;
-        max_val = sycl::max(max_val, val);
-    }
-
-    // find the max value in the block
-    max_val = warp_reduce_max(max_val, item_ct1);
-    if (block_size > WARP_SIZE) {
-        if (warp_id == 0) {
-            buf[lane_id] = -INFINITY;
-        }
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        if (lane_id == 0) {
-            buf[warp_id] = max_val;
-        }
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        max_val = buf[lane_id];
-        max_val = warp_reduce_max(max_val, item_ct1);
-    }
-
-    float tmp = 0.f;
-
-#pragma unroll
-    for (int col0 = 0; col0 < ncols; col0 += block_size) {
-        const int col = col0 + tid;
-                if (ncols_template == 0 && col >= ncols) {
-            break;
-        }
-
-        const float val = sycl::native::exp(vals[col] - max_val);
-        tmp += val;
-        vals[col] = val;
-    }
-
-    // find the sum of exps in the block
-    tmp = warp_reduce_sum(tmp, item_ct1);
-    if (block_size > WARP_SIZE) {
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-        if (warp_id == 0) {
-            buf[lane_id] = 0.f;
-        }
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        if (lane_id == 0) {
-            buf[warp_id] = tmp;
-        }
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        tmp = buf[lane_id];
-        tmp = warp_reduce_sum(tmp, item_ct1);
-    }
-
-    const float inv_sum = 1.f / tmp;
-
-#pragma unroll
-    for (int col0 = 0; col0 < ncols; col0 += block_size) {
-        const int col = col0 + tid;
-
-        if (ncols_template == 0 && col >= ncols) {
-            return;
-        }
-
-        const int idst = rowx*ncols + col;
-        dst[idst] = vals[col] * inv_sum;
-    }
-}
-
-static void scale_f32(const float * x, float * dst, const float scale, const int k,
-                      const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-
-    dst[i] = scale * x[i];
-}
-
-static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
-                      const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-
-    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
-}
-
-template <typename T>
-static void im2col_kernel(const float *x, T *dst, int offset_delta,
-                           int IW, int IH, int OW, int KW, int KH,
-                           int pelements, int CHW, int s0, int s1, int p0,
-                           int p1, int d0, int d1,
-                           const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_id(2) +
-                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
-    if (i >= pelements) {
-        return;
-    }
-
-    const int ksize = OW * (KH > 1 ? KW : 1);
-    const int kx = i / ksize;
-    const int kd = kx * ksize;
-    const int ky = (i - kd) / OW;
-    const int ix = i % OW;
-
-    const int64_t iiw = ix * s0 + kx * d0 - p0;
-    const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
-
-    const int64_t offset_dst =
-        (item_ct1.get_group(1) * OW + ix) * CHW +
-        (item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
-
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-        dst[offset_dst] =
-            sycl::vec<float, 1>(0.0f)
-                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
-    } else {
-        const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
-        dst[offset_dst] =
-            sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
-                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
-    }
-}
-
-template <typename Ti, typename To>
-static  void pool2d_nchw_kernel(
-        const int ih, const int iw, const int oh, const int ow,
-        const int kh, const int kw, const int sh, const int sw,
-        const int ph, const int pw, const int parallel_elements,
-        const Ti* src, To* dst, const enum ggml_op_pool op,
-        const sycl::nd_item<3> &item_ct1) {
-        int idx = item_ct1.get_local_id(2) +
-                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
-        if (idx >= parallel_elements) {
-            return;
-        }
-
-        const int I_HW = ih * iw;
-        const int O_HW = oh * ow;
-        const int nc = idx / O_HW;
-        const int cur_oh = idx % O_HW / ow;
-        const int cur_ow = idx % O_HW % ow;
-        const Ti* i_ptr = src + nc * I_HW;
-        To* o_ptr = dst + nc * O_HW;
-        const int start_h = cur_oh * sh - ph;
-        const int bh = sycl::max(0, start_h);
-        const int eh = sycl::min(ih, start_h + kh);
-        const int start_w = cur_ow * sw - pw;
-        const int bw = sycl::max(0, start_w);
-        const int ew = sycl::min(iw, start_w + kw);
-
-        To res = 0;
-
-        switch (op) {
-            case GGML_OP_POOL_AVG: res = 0; break;
-            case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
-        }
-
-        for (int i = bh; i < eh; i += 1) {
-            for (int j = bw; j < ew; j += 1) {
-#if DPCT_COMPATIBILITY_TEMP >= 350
-                /*
-                DPCT1098:106: The '*' expression is used instead of the __ldg
-                call. These two expressions do not provide the exact same
-                functionality. Check the generated code for potential precision
-                and/or performance issues.
-                */
-                Ti cur = *(i_ptr + i * iw + j);
-#else
-                Ti cur = i_ptr[i * iw + j];
-#endif
-                switch (op) {
-                    case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
-                    case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
-                }
-            }
-        }
-        o_ptr[cur_oh * ow + cur_ow] = res;
-}
-
-template <int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                          ggml_tensor *dst, const void *src0_dd,
-                          const int32_t *src1_dd, float *dst_dd,
-                          queue_ptr stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
-    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
-    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
-
-    // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
-
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
-
-    GGML_ASSERT(ne00 % 2 == 0);
-
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             k_get_rows<qk, qr, dq>(
-                                 src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
-                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
-                         });
-
-    (void) dst;
-}
-
-template <typename src0_t>
-static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const src0_t *src0_dd, const int32_t *src1_dd,
-                                float *dst_dd, queue_ptr stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
-    const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
-    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
-
-    // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
-
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
-
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
-                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
-            });
-    }
-
-    (void) dst;
-}
-
-template<float (*bin_op)(const float, const float)>
-struct bin_bcast_sycl {
-    template <typename src0_t, typename src1_t, typename dst_t>
-    void operator()(ggml_backend_sycl_context & ctx,
-                    const struct ggml_tensor *src0,
-                    const struct ggml_tensor *src1, struct ggml_tensor *dst,
-                    const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
-                    queue_ptr stream) {
-
-        GGML_TENSOR_BINARY_OP_LOCALS
-
-        int nr0 = ne10/ne0;
-        int nr1 = ne11/ne1;
-        int nr2 = ne12/ne2;
-        int nr3 = ne13/ne3;
-
-        int nr[4] = { nr0, nr1, nr2, nr3 };
-
-        // collapse dimensions until first broadcast dimension
-        int64_t cne0[] = {ne0, ne1, ne2, ne3};
-        int64_t cne1[] = {ne10, ne11, ne12, ne13};
-        size_t cnb0[] = {nb0, nb1, nb2, nb3};
-        size_t cnb1[] = {nb10, nb11, nb12, nb13};
-        auto collapse = [](int64_t cne[]) {
-            cne[0] *= cne[1];
-            cne[1] = cne[2];
-            cne[2] = cne[3];
-            cne[3] = 1;
-        };
-
-        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
-            cnb[1] *= cne[1];
-            cnb[2] *= cne[2];
-            cnb[3] *= cne[3];
-        };
-
-        for (int i = 0; i < 4; i++) {
-            if (nr[i] != 1) {
-                break;
-            }
-            if (i > 0) {
-                collapse_nb(cnb0, cne0);
-                collapse_nb(cnb1, cne1);
-                collapse(cne0);
-                collapse(cne1);
-            }
-        }
-        {
-            int64_t ne0 = cne0[0];
-            int64_t ne1 = cne0[1];
-            int64_t ne2 = cne0[2];
-            int64_t ne3 = cne0[3];
-
-            int64_t ne10 = cne1[0];
-            int64_t ne11 = cne1[1];
-            int64_t ne12 = cne1[2];
-            int64_t ne13 = cne1[3];
-
-            size_t nb0 = cnb0[0];
-            size_t nb1 = cnb0[1];
-            size_t nb2 = cnb0[2];
-            size_t nb3 = cnb0[3];
-
-            size_t nb10 = cnb1[0];
-            size_t nb11 = cnb1[1];
-            size_t nb12 = cnb1[2];
-            size_t nb13 = cnb1[3];
-
-            size_t s0 = nb0 / sizeof(dst_t);
-            size_t s1 = nb1 / sizeof(dst_t);
-            size_t s2 = nb2 / sizeof(dst_t);
-            size_t s3 = nb3 / sizeof(dst_t);
-
-            size_t s10 = nb10 / sizeof(src1_t);
-            size_t s11 = nb11 / sizeof(src1_t);
-            size_t s12 = nb12 / sizeof(src1_t);
-            size_t s13 = nb13 / sizeof(src1_t);
-
-            GGML_ASSERT(s0 == 1);
-            GGML_ASSERT(s10 == 1);
-
-            const int block_size = 128;
-
-            int64_t hne0 = std::max(ne0/2LL, 1LL);
-
-            sycl::range<3> block_dims(1, 1, 1);
-            block_dims[2] = std::min<unsigned int>(hne0, block_size);
-            block_dims[1] = std::min<unsigned int>(
-                ne1, block_size / (unsigned int)block_dims[2]);
-            block_dims[0] = std::min(
-                std::min<unsigned int>(
-                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /
-                                   (unsigned int)block_dims[1]),
-                64U);
-
-            sycl::range<3> block_nums(
-                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
-                (ne1 + block_dims[1] - 1) / block_dims[1],
-                (hne0 + block_dims[2] - 1) / block_dims[2]);
-
-            if (block_nums[0] > 65535) {
-                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
-                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
-                {
-                    dpct::has_capability_or_fail(stream->get_device(),
-                                                 {sycl::aspect::fp16});
-
-                    stream->parallel_for(
-                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
-                                              sycl::range<3>(1, 1, block_size),
-                                          sycl::range<3>(1, 1, block_size)),
-                        [=](sycl::nd_item<3> item_ct1) {
-                            k_bin_bcast_unravel<bin_op>(
-                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
-                                ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
-                                s13, item_ct1);
-                        });
-                }
-            } else {
-                /*
-                DPCT1049:16: The work-group size passed to the SYCL kernel may
-                exceed the limit. To get the device limit, query
-                info::device::max_work_group_size. Adjust the work-group size if
-                needed.
-                */
-                dpct::has_capability_or_fail(stream->get_device(),
-                                             {sycl::aspect::fp16});
-
-                stream->parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
-                                            ne2, ne3, ne10, ne11, ne12, ne13,
-                                            s1, s2, s3, s11, s12, s13,
-                                            item_ct1);
-                    });
-            }
-        }
-    }
-};
-
-static void acc_f32_sycl(const float *x, const float *y, float *dst,
-                         const int n_elements, const int ne10, const int ne11,
-                         const int ne12, const int nb1, const int nb2,
-                         const int offset, queue_ptr stream) {
-    int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
-                    item_ct1);
-        });
-}
-
-static void gelu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            gelu_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void silu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            silu_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
-                                queue_ptr stream) {
-    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            gelu_quick_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void tanh_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            tanh_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void relu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            relu_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
-                                 queue_ptr stream) {
-    const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            hardsigmoid_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void hardswish_f32_sycl(const float *x, float *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            hardswish_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
-                                const float negative_slope,
-                                queue_ptr stream) {
-    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
-        });
-}
-
-static void sqr_f32_sycl(const float *x, float *dst, const int k,
-                         queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            sqr_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void norm_f32_sycl(const float *x, float *dst, const int ncols,
-                          const int nrows, const float eps,
-                          queue_ptr stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
-    if (ncols < 1024) {
-        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
-                sycl::range<1>(32), cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        norm_f32(x, dst, ncols, eps, item_ct1,
-                                            s_sum_acc_ct1.get_pointer(), WARP_SIZE);
-                    });
-        });
-    } else {
-        // FIXME: 1024 from cuda
-        const int work_group_size = GROUP_SIZE;
-        const sycl::range<3> block_dims(1, 1, work_group_size);
-        /*
-        DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
-                sycl::range<1>(32), cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        norm_f32(x, dst, ncols, eps, item_ct1,
-                                       s_sum_acc_ct1.get_pointer(), work_group_size);
-                    });
-        });
-    }
-}
-
-static void group_norm_f32_sycl(const float *x, float *dst,
-                                const int num_groups, const int group_size,
-                                const int ne_elements, queue_ptr stream) {
-    static const float eps = 1e-6f;
-    if (group_size < 1024) {
-        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
-                                                         cgh);
-
-            const float eps_ct4 = eps;
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        group_norm_f32(
-                            x, dst, group_size, ne_elements, eps_ct4, item_ct1,
-                            s_sum_acc_ct1.get_pointer(), WARP_SIZE);
-                    });
-        });
-    } else {
-        const int work_group_size = GROUP_SIZE;
-        const sycl::range<3> block_dims(1, 1, work_group_size);
-        /*
-        DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
-                                                         cgh);
-
-            const float eps_ct4 = eps;
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        group_norm_f32(x, dst, group_size, ne_elements,
-                                             eps_ct4, item_ct1,
-                                             s_sum_acc_ct1.get_pointer(), work_group_size);
-                    });
-        });
-    }
-}
-
-static void concat_f32_sycl(const float *x, const float *y, float *dst,
-                            const int ne0, int ne1, int ne2, int ne02,
-                            queue_ptr stream) {
-    int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
-    sycl::range<3> gridDim(ne2, ne1, num_blocks);
-    stream->parallel_for(
-        sycl::nd_range<3>(gridDim *
-                              sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            concat_f32(x, y, dst, ne0, ne02, item_ct1);
-        });
-}
-
-static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
-                             const int nb02, const int nb03, const int ne10, const int ne11,
-                             const int ne12, const int ne13, const float sf0, const float sf1,
-                             const float sf2, const float sf3, queue_ptr stream) {
-    int dst_size = ne10 * ne11 * ne12 * ne13;
-    int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
-    sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
-    stream->parallel_for(
-        sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
-        [=](sycl::nd_item<1> item_ct1) {
-            upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
-        });
-}
-
-static void pad_f32_sycl(const float *x, float *dst, const int ne00,
-                         const int ne01, const int ne02, const int ne0,
-                         const int ne1, const int ne2, queue_ptr stream) {
-    int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
-    sycl::range<3> gridDim(ne2, ne1, num_blocks);
-    stream->parallel_for(
-        sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
-        });
-}
-
-static void rms_norm_f32_sycl(const float *x, float *dst, const int ncols,
-                              const int nrows, const float eps,
-                              queue_ptr stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
-    // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
-    if (ncols < 1024) {
-        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
-                                                         cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        rms_norm_f32(x, dst, ncols, eps, item_ct1,
-                                                s_sum_acc_ct1.get_pointer(), WARP_SIZE);
-                    });
-        });
-    } else {
-        const int work_group_size = GROUP_SIZE;
-        const sycl::range<3> block_dims(1, 1, work_group_size);
-        /*
-        DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
-                                                         cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        rms_norm_f32(x, dst, ncols, eps, item_ct1,
-                                           s_sum_acc_ct1.get_pointer(), work_group_size);
-                    });
-        });
-    }
-}
-
-static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
-                                   const int ky, const int kx_padded,
-                                   queue_ptr stream) {
-    const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
-    const sycl::range<3> num_blocks(1, ky, block_num_x);
-    const sycl::range<3> block_size(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(num_blocks * block_size, block_size),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                quantize_q8_1(x, vy, kx, kx_padded, item_ct1);
-            });
-    }
-}
-
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block_sycl(const void *__restrict__ vx,
-                                  dst_t *__restrict__ y, const int k,
-                                  queue_ptr stream) {
-    const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-        stream->parallel_for(
-            sycl::nd_range<3>(
-                sycl::range<3>(1, 1, num_blocks) *
-                    sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
-                sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
-            });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
-                                     queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 64),
-                                               sycl::range<3>(1, 1, 64)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_q2_K(vx, y, item_ct1);
-                             });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
-                                     queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 64),
-                                               sycl::range<3>(1, 1, 64)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_q3_K(vx, y, item_ct1);
-                             });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
-                                     queue_ptr stream) {
-    const int nb32 = k / 32;
-    const int nb = (k + 255) / 256;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_q4_0(vx, y, nb32, item_ct1);
-                             });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
-                                     queue_ptr stream) {
-    const int nb32 = k / 32;
-    const int nb = (k + 255) / 256;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_q4_1(vx, y, nb32, item_ct1);
-                             });
-    }
-}
-
-
-template <typename dst_t>
-static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
-                                     queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_q4_K(vx, y, item_ct1);
-                             });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
-                                     queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 64),
-                                               sycl::range<3>(1, 1, 64)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_q5_K(vx, y, item_ct1);
-                             });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
-                                     queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 64),
-                                               sycl::range<3>(1, 1, 64)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_q6_K(vx, y, item_ct1);
-                             });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
-                                        queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_iq1_s(
-                                     vx, y, item_ct1, iq1s_grid_gpu
-                                     );
-                             });
-        });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
-                                        queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_iq1_m(
-                                     vx, y, item_ct1, iq1s_grid_gpu
-                                     );
-                             });
-        });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
-                                        queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_iq2_xxs(
-                                     vx, y, item_ct1, iq2xxs_grid,
-                                     ksigns_iq2xs, kmask_iq2xs);
-                             });
-        });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
-                                       queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_iq2_xs(
-                                     vx, y, item_ct1, iq2xs_grid,
-                                     ksigns_iq2xs, kmask_iq2xs);
-                             });
-        });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
-                                      queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_iq2_s(vx, y, item_ct1);
-                             });
-        });
-    }
-}
-
-
-template <typename dst_t>
-static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
-                                        queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_iq3_xxs(
-                                     vx, y, item_ct1, iq3xxs_grid,
-                                     ksigns_iq2xs, kmask_iq2xs);
-                             });
-        });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
-                                        queue_ptr stream) {
-    const int nb = k / QK_K;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                                   sycl::range<3>(1, 1, 32),
-                                               sycl::range<3>(1, 1, 32)),
-                             [=](sycl::nd_item<3> item_ct1) {
-                                 dequantize_block_iq3_s(
-                                     vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
-                             });
-        });
-    }
-}
-
-template <typename dst_t>
-static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
-                                       queue_ptr stream) {
-    const int nb = (k + QK_K - 1) / QK_K;
-      {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            stream->submit([&](sycl::handler &cgh) {
-                  cgh.parallel_for(
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                            sycl::range<3>(1, 1, 32),
-                                        sycl::range<3>(1, 1, 32)),
-                      [=](sycl::nd_item<3> item_ct1) {
-                            dequantize_block_iq4_xs(vx, y, item_ct1);
-                      });
-            });
-      }
-}
-
-
-template <typename dst_t>
-static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
-                                       queue_ptr stream) {
-    const int nb = (k + QK_K - 1) / QK_K;
-      {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            stream->submit([&](sycl::handler &cgh) {
-                  cgh.parallel_for(
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
-                                            sycl::range<3>(1, 1, 32),
-                                        sycl::range<3>(1, 1, 32)),
-                      [=](sycl::nd_item<3> item_ct1) {
-                            dequantize_block_iq4_nl(vx, y, item_ct1);
-                      });
-            });
-      }
-}
-
-
-
-template <typename src_t, typename dst_t>
-static void convert_unary_sycl(const void *__restrict__ vx,
-                               dst_t *__restrict__ y, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(
-                sycl::range<3>(1, 1, num_blocks) *
-                    sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
-                sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                convert_unary<src_t>(vx, y, k, item_ct1);
-            });
-    }
-}
-
-
-static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
-    int id;
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
-        case GGML_TYPE_Q4_1:
-            return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
-        case GGML_TYPE_Q5_0:
-            return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
-        case GGML_TYPE_Q5_1:
-            return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
-        case GGML_TYPE_Q8_0:
-            return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
-        case GGML_TYPE_Q2_K:
-            return dequantize_row_q2_K_sycl;
-        case GGML_TYPE_Q3_K:
-            return dequantize_row_q3_K_sycl;
-        case GGML_TYPE_Q4_K:
-            return dequantize_row_q4_K_sycl;
-        case GGML_TYPE_Q5_K:
-            return dequantize_row_q5_K_sycl;
-        case GGML_TYPE_Q6_K:
-            return dequantize_row_q6_K_sycl;
-        case GGML_TYPE_IQ1_S:
-            return dequantize_row_iq1_s_sycl;
-        case GGML_TYPE_IQ1_M:
-            return dequantize_row_iq1_m_sycl;
-        case GGML_TYPE_IQ2_XXS:
-            return dequantize_row_iq2_xxs_sycl;
-        case GGML_TYPE_IQ2_XS:
-            return dequantize_row_iq2_xs_sycl;
-        case GGML_TYPE_IQ2_S:
-            return dequantize_row_iq2_s_sycl;
-        case GGML_TYPE_IQ3_XXS:
-            return dequantize_row_iq3_xxs_sycl;
-        case GGML_TYPE_IQ3_S:
-            return dequantize_row_iq3_s_sycl;
-        case GGML_TYPE_IQ4_XS:
-            return dequantize_row_iq4_xs_sycl;
-        case GGML_TYPE_IQ4_NL:
-            return dequantize_row_iq4_nl_sycl;
-        case GGML_TYPE_F32:
-            return convert_unary_sycl<float>;
-        default:
-            return nullptr;
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_row_q4_0_sycl;
-        case GGML_TYPE_Q4_1:
-            return dequantize_row_q4_1_sycl;
-        case GGML_TYPE_Q5_0:
-            return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
-        case GGML_TYPE_Q5_1:
-            return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
-        case GGML_TYPE_Q8_0:
-            return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
-        case GGML_TYPE_Q2_K:
-            return dequantize_row_q2_K_sycl;
-        case GGML_TYPE_Q3_K:
-            return dequantize_row_q3_K_sycl;
-        case GGML_TYPE_Q4_K:
-            return dequantize_row_q4_K_sycl;
-        case GGML_TYPE_Q5_K:
-            return dequantize_row_q5_K_sycl;
-        case GGML_TYPE_Q6_K:
-            return dequantize_row_q6_K_sycl;
-        case GGML_TYPE_IQ1_S:
-            return dequantize_row_iq1_s_sycl;
-        case GGML_TYPE_IQ1_M:
-            return dequantize_row_iq1_m_sycl;
-        case GGML_TYPE_IQ2_XXS:
-            return dequantize_row_iq2_xxs_sycl;
-        case GGML_TYPE_IQ2_XS:
-            return dequantize_row_iq2_xs_sycl;
-        case GGML_TYPE_IQ2_S:
-            return dequantize_row_iq2_s_sycl;
-        case GGML_TYPE_IQ3_XXS:
-            return dequantize_row_iq3_xxs_sycl;
-        case GGML_TYPE_IQ3_S:
-            return dequantize_row_iq3_s_sycl;
-        case GGML_TYPE_IQ4_XS:
-            return dequantize_row_iq4_xs_sycl;
-        case GGML_TYPE_IQ4_NL:
-            return dequantize_row_iq4_nl_sycl;
-        case GGML_TYPE_F16:
-            return convert_unary_sycl<sycl::half>;
-        default:
-            return nullptr;
-    }
-}
-
-static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
-                    vx, y, dst, ncols, nrows, item_ct1);
-            });
-    }
-}
-
-static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
-                    vx, y, dst, ncols, nrows, item_ct1);
-            });
-    }
-}
-
-static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
-                    vx, y, dst, ncols, nrows, item_ct1);
-            });
-    }
-}
-
-static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
-                    vx, y, dst, ncols, nrows, item_ct1);
-            });
-    }
-}
-
-static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
-                    vx, y, dst, ncols, nrows, item_ct1);
-            });
-    }
-}
-
-static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, ny, 32);
-    stream->parallel_for(
-        sycl::nd_range<3>(block_nums * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-            dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
-        });
-}
-
-static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2 / K_QUANTS_PER_ITERATION;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, ny, 32);
-    stream->parallel_for(
-        sycl::nd_range<3>(block_nums * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-            dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
-        });
-}
-
-static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2 / K_QUANTS_PER_ITERATION;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, ny, 32);
-    stream->parallel_for(
-        sycl::nd_range<3>(block_nums * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-            dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
-        });
-}
-
-static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const sycl::range<3> block_dims(1, 1, 32);
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-            dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
-        });
-}
-
-static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
-                                             float *dst, const int ncols,
-                                             const int nrows,
-                                             queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int ny = 2 / K_QUANTS_PER_ITERATION;
-    const int block_num_y = (nrows + ny - 1) / ny;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, ny, 32);
-    stream->parallel_for(
-        sycl::nd_range<3>(block_nums * block_dims, block_dims),
-        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-            dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
-        });
-}
-
-static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
-                                         float *dst, const int ncols,
-                                         const int nrows,
-                                         queue_ptr stream) {
-    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
-                                                          nrows, item_ct1);
-            });
-    }
-}
-
-
-static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK4_0 == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
-                                      VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK4_1 == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
-                                      VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK5_0 == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
-                                      VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK5_1 == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
-                                      VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK8_0 == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
-                                      VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
-                                      VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
-                                      VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
-                                      VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
-                                      VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
-                                       float *dst, const int ncols,
-                                       const int nrows,
-                                       queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
-                                      VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-
-static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
-                                          float *dst, const int ncols,
-                                          const int nrows,
-                                          queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS, block_iq2_xxs, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
-                                         float *dst, const int ncols,
-                                         const int nrows,
-                                         queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-            auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
-            auto ksigns64_ptr_ct1 = &ksigns64[0];
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS, block_iq2_xs, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
-                                         float *dst, const int ncols,
-                                         const int nrows,
-                                         queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-            auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
-            auto ksigns64_ptr_ct1 = &ksigns64[0];
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
-                                          float *dst, const int ncols,
-                                          const int nrows,
-                                          queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-            auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0];
-            auto ksigns64_ptr_ct1 = &ksigns64[0];
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS, block_iq3_xxs, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
-                                          float *dst, const int ncols,
-                                          const int nrows,
-                                          queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-            auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_XS, block_iq3_s, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
-                                          float *dst, const int ncols,
-                                          const int nrows,
-                                          queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-            auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0];
-            auto ksigns64_ptr_ct1 = &ksigns64[0];
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
-                                          float *dst, const int ncols,
-                                          const int nrows,
-                                          queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
-                                          float *dst, const int ncols,
-                                          const int nrows,
-                                          queue_ptr stream) {
-    GGML_ASSERT(ncols % QK4_NL == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
-                                          float *dst, const int ncols,
-                                          const int nrows,
-                                          queue_ptr stream) {
-    GGML_ASSERT(ncols % QK_K == 0);
-    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
-    const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
-    {
-
-        stream->submit([&](sycl::handler &cgh) {
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS, block_iq4_xs, 1>(
-                            vx, vy, dst, ncols, nrows, item_ct1);
-                    });
-        });
-    }
-}
-
-static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
-
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q4_0_RDNA2;
-        mmq_y  =  MMQ_Y_Q4_0_RDNA2;
-        nwarps = NWARPS_Q4_0_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q4_0_RDNA1;
-        mmq_y  =  MMQ_Y_Q4_0_RDNA1;
-        nwarps = NWARPS_Q4_0_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q4_0_AMPERE;
-        mmq_y  =  MMQ_Y_Q4_0_AMPERE;
-        nwarps = NWARPS_Q4_0_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q4_0_PASCAL;
-        mmq_y  =  MMQ_Y_Q4_0_PASCAL;
-        nwarps = NWARPS_Q4_0_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q4_0<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_qs_q4_0_acc_ct1.get_pointer(),
-                            tile_x_d_q4_0_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    } else {
-        const bool need_check = true;
-        /*
-        DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q4_0<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_qs_q4_0_acc_ct1.get_pointer(),
-                            tile_x_d_q4_0_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
-
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q4_1_RDNA2;
-        mmq_y  =  MMQ_Y_Q4_1_RDNA2;
-        nwarps = NWARPS_Q4_1_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q4_1_RDNA1;
-        mmq_y  =  MMQ_Y_Q4_1_RDNA1;
-        nwarps = NWARPS_Q4_1_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q4_1_AMPERE;
-        mmq_y  =  MMQ_Y_Q4_1_AMPERE;
-        nwarps = NWARPS_Q4_1_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q4_1_PASCAL;
-        mmq_y  =  MMQ_Y_Q4_1_PASCAL;
-        nwarps = NWARPS_Q4_1_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q4_1<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_qs_q4_1_acc_ct1.get_pointer(),
-                            tile_x_dm_q4_1_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    } else {
-        const bool need_check = true;
-        /*
-        DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q4_1<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_qs_q4_1_acc_ct1.get_pointer(),
-                            tile_x_dm_q4_1_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
-
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q5_0_RDNA2;
-        mmq_y  =  MMQ_Y_Q5_0_RDNA2;
-        nwarps = NWARPS_Q5_0_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q5_0_RDNA1;
-        mmq_y  =  MMQ_Y_Q5_0_RDNA1;
-        nwarps = NWARPS_Q5_0_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q5_0_AMPERE;
-        mmq_y  =  MMQ_Y_Q5_0_AMPERE;
-        nwarps = NWARPS_Q5_0_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q5_0_PASCAL;
-        mmq_y  =  MMQ_Y_Q5_0_PASCAL;
-        nwarps = NWARPS_Q5_0_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q5_0<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q5_0_acc_ct1.get_pointer(),
-                            tile_x_d_q5_0_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
+        for (int i = 0; i < 4; i++) {
+            if (nr[i] != 1) {
+                break;
+            }
+            if (i > 0) {
+                collapse_nb(cnb0, cne0);
+                collapse_nb(cnb1, cne1);
+                collapse(cne0);
+                collapse(cne1);
+            }
         }
-    } else {
-        const bool need_check = true;
-        /*
-        DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
         {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+            int64_t ne0 = cne0[0];
+            int64_t ne1 = cne0[1];
+            int64_t ne2 = cne0[2];
+            int64_t ne3 = cne0[3];
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q5_0<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q5_0_acc_ct1.get_pointer(),
-                            tile_x_d_q5_0_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+            int64_t ne10 = cne1[0];
+            int64_t ne11 = cne1[1];
+            int64_t ne12 = cne1[2];
+            int64_t ne13 = cne1[3];
 
-static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
+            size_t nb0 = cnb0[0];
+            size_t nb1 = cnb0[1];
+            size_t nb2 = cnb0[2];
+            size_t nb3 = cnb0[3];
 
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q5_1_RDNA2;
-        mmq_y  =  MMQ_Y_Q5_1_RDNA2;
-        nwarps = NWARPS_Q5_1_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q5_1_RDNA1;
-        mmq_y  =  MMQ_Y_Q5_1_RDNA1;
-        nwarps = NWARPS_Q5_1_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q5_1_AMPERE;
-        mmq_y  =  MMQ_Y_Q5_1_AMPERE;
-        nwarps = NWARPS_Q5_1_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q5_1_PASCAL;
-        mmq_y  =  MMQ_Y_Q5_1_PASCAL;
-        nwarps = NWARPS_Q5_1_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
+            size_t nb10 = cnb1[0];
+            size_t nb11 = cnb1[1];
+            size_t nb12 = cnb1[2];
+            size_t nb13 = cnb1[3];
 
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
 
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+            size_t s10 = nb10 / sizeof(src1_t);
+            size_t s11 = nb11 / sizeof(src1_t);
+            size_t s12 = nb12 / sizeof(src1_t);
+            size_t s13 = nb13 / sizeof(src1_t);
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
-                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q5_1<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q5_1_acc_ct1.get_pointer(),
-                            tile_x_dm_q5_1_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    } else {
-        const bool need_check = true;
-        /*
-        DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s10 == 1);
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
-                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q5_1<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q5_1_acc_ct1.get_pointer(),
-                            tile_x_dm_q5_1_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+            const int block_size = 128;
 
-static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
+            int64_t hne0 = std::max(ne0/2LL, 1LL);
 
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q8_0_RDNA2;
-        mmq_y  =  MMQ_Y_Q8_0_RDNA2;
-        nwarps = NWARPS_Q8_0_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q8_0_RDNA1;
-        mmq_y  =  MMQ_Y_Q8_0_RDNA1;
-        nwarps = NWARPS_Q8_0_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q8_0_AMPERE;
-        mmq_y  =  MMQ_Y_Q8_0_AMPERE;
-        nwarps = NWARPS_Q8_0_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q8_0_PASCAL;
-        mmq_y  =  MMQ_Y_Q8_0_PASCAL;
-        nwarps = NWARPS_Q8_0_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
+            sycl::range<3> block_dims(1, 1, 1);
+            block_dims[2] = std::min<unsigned int>(hne0, block_size);
+            block_dims[1] = std::min<unsigned int>(
+                ne1, block_size / (unsigned int)block_dims[2]);
+            block_dims[0] = std::min(
+                std::min<unsigned int>(
+                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /
+                                   (unsigned int)block_dims[1]),
+                64U);
 
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+            sycl::range<3> block_nums(
+                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
+                (ne1 + block_dims[1] - 1) / block_dims[1],
+                (hne0 + block_dims[2] - 1) / block_dims[2]);
 
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+            if (block_nums[0] > 65535) {
+                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+                {
+                    dpct::has_capability_or_fail(stream->get_device(),
+                                                 {sycl::aspect::fp16});
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q8_0<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_qs_q8_0_acc_ct1.get_pointer(),
-                            tile_x_d_q8_0_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    } else {
-        const bool need_check = true;
-        /*
-        DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+                    stream->parallel_for(
+                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
+                                              sycl::range<3>(1, 1, block_size),
+                                          sycl::range<3>(1, 1, block_size)),
+                        [=](sycl::nd_item<3> item_ct1) {
+                            k_bin_bcast_unravel<bin_op>(
+                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
+                                ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
+                                s13, item_ct1);
+                        });
+                }
+            } else {
+                /*
+                DPCT1049:16: The work-group size passed to the SYCL kernel may
+                exceed the limit. To get the device limit, query
+                info::device::max_work_group_size. Adjust the work-group size if
+                needed.
+                */
+                dpct::has_capability_or_fail(stream->get_device(),
+                                             {sycl::aspect::fp16});
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
+                stream->parallel_for(
                     sycl::nd_range<3>(block_nums * block_dims, block_dims),
                     [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q8_0<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_qs_q8_0_acc_ct1.get_pointer(),
-                            tile_x_d_q8_0_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
+                        k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
+                                            ne2, ne3, ne10, ne11, ne12, ne13,
+                                            s1, s2, s3, s11, s12, s13,
+                                            item_ct1);
                     });
-            });
+            }
         }
     }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+};
 
-static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
+static void acc_f32_sycl(const float *x, const float *y, float *dst,
+                         const int n_elements, const int ne10, const int ne11,
+                         const int ne12, const int nb1, const int nb2,
+                         const int offset, queue_ptr stream) {
+    int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
+                    item_ct1);
+        });
+}
 
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q2_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q2_K_RDNA2;
-        nwarps = NWARPS_Q2_K_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q2_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q2_K_RDNA1;
-        nwarps = NWARPS_Q2_K_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q2_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q2_K_AMPERE;
-        nwarps = NWARPS_Q2_K_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q2_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q2_K_PASCAL;
-        nwarps = NWARPS_Q2_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
+static void gelu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            gelu_f32(x, dst, k, item_ct1);
+        });
+}
 
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+static void silu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            silu_f32(x, dst, k, item_ct1);
+        });
+}
 
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
+                                queue_ptr stream) {
+    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            gelu_quick_f32(x, dst, k, item_ct1);
+        });
+}
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q2_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q2_K_acc_ct1.get_pointer(),
-                            tile_x_dm_q2_K_acc_ct1.get_pointer(),
-                            tile_x_sc_q2_K_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    } else {
-        const bool need_check = true;
-        /*
-        DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+static void tanh_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            tanh_f32(x, dst, k, item_ct1);
+        });
+}
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q2_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q2_K_acc_ct1.get_pointer(),
-                            tile_x_dm_q2_K_acc_ct1.get_pointer(),
-                            tile_x_sc_q2_K_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    }
+static void relu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            relu_f32(x, dst, k, item_ct1);
+        });
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+
+static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
+                                 queue_ptr stream) {
+    const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            hardsigmoid_f32(x, dst, k, item_ct1);
+        });
 }
 
-static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
+static void hardswish_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            hardswish_f32(x, dst, k, item_ct1);
+        });
+}
 
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q3_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q3_K_RDNA2;
-        nwarps = NWARPS_Q3_K_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q3_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q3_K_RDNA1;
-        nwarps = NWARPS_Q3_K_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q3_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q3_K_AMPERE;
-        nwarps = NWARPS_Q3_K_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q3_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q3_K_PASCAL;
-        nwarps = NWARPS_Q3_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
+static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
+                                const float negative_slope,
+                                queue_ptr stream) {
+    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
+        });
+}
 
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+static void sqr_f32_sycl(const float *x, float *dst, const int k,
+                         queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            sqr_f32(x, dst, k, item_ct1);
+        });
+}
 
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+static void norm_f32_sycl(const float *x, float *dst, const int ncols,
+                          const int nrows, const float eps,
+                          queue_ptr stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    if (ncols < 1024) {
+        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
+                sycl::range<1>(32), cgh);
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q3_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q3_K_acc_ct1.get_pointer(),
-                            tile_x_dm_q3_K_acc_ct1.get_pointer(),
-                            tile_x_qh_q3_K_acc_ct1.get_pointer(),
-                            tile_x_sc_q3_K_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
+            cgh.parallel_for(
+                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+                                  block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        norm_f32(x, dst, ncols, eps, item_ct1,
+                                            s_sum_acc_ct1.get_pointer(), WARP_SIZE);
                     });
-            });
-        }
+        });
     } else {
-        const bool need_check = true;
+        const int work_group_size = get_work_group_size(stream->get_device());
+        const sycl::range<3> block_dims(1, 1, work_group_size);
         /*
-        DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
+        DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
         the limit. To get the device limit, query
         info::device::max_work_group_size. Adjust the work-group size if needed.
         */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
+                sycl::range<1>(32), cgh);
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q3_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q3_K_acc_ct1.get_pointer(),
-                            tile_x_dm_q3_K_acc_ct1.get_pointer(),
-                            tile_x_qh_q3_K_acc_ct1.get_pointer(),
-                            tile_x_sc_q3_K_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
+            cgh.parallel_for(
+                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+                                  block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        norm_f32(x, dst, ncols, eps, item_ct1,
+                                       s_sum_acc_ct1.get_pointer(), work_group_size);
                     });
-            });
-        }
+        });
     }
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
-
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q4_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q4_K_RDNA2;
-        nwarps = NWARPS_Q4_K_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q4_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q4_K_RDNA1;
-        nwarps = NWARPS_Q4_K_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q4_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q4_K_AMPERE;
-        nwarps = NWARPS_Q4_K_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q4_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q4_K_PASCAL;
-        nwarps = NWARPS_Q4_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
 
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+static void group_norm_f32_sycl(const float *x, float *dst,
+                                const int num_groups, const int group_size,
+                                const int ne_elements, queue_ptr stream) {
+    static const float eps = 1e-6f;
+    if (group_size < 1024) {
+        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
+                                                         cgh);
 
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+            const float eps_ct4 = eps;
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q4_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q4_K_acc_ct1.get_pointer(),
-                            tile_x_dm_q4_K_acc_ct1.get_pointer(),
-                            tile_x_sc_q4_K_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
+            cgh.parallel_for(
+                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
+                                  block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        group_norm_f32(
+                            x, dst, group_size, ne_elements, eps_ct4, item_ct1,
+                            s_sum_acc_ct1.get_pointer(), WARP_SIZE);
                     });
-            });
-        }
+        });
     } else {
-        const bool need_check = true;
+        const int work_group_size = get_work_group_size(stream->get_device());
+        const sycl::range<3> block_dims(1, 1, work_group_size);
         /*
-        DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
+        DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
         the limit. To get the device limit, query
         info::device::max_work_group_size. Adjust the work-group size if needed.
         */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q4_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q4_K_acc_ct1.get_pointer(),
-                            tile_x_dm_q4_K_acc_ct1.get_pointer(),
-                            tile_x_sc_q4_K_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
+                                                         cgh);
+
+            const float eps_ct4 = eps;
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
+                                  block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        group_norm_f32(x, dst, group_size, ne_elements,
+                                             eps_ct4, item_ct1,
+                                             s_sum_acc_ct1.get_pointer(), work_group_size);
                     });
-            });
-        }
+        });
     }
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
 
-static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
+static void concat_f32_sycl(const float *x, const float *y, float *dst,
+                            const int ne0, int ne1, int ne2, int ne02,
+                            queue_ptr stream) {
+    int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
+    sycl::range<3> gridDim(ne2, ne1, num_blocks);
+    stream->parallel_for(
+        sycl::nd_range<3>(gridDim *
+                              sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            concat_f32(x, y, dst, ne0, ne02, item_ct1);
+        });
+}
 
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q5_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q5_K_RDNA2;
-        nwarps = NWARPS_Q5_K_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q5_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q5_K_RDNA1;
-        nwarps = NWARPS_Q5_K_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q5_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q5_K_AMPERE;
-        nwarps = NWARPS_Q5_K_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q5_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q5_K_PASCAL;
-        nwarps = NWARPS_Q5_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
+static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
+                             const int nb02, const int nb03, const int ne10, const int ne11,
+                             const int ne12, const int ne13, const float sf0, const float sf1,
+                             const float sf2, const float sf3, queue_ptr stream) {
+    int dst_size = ne10 * ne11 * ne12 * ne13;
+    int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
+    sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
+    stream->parallel_for(
+        sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
+        [=](sycl::nd_item<1> item_ct1) {
+            upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
+        });
+}
 
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+static void pad_f32_sycl(const float *x, float *dst, const int ne00,
+                         const int ne01, const int ne02, const int ne0,
+                         const int ne1, const int ne2, queue_ptr stream) {
+    int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
+    sycl::range<3> gridDim(ne2, ne1, num_blocks);
+    stream->parallel_for(
+        sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
+        });
+}
 
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+static void rms_norm_f32_sycl(const float *x, float *dst, const int ncols,
+                              const int nrows, const float eps,
+                              queue_ptr stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
+    if (ncols < 1024) {
+        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
+                                                         cgh);
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q5_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q5_K_acc_ct1.get_pointer(),
-                            tile_x_dm_q5_K_acc_ct1.get_pointer(),
-                            tile_x_sc_q5_K_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
+            cgh.parallel_for(
+                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+                                  block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        rms_norm_f32(x, dst, ncols, eps, item_ct1,
+                                                s_sum_acc_ct1.get_pointer(), WARP_SIZE);
                     });
-            });
-        }
+        });
     } else {
-        const bool need_check = true;
+        const int work_group_size = get_work_group_size(stream->get_device());
+        const sycl::range<3> block_dims(1, 1, work_group_size);
         /*
-        DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
+        DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
         the limit. To get the device limit, query
         info::device::max_work_group_size. Adjust the work-group size if needed.
         */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
+                                                         cgh);
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q5_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_q5_K_acc_ct1.get_pointer(),
-                            tile_x_dm_q5_K_acc_ct1.get_pointer(),
-                            tile_x_sc_q5_K_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
+            cgh.parallel_for(
+                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+                                  block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        rms_norm_f32(x, dst, ncols, eps, item_ct1,
+                                           s_sum_acc_ct1.get_pointer(), work_group_size);
                     });
-            });
-        }
+        });
     }
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
-                                        float *dst, const int ncols_x,
-                                        const int nrows_x, const int ncols_y,
-                                        const int nrows_y, const int nrows_dst,
-                                        queue_ptr stream) try {
-
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-    const int compute_capability = ggml_sycl_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= VER_GEN13) {
-        mmq_x  =  MMQ_X_Q6_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q6_K_RDNA2;
-        nwarps = NWARPS_Q6_K_RDNA2;
-    } else if (compute_capability >= VER_GEN12) {
-        mmq_x  =  MMQ_X_Q6_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q6_K_RDNA1;
-        nwarps = NWARPS_Q6_K_RDNA1;
-    } else if (compute_capability >= VER_GEN9) {
-        mmq_x  =  MMQ_X_Q6_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q6_K_AMPERE;
-        nwarps = NWARPS_Q6_K_AMPERE;
-    } else if (compute_capability >= VER_4VEC) {
-        mmq_x  =  MMQ_X_Q6_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q6_K_PASCAL;
-        nwarps = NWARPS_Q6_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
-    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        /*
-        DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
-                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q6_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_acc_ct1.get_pointer(),
-                            tile_x_dm_acc_ct1.get_pointer(),
-                            tile_x_sc_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
-            });
-        }
-    } else {
-        const bool need_check = true;
-        /*
-        DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(stream->get_device(),
-                                         {sycl::aspect::fp16});
+static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
+                                   const int ky, const int kx_padded,
+                                   queue_ptr stream) {
+    const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
+    const sycl::range<3> num_blocks(1, ky, block_num_x);
+    const sycl::range<3> block_size(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-            stream->submit([&](sycl::handler &cgh) {
-                sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
-                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
-                    cgh);
-                sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
-                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
-                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
-                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
-                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
-
-                cgh.parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        mul_mat_q6_K<need_check>(
-                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
-                            nrows_dst, item_ct1,
-                            tile_x_ql_acc_ct1.get_pointer(),
-                            tile_x_dm_acc_ct1.get_pointer(),
-                            tile_x_sc_acc_ct1.get_pointer(),
-                            tile_y_qs_acc_ct1.get_pointer(),
-                            tile_y_ds_acc_ct1.get_pointer());
-                    });
+        stream->parallel_for(
+            sycl::nd_range<3>(num_blocks * block_size, block_size),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+                quantize_q8_1(x, vy, kx, kx_padded, item_ct1);
             });
-        }
     }
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
 
 static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
                                            float *dst, const int ncols_x,
@@ -9187,7 +2446,7 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
                               const int nrows_y, const float scale, const float max_bias,
                               queue_ptr stream) {
     int nth = WARP_SIZE;
-    int max_block_size = GROUP_SIZE;
+    int max_block_size = get_work_group_size(stream->get_device());
     while (nth < ncols_x && nth < max_block_size) nth *= 2;
     if (nth>max_block_size) nth = max_block_size;
 
@@ -9339,7 +2598,7 @@ void ggml_backend_sycl_print_sycl_devices() {
     }
 }
 
-int get_sycl_env(const char *env_name, int default_val) {
+static inline int get_sycl_env(const char *env_name, int default_val) {
     char *user_device_string = getenv(env_name);
     int user_number = default_val;
 
@@ -9353,10 +2612,9 @@ int get_sycl_env(const char *env_name, int default_val) {
     return user_number;
 }
 
-int get_work_group_size(int user_device_id) {
+static inline int get_work_group_size(const sycl::device& device) {
     dpct::device_info prop;
-    dpct::get_device_info(prop,
-                          dpct::dev_mgr::instance().get_device(user_device_id));
+    dpct::get_device_info(prop, device);
     return prop.get_max_work_group_size();
 }
 
@@ -10042,76 +3300,6 @@ inline void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, const ggml_te
     (void) src1_dd;
 }
 
-inline void ggml_sycl_op_mul_mat_q(
-    ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
-    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
-    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
-    const int64_t src1_ncols, const int64_t src1_padded_row_size,
-    const queue_ptr &stream) try {
-
-    const int64_t ne00 = src0->ne[0];
-
-    const int64_t ne10 = src1->ne[0];
-    GGML_ASSERT(ne10 % QK8_1 == 0);
-
-    const int64_t ne0 = dst->ne[0];
-
-    const int64_t row_diff = row_high - row_low;
-
-    int device_id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(device_id = get_current_device_id()));
-
-    // the main device has a larger memory buffer to hold the results from all GPUs
-    // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
-    const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
-
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-            ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q2_K:
-            ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q3_K:
-            ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q4_K:
-            ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q5_K:
-            ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q6_K:
-            ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
-            break;
-        default:
-            GGML_ASSERT(false);
-            break;
-    }
-
-    (void) src1;
-    (void) dst;
-    (void) src1_ddf_i;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
 static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {
     int64_t min_compute_capability = INT_MAX;
     int64_t max_compute_capability = INT_MIN;
@@ -10160,179 +3348,6 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
 
 }
 
-inline void ggml_sycl_op_mul_mat_vec_q(
-    ggml_backend_sycl_context & ctx,
-    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
-    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
-    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
-    const int64_t src1_ncols, const int64_t src1_padded_row_size,
-    const queue_ptr &stream) {
-
-    const int64_t ne10 = src1->ne[0];
-    GGML_ASSERT(ne10 % QK8_1 == 0);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t row_diff = row_high - row_low;
-
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-
-    // the main device has a larger memory buffer to hold the results from all GPUs
-    // nrows_dst == nrows of the matrix that the kernel writes into
-    const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff;
-
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-            mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q2_K:
-            mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q3_K:
-            mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q4_K:
-            mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_K:
-            mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q6_K:
-            mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ1_S:
-            mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ1_M:
-            mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ2_XXS:
-            mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ2_XS:
-            mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ2_S:
-            mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ3_XXS:
-            mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ3_S:
-            mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ4_NL:
-            mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_IQ4_XS:
-            mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        default:
-            GGML_ASSERT(false);
-            break;
-    }
-
-    (void) src1;
-    (void) dst;
-    (void) src1_ddf_i;
-    (void) src1_ncols;
-    (void) src1_padded_row_size;
-}
-
-
-inline void ggml_sycl_op_dequantize_mul_mat_vec(
-    ggml_backend_sycl_context & ctx,
-    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
-    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
-    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
-    const int64_t src1_ncols, const int64_t src1_padded_row_size,
-    const queue_ptr &stream) {
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t row_diff = row_high - row_low;
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
-#ifdef GGML_SYCL_F16
-    ggml_sycl_pool_alloc<sycl::half> src1_dfloat_a(ctx.pool());
-    sycl::half *src1_dfloat = nullptr; // dfloat == half
-
-    bool src1_convert_f16 =
-        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
-        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
-        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
-
-    if (src1_convert_f16) {
-        src1_dfloat = src1_dfloat_a.alloc(ne00);
-        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
-        GGML_ASSERT(to_fp16_sycl != nullptr);
-        to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
-    }
-#else
-    const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
-#endif // GGML_SYCL_F16
-
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-            dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q2_K:
-            dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q3_K:
-            dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q4_K:
-            dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q5_K:
-            dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_Q6_K:
-            dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-            break;
-        case GGML_TYPE_F16:
-            convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
-            break;
-        default:
-            printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
-            GGML_ASSERT(false);
-            break;
-    }
-
-    (void) src1;
-    (void) dst;
-    (void) src1_ddq_i;
-    (void) src1_ncols;
-    (void) src1_padded_row_size;
-}
-
 inline void ggml_sycl_op_mul_mat_sycl(
     ggml_backend_sycl_context & ctx,
     const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -11606,7 +4621,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
     } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         // KQV single-batch
         ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // KQ + KQV multi-batch
         ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
@@ -11897,7 +4912,7 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
     GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
     GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS01;
 
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
     queue_ptr main_stream = ctx.stream();
diff --git a/src/ggml-sycl.h b/src/ggml-sycl.h
deleted file mode 100644 (file)
index 451938f..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-//
-//  MIT license
-//  Copyright (C) 2024 Intel Corporation
-//  SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-#include "ggml-sycl/presets.hpp"
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-// backend API
-GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
-
-// devide buffer
-GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
-
-// split tensor buffer that splits matrices by rows across multiple devices
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
-
-// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
-GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
-
-GGML_API void   ggml_backend_sycl_print_sycl_devices(void);
-GGML_API GGML_CALL void   ggml_sycl_get_gpu_list(int *id_list, int max_len);
-GGML_API GGML_CALL void   ggml_sycl_get_device_description(int device, char *description, size_t description_size);
-GGML_API GGML_CALL int   ggml_backend_sycl_get_device_count();
-GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
-
-// SYCL doesn't support registering host memory, keep here for reference
-// GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
-// GGML_API GGML_CALL void ggml_backend_sycl_unregister_host_buffer(void * buffer);
-#ifdef  __cplusplus
-}
-#endif
index 88bae59678bdd82f37d5cd428c1b774c97566651..2d37e271f90508825d6d7fdae56df4927b0b00b2 100644 (file)
 #define GGML_SYCL_BACKEND_HPP
 
 #include "common.hpp"
+#include "convert.hpp"
+#include "dequantize.hpp"
+#include "dmmv.hpp"
+#include "mmq.hpp"
+#include "mmvq.hpp"
 
 #endif // GGML_SYCL_BACKEND_HPP
index 414c37eed0d5dabf1ae5feef05a2e8eb1a1cc66a..e01f91633a4bff52f99b53939b67d41a83e988a2 100644 (file)
@@ -17,6 +17,7 @@
 #include <iostream>
 
 #include "dpct/helper.hpp"
+#include "ggml-sycl.h"
 #include "presets.hpp"
 
 #define GGML_COMMON_DECL_SYCL
diff --git a/src/ggml-sycl/convert.cpp b/src/ggml-sycl/convert.cpp
new file mode 100644 (file)
index 0000000..ce9de2b
--- /dev/null
@@ -0,0 +1,544 @@
+#include "convert.hpp"
+#include "dequantize.hpp"
+#include "presets.hpp"
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
+                             const sycl::nd_item<3> &item_ct1) {
+    const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                       item_ct1.get_local_id(2));
+
+    if (i >= k) {
+        return;
+    }
+
+    const int ib = i/qk; // block index
+    const int iqs = (i%qk)/qr; // quant index
+    const int iybs = i - i%qk; // y block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(vx, ib, iqs, v);
+
+    y[iybs + iqs + 0] = v.x();
+    y[iybs + iqs + y_offset] = v.y();
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_sycl(const void *__restrict__ vx,
+                                  dst_t *__restrict__ y, const int k,
+                                  dpct::queue_ptr stream) {
+    const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+        stream->parallel_for(
+            sycl::nd_range<3>(
+                sycl::range<3>(1, 1, num_blocks) *
+                    sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
+                sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
+            [=](sycl::nd_item<3> item_ct1) {
+                dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
+            });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
+                                     dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+#if QK_K == 256
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 64),
+                                               sycl::range<3>(1, 1, 64)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q2_K(vx, y, item_ct1);
+                             });
+    }
+#else
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q2_K(vx, y, item_ct1);
+                             });
+    }
+
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
+                                     dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+#if QK_K == 256
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 64),
+                                               sycl::range<3>(1, 1, 64)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q3_K(vx, y, item_ct1);
+                             });
+    }
+#else
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q3_K(vx, y, item_ct1);
+                             });
+    }
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
+                                     dpct::queue_ptr stream) {
+    const int nb32 = k / 32;
+    const int nb = (k + 255) / 256;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q4_0(vx, y, nb32, item_ct1);
+                             });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
+                                     dpct::queue_ptr stream) {
+    const int nb32 = k / 32;
+    const int nb = (k + 255) / 256;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q4_1(vx, y, nb32, item_ct1);
+                             });
+    }
+}
+
+
+template <typename dst_t>
+static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
+                                     dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q4_K(vx, y, item_ct1);
+                             });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
+                                     dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+#if QK_K == 256
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 64),
+                                               sycl::range<3>(1, 1, 64)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q5_K(vx, y, item_ct1);
+                             });
+    }
+#else
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q5_K(vx, y, item_ct1);
+                             });
+    }
+
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
+                                     dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+#if QK_K == 256
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 64),
+                                               sycl::range<3>(1, 1, 64)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q6_K(vx, y, item_ct1);
+                             });
+    }
+#else
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_q6_K(vx, y, item_ct1);
+                             });
+    }
+
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
+                                        dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq1_s(
+                                     vx, y, item_ct1, iq1s_grid_gpu
+                                     );
+                             });
+        });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
+                                        dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq1_m(
+                                     vx, y, item_ct1, iq1s_grid_gpu
+                                     );
+                             });
+        });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
+                                        dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq2_xxs(
+                                     vx, y, item_ct1, iq2xxs_grid,
+                                     ksigns_iq2xs, kmask_iq2xs);
+                             });
+        });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
+                                       dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq2_xs(
+                                     vx, y, item_ct1, iq2xs_grid,
+                                     ksigns_iq2xs, kmask_iq2xs);
+                             });
+        });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
+                                      dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq2_s(vx, y, item_ct1);
+                             });
+        });
+    }
+}
+
+
+template <typename dst_t>
+static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
+                                        dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq3_xxs(
+                                     vx, y, item_ct1, iq3xxs_grid,
+                                     ksigns_iq2xs, kmask_iq2xs);
+                             });
+        });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
+                                        dpct::queue_ptr stream) {
+    const int nb = k / QK_K;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                                   sycl::range<3>(1, 1, 32),
+                                               sycl::range<3>(1, 1, 32)),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 dequantize_block_iq3_s(
+                                     vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
+                             });
+        });
+    }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
+                                       dpct::queue_ptr stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+#if QK_K == 64
+    dequantize_row_iq4_nl_sycl(vx, y, k, stream);
+#else
+      {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                  cgh.parallel_for(
+                      sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                            sycl::range<3>(1, 1, 32),
+                                        sycl::range<3>(1, 1, 32)),
+                      [=](sycl::nd_item<3> item_ct1) {
+                            dequantize_block_iq4_xs(vx, y, item_ct1);
+                      });
+            });
+      }
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
+                                       dpct::queue_ptr stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+      {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                  cgh.parallel_for(
+                      sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+                                            sycl::range<3>(1, 1, 32),
+                                        sycl::range<3>(1, 1, 32)),
+                      [=](sycl::nd_item<3> item_ct1) {
+                            dequantize_block_iq4_nl(vx, y, item_ct1);
+                      });
+            });
+      }
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+
+    const src_t * x = (src_t *) vx;
+
+    y[i] = x[i];
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary_sycl(const void *__restrict__ vx,
+                               dst_t *__restrict__ y, const int k,
+                               dpct::queue_ptr stream) {
+    const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(
+                sycl::range<3>(1, 1, num_blocks) *
+                    sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
+                sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
+            [=](sycl::nd_item<3> item_ct1) {
+                convert_unary<src_t>(vx, y, k, item_ct1);
+            });
+    }
+}
+
+to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_sycl;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_sycl;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_sycl;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_sycl;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_sycl;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_sycl;
+        case GGML_TYPE_IQ1_M:
+            return dequantize_row_iq1_m_sycl;
+        case GGML_TYPE_IQ2_XXS:
+            return dequantize_row_iq2_xxs_sycl;
+        case GGML_TYPE_IQ2_XS:
+            return dequantize_row_iq2_xs_sycl;
+        case GGML_TYPE_IQ2_S:
+            return dequantize_row_iq2_s_sycl;
+        case GGML_TYPE_IQ3_XXS:
+            return dequantize_row_iq3_xxs_sycl;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_sycl;
+        case GGML_TYPE_IQ4_XS:
+            return dequantize_row_iq4_xs_sycl;
+        case GGML_TYPE_IQ4_NL:
+            return dequantize_row_iq4_nl_sycl;
+        case GGML_TYPE_F32:
+            return convert_unary_sycl<float>;
+        default:
+            return nullptr;
+    }
+}
+
+to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_sycl;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_sycl;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_sycl;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_sycl;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_sycl;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_sycl;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_sycl;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_sycl;
+        case GGML_TYPE_IQ1_M:
+            return dequantize_row_iq1_m_sycl;
+        case GGML_TYPE_IQ2_XXS:
+            return dequantize_row_iq2_xxs_sycl;
+        case GGML_TYPE_IQ2_XS:
+            return dequantize_row_iq2_xs_sycl;
+        case GGML_TYPE_IQ2_S:
+            return dequantize_row_iq2_s_sycl;
+        case GGML_TYPE_IQ3_XXS:
+            return dequantize_row_iq3_xxs_sycl;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_sycl;
+        case GGML_TYPE_IQ4_XS:
+            return dequantize_row_iq4_xs_sycl;
+        case GGML_TYPE_IQ4_NL:
+            return dequantize_row_iq4_nl_sycl;
+        case GGML_TYPE_F16:
+            return convert_unary_sycl<sycl::half>;
+        default:
+            return nullptr;
+    }
+}
diff --git a/src/ggml-sycl/convert.hpp b/src/ggml-sycl/convert.hpp
new file mode 100644 (file)
index 0000000..b1f10d6
--- /dev/null
@@ -0,0 +1,27 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_CONVERT_HPP
+#define GGML_SYCL_CONVERT_HPP
+
+#include "common.hpp"
+
+template <typename T>
+using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
+                             int k, dpct::queue_ptr stream);
+typedef to_t_sycl_t<float> to_fp32_sycl_t;
+typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
+
+to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type);
+to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type);
+
+#endif // GGML_SYCL_CONVERT_HPP
diff --git a/src/ggml-sycl/dequantize.hpp b/src/ggml-sycl/dequantize.hpp
new file mode 100644 (file)
index 0000000..b6080d8
--- /dev/null
@@ -0,0 +1,690 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_DEQUANTIZE_HPP
+#define GGML_SYCL_DEQUANTIZE_HPP
+
+#include "common.hpp"
+
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
+
+static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
+                                            const int iqs, dfloat2 &v) {
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    const int vui = x[ib].qs[iqs];
+
+    v.x() = vui & 0xF;
+    v.y() = vui >> 4;
+
+#ifdef GGML_SYCL_F16
+    // v = v - {8.0f, 8.0f};
+    // v = v * {d, d};
+    v.s0() = (v.s0() - 8.0f) * d;
+    v.s1() = (v.s1() - 8.0f) * d;
+
+#else
+    v.x() = (v.x() - 8.0f) * d;
+    v.y() = (v.y() - 8.0f) * d;
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
+                                            const int iqs, dfloat2 &v) {
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const dfloat d = x[ib].dm[0];
+    const dfloat m = x[ib].dm[1];
+
+    const int vui = x[ib].qs[iqs];
+
+    v.x() = vui & 0xF;
+    v.y() = vui >> 4;
+
+#ifdef GGML_SYCL_F16
+    // v = v * {d, d};
+    // v = v + {m, m};
+    v.s0() = (v.s0() * d) + m;
+    v.s1() = (v.s1() * d) + m;
+
+#else
+    v.x() = (v.x() * d) + m;
+    v.y() = (v.y() * d) + m;
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
+                                            const int iqs, dfloat2 &v) {
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+#ifdef GGML_SYCL_F16
+    // v = v - {16.0f, 16.0f};
+    // v = v * {d, d};
+    v.s0() = (v.s0() - 16.0f) * d;
+    v.s1() = (v.s1() - 16.0f) * d;
+
+#else
+    v.x() = (v.x() - 16.0f) * d;
+    v.y() = (v.y() - 16.0f) * d;
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
+                                            const int iqs, dfloat2 &v) {
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const dfloat d = x[ib].dm[0];
+    const dfloat m = x[ib].dm[1];
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+#ifdef GGML_SYCL_F16
+    // v = v * {d, d};
+    // v = v + {m, m};
+    v.s0() = (v.s0() * d) + m;
+    v.s1() = (v.s1() * d) + m;
+#else
+    v.x() = (v.x() * d) + m;
+    v.y() = (v.y() * d) + m;
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
+                                            const int iqs, dfloat2 &v) {
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    v.x() = x[ib].qs[iqs + 0];
+    v.y() = x[ib].qs[iqs + 1];
+
+#ifdef GGML_SYCL_F16
+    // v = v * {d, d};
+    v.s0() *= d;
+    v.s1() *= d;
+#else
+    v.x() *= d;
+    v.y() *= d;
+#endif // GGML_SYCL_F16
+}
+
+template<typename dst_t>
+static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
+                                  const sycl::nd_item<3> &item_ct1) {
+
+    const int i = item_ct1.get_group(2);
+
+    // assume 32 threads
+    const int tid = item_ct1.get_local_id(2);
+    const int il  = tid/8;
+    const int ir  = tid%8;
+    const int ib = 8*i + ir;
+    if (ib >= nb32) {
+        return;
+    }
+
+    dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+    const block_q4_0 * x = (const block_q4_0 *)vx + ib;
+    const float d = sycl::vec<sycl::half, 1>(x->d)
+                        .convert<float, sycl::rounding_mode::automatic>()[0];
+    const float dm = -8*d;
+
+    const uint8_t * q = x->qs + 4*il;
+
+    for (int l = 0; l < 4; ++l) {
+        y[l+ 0] = d * (q[l] & 0xF) + dm;
+        y[l+16] = d * (q[l] >>  4) + dm;
+    }
+}
+
+template<typename dst_t>
+static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
+                                  const sycl::nd_item<3> &item_ct1) {
+
+    const int i = item_ct1.get_group(2);
+
+    // assume 32 threads
+    const int tid = item_ct1.get_local_id(2);
+    const int il  = tid/8;
+    const int ir  = tid%8;
+    const int ib = 8*i + ir;
+    if (ib >= nb32) {
+        return;
+    }
+
+    dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+    const block_q4_1 * x = (const block_q4_1 *)vx + ib;
+    const sycl::float2 d =
+        x->dm.convert<float, sycl::rounding_mode::automatic>();
+
+    const uint8_t * q = x->qs + 4*il;
+
+    for (int l = 0; l < 4; ++l) {
+        y[l + 0] = d.x() * (q[l] & 0xF) + d.y();
+        y[l + 16] = d.x() * (q[l] >> 4) + d.y();
+    }
+}
+
+
+//================================== k-quants
+
+template<typename dst_t>
+static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                  const sycl::nd_item<3> &item_ct1) {
+
+    const int i = item_ct1.get_group(2);
+    const block_q2_K * x = (const block_q2_K *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int n   = tid/32;
+    const int l   = tid - 32*n;
+    const int is  = 8*n + l/16;
+
+    const uint8_t q = x[i].qs[32*n + l];
+    dst_t * y = yy + i*QK_K + 128*n;
+
+    float dall = x[i].dm[0];
+    float dmin = x[i].dm[1];
+    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
+    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
+    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+#else
+    const int is = tid/16;  // 0 or 1
+    const int il = tid%16;  // 0...15
+    const uint8_t q = x[i].qs[il] >> (2*is);
+    dst_t * y = yy + i*QK_K + 16*is + il;
+
+    float dall = x[i].dm[0];
+    float dmin = x[i].dm[1];
+    y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
+#endif
+
+}
+
+template<typename dst_t>
+static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                  const sycl::nd_item<3> &item_ct1) {
+
+    const int i = item_ct1.get_group(2);
+    const block_q3_K * x = (const block_q3_K *) vx;
+
+#if QK_K == 256
+    const int r = item_ct1.get_local_id(2) / 4;
+    const int tid = r/2;
+    const int is0 = r%2;
+    const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
+    const int n = tid / 4;
+    const int j = tid - 4*n;
+
+    uint8_t m = 1 << (4*n + j);
+    int is = 8*n + 2*j + is0;
+    int shift = 2*j;
+
+    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+    float d_all = x[i].d;
+    float dl = d_all * (us - 32);
+
+    dst_t * y = yy + i*QK_K + 128*n + 32*j;
+    const uint8_t * q = x[i].qs + 32*n;
+    const uint8_t * hm = x[i].hmask;
+
+    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+#else
+    const int tid = item_ct1.get_local_id(2);
+    const int is  = tid/16;  // 0 or 1
+    const int il  = tid%16;  // 0...15
+    const int im  = il/8;    // 0...1
+    const int in  = il%8;    // 0...7
+
+    dst_t * y = yy + i*QK_K + 16*is + il;
+
+    const uint8_t q = x[i].qs[il] >> (2*is);
+    const uint8_t h = x[i].hmask[in] >> (2*is + im);
+    const float   d = (float)x[i].d;
+
+    if (is == 0) {
+        y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+        y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+    } else {
+        y[ 0] = d * ((x[i].scales[0] >>  4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+        y[32] = d * ((x[i].scales[1] >>  4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+    }
+#endif
+
+}
+
+#if QK_K == 256
+static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+    if (j < 4) {
+        d = q[j] & 63; m = q[j + 4] & 63;
+    } else {
+        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+#endif
+
+template<typename dst_t>
+static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                  const sycl::nd_item<3> &item_ct1) {
+    const block_q4_K * x = (const block_q4_K *) vx;
+
+    const int i = item_ct1.get_group(2);
+
+#if QK_K == 256
+    // assume 32 threads
+    const int tid = item_ct1.get_local_id(2);
+    const int il  = tid/8;
+    const int ir  = tid%8;
+    const int is  = 2*il;
+    const int n   = 4;
+
+    dst_t * y = yy + i*QK_K + 64*il + n*ir;
+
+    const float dall = x[i].dm[0];
+    const float dmin = x[i].dm[1];
+
+    const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+    for (int l = 0; l < n; ++l) {
+        y[l + 0] = d1 * (q[l] & 0xF) - m1;
+        y[l +32] = d2 * (q[l] >>  4) - m2;
+    }
+#else
+    const int tid = item_ct1.get_local_id(2);
+    const uint8_t * q = x[i].qs;
+    dst_t * y = yy + i*QK_K;
+    const float d = (float)x[i].dm[0];
+    const float m = (float)x[i].dm[1];
+    y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
+    y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4);
+#endif
+}
+
+template<typename dst_t>
+static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                  const sycl::nd_item<3> &item_ct1) {
+    const block_q5_K * x = (const block_q5_K *) vx;
+
+    const int i = item_ct1.get_group(2);
+
+#if QK_K == 256
+    // assume 64 threads - this is very slightly better than the one below
+    const int tid = item_ct1.get_local_id(2);
+    const int il  = tid/16;   // il is in 0...3
+    const int ir  = tid%16;   // ir is in 0...15
+    const int is  = 2*il;     // is is in 0...6
+
+    dst_t * y = yy + i*QK_K + 64*il + 2*ir;
+
+    const float dall = x[i].dm[0];
+    const float dmin = x[i].dm[1];
+
+    const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+    const uint8_t * qh = x[i].qh + 2*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+
+    uint8_t   hm  = 1 << (2*il);
+    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+    hm <<= 1;
+    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+#else
+    const int tid = item_ct1.get_local_id(2);
+    const uint8_t q = x[i].qs[tid];
+    const int im = tid/8;  // 0...3
+    const int in = tid%8;  // 0...7
+    const int is = tid/16; // 0 or 1
+    const uint8_t h = x[i].qh[in] >> im;
+    const float d = x[i].d;
+    dst_t * y = yy + i*QK_K + tid;
+    y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
+    y[32] = d * x[i].scales[is+2] * ((q >>  4) - ((h >> 4) & 1 ? 0 : 16));
+#endif
+}
+
+template<typename dst_t>
+static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                  const sycl::nd_item<3> &item_ct1) {
+    const block_q6_K * x = (const block_q6_K *) vx;
+
+    const int i = item_ct1.get_group(2);
+#if QK_K == 256
+
+    // assume 64 threads - this is very slightly better than the one below
+    const int tid = item_ct1.get_local_id(2);
+    const int ip  = tid/32;   // ip is 0 or 1
+    const int il  = tid - 32*ip; // 0...32
+    const int is  = 8*ip + il/16;
+
+    dst_t * y = yy + i*QK_K + 128*ip + il;
+
+    const float d = x[i].d;
+
+    const uint8_t * ql = x[i].ql + 64*ip + il;
+    const uint8_t   qh = x[i].qh[32*ip + il];
+    const int8_t  * sc = x[i].scales + is;
+
+    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+#else
+
+    // assume 32 threads
+    const int tid = item_ct1.get_local_id(2);
+    const int ip  = tid/16;         // 0 or 1
+    const int il  = tid - 16*ip;    // 0...15
+
+    dst_t * y = yy + i*QK_K + 16*ip + il;
+
+    const float d = x[i].d;
+
+    const uint8_t   ql = x[i].ql[16*ip + il];
+    const uint8_t   qh = x[i].qh[il] >> (2*ip);
+    const int8_t  * sc = x[i].scales;
+
+    y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[ip+2] * ((int8_t)((ql  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+#endif
+}
+
+template<typename dst_t>
+static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                     const sycl::nd_item<3> &item_ct1,
+                                     const uint64_t *iq2xxs_grid_ptr,
+                                     const uint8_t *ksigns_iq2xs_ptr,
+                                     const uint8_t *kmask_iq2xs_ptr) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * q2 = x[i].qs + 4*ib;
+    const uint8_t  * aux8 = (const uint8_t *)q2;
+    const uint8_t  * grid = (const uint8_t *)(iq2xxs_grid_ptr + aux8[il]);
+    const uint32_t aux32 = q2[2] | (q2[3] << 16);
+    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
+    const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127];
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f);
+#else
+    assert(false);
+#endif
+
+}
+
+template<typename dst_t>
+static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                    const sycl::nd_item<3> &item_ct1,
+                                    const uint64_t *iq2xs_grid,
+                                    const uint8_t *ksigns_iq2xs,
+                                    const uint8_t *kmask_iq2xs) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq2_xs * x = (const block_iq2_xs *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * q2 = x[i].qs + 4*ib;
+    const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
+    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+    const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+#else
+    assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
+                       const sycl::nd_item<3> &item_ct1) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq2_s * x = (const block_iq2_s *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
+    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+    const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
+#pragma unroll
+    for (int j = 0; j < 8; ++j)
+        y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+#else
+    assert(false);
+
+#endif
+
+}
+
+template<typename dst_t>
+static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
+                                     const sycl::nd_item<3> &item_ct1,
+                                     const uint32_t *iq3xxs_grid,
+                                     const uint8_t *ksigns_iq2xs,
+                                     const uint8_t *kmask_iq2xs) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t  * q3 = x[i].qs + 8*ib;
+    const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
+    const uint8_t  * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
+    const uint8_t  * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
+    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+#else
+    assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
+                       const sycl::nd_item<3> &item_ct1,
+                       const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq3_s * x = (const block_iq3_s *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t * qs = x[i].qs + 8*ib;
+    const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
+    const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
+    const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
+    const uint8_t signs = x[i].signs[4*ib + il];
+#pragma unroll
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+#else
+    assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
+                       const sycl::nd_item<3> &item_ct1,
+                       const uint32_t *iq1s_grid_gpu) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq1_s * x = (const block_iq1_s  *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
+    const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
+    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
+    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+    grid32[0] &= 0x0f0f0f0f;
+#pragma unroll
+    for (int j = 0; j < 8; ++j) {
+        y[j] = d * (q[j] + delta);
+    }
+#else
+    assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
+                       const sycl::nd_item<3> &item_ct1,
+                       const uint32_t *iq1s_grid_gpu) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq1_m * x = (const block_iq1_m  *) vx;
+
+    const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * sc = (const uint16_t *)x[i].scales;
+    iq1m_scale_t scale;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
+    const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
+    const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
+    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
+    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+    grid32[0] &= 0x0f0f0f0f;
+#pragma unroll
+    for (int j = 0; j < 8; ++j) {
+        y[j] = d * (q[j] + delta);
+    }
+#else
+    assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
+                        const sycl::nd_item<3> &item_ct1) {
+
+    const int i = item_ct1.get_group(2);
+    const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
+
+    const int tid = item_ct1.get_local_id(2);
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[ib].qs + 4*il;
+    const float d = (float)x[ib].d;
+#pragma unroll
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    }
+
+}
+
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
+                        const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_group(2);
+    const block_iq4_xs * x = (const block_iq4_xs *)vx;
+
+    const int tid = item_ct1.get_local_id(2);
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;
+    const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
+#pragma unroll
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    }
+}
+
+
+#endif // GGML_SYCL_DEQUANTIZE_HPP
diff --git a/src/ggml-sycl/dmmv.cpp b/src/ggml-sycl/dmmv.cpp
new file mode 100644 (file)
index 0000000..3a87d3e
--- /dev/null
@@ -0,0 +1,1022 @@
+#include "convert.hpp"
+#include "dmmv.hpp"
+#include "dequantize.hpp"
+#include "presets.hpp"
+
+static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const sycl::half *x = (const sycl::half *)vx;
+
+    // automatic half -> float type cast if dfloat == float
+    v.x() = x[ib + iqs + 0];
+    v.y() = x[ib + iqs + 1];
+}
+
+static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const float * x = (const float *) vx;
+
+    // automatic half -> float type cast if dfloat == float
+    v.x() = x[ib + iqs + 0];
+    v.y() = x[ib + iqs + 1];
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
+                                   const sycl::nd_item<3> &item_ct1) {
+    // qk = quantized weights per x block
+    // qr = number of quantized weights per data value in x block
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int tid = item_ct1.get_local_id(2);
+
+    const int iter_stride = 2*GGML_SYCL_DMMV_X;
+    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+// partial sum for each thread
+#ifdef GGML_SYCL_F16
+    sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
+#else
+    float tmp = 0.0f;
+#endif // GGML_SYCL_F16
+
+    for (int i = 0; i < ncols; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int ib = (row*ncols + col)/qk; // x block index
+        const int iqs = (col%qk)/qr; // x quant index
+        const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+        for (int j = 0; j < vals_per_iter; j += 2) {
+            // process 2 vals per j iter
+
+            // dequantize
+            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+            dfloat2 v;
+            dequantize_kernel(vx, ib, iqs + j/qr, v);
+
+            // matrix multiplication
+            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_SYCL_F16
+            dfloat2 t1{y[iybs + iqs + j / qr + 0],
+                        y[iybs + iqs + j / qr + y_offset]};
+
+            tmp += v * t1;
+#else
+            tmp += v.x() * y[iybs + iqs + j / qr + 0];
+            tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
+#endif // GGML_SYCL_F16
+        }
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (tid == 0) {
+#ifdef GGML_SYCL_F16
+        dst[row] = tmp.x() + tmp.y();
+#else
+        dst[row] = tmp;
+#endif // GGML_SYCL_F16
+    }
+}
+
+static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
+                                         float *dst, const int ncols,
+                                         const int nrows,
+                                         dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+                dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
+                                                          nrows, item_ct1);
+            });
+    }
+}
+
+/*
+DPCT1110:4: The total declared local variable size in device function
+dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
+pressure. Consult with your hardware vendor to find the total register size
+available and adjust the code, or use smaller sub-group size to avoid high
+register pressure.
+*/
+static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
+                                        const float *__restrict__ yy,
+                                        float *__restrict__ dst,
+                                        const int ncols, int nrows,
+                                        const sycl::nd_item<3> &item_ct1) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q2_K * x = (const block_q2_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+    const int tid =
+        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15
+    const int ix =
+        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
+
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+
+        const float dall = x[i].dm[0];
+        const float dmin = x[i].dm[1];
+
+        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
+
+        }
+        tmp += dall * sum1 - dmin * sum2;
+
+    }
+#else
+    const int tid = item_ct1.get_local_id(2) /
+                    (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7
+    const int ix = item_ct1.get_local_id(2) %
+                   (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3
+    const int offset = tid * K_QUANTS_PER_ITERATION;
+
+    uint32_t uaux[2];
+    const uint8_t * d = (const uint8_t *)uaux;
+
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + offset;
+        const uint8_t * q = x[i].qs + offset;
+        const uint32_t * s = (const uint32_t *)x[i].scales;
+
+        uaux[0] = s[0] & 0x0f0f0f0f;
+        uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
+
+        const sycl::float2 dall =
+            x[i].dm.convert<float, sycl::rounding_mode::automatic>();
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            const uint8_t ql = q[l];
+            sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+                  + y[l+16] * d[1] * ((ql >> 2) & 3)
+                  + y[l+32] * d[2] * ((ql >> 4) & 3)
+                  + y[l+48] * d[3] * ((ql >> 6) & 3);
+            sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
+        }
+        tmp += dall.x() * sum1 - dall.y() * sum2;
+    }
+
+#endif
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+/*
+DPCT1110:5: The total declared local variable size in device function
+dequantize_mul_mat_vec_q3_k exceeds 128 bytes and may cause high register
+pressure. Consult with your hardware vendor to find the total register size
+available and adjust the code, or use smaller sub-group size to avoid high
+register pressure.
+*/
+static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
+                                        const float *__restrict__ yy,
+                                        float *__restrict__ dst,
+                                        const int ncols, int nrows,
+                                        const sycl::nd_item<3> &item_ct1) {
+
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q3_K * x = (const block_q3_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const int tid =
+        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+    const int ix =
+        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
+
+    const uint8_t m = 1 << (4*im);
+
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
+
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
+
+    const uint16_t s_shift = 4*im;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+        const uint8_t * h = x[i].hmask + l0;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp += d * sum;
+
+    }
+#else
+
+    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
+    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
+    const int offset = tid * K_QUANTS_PER_ITERATION;         // 0...15 or 0...14
+    const int in = offset/8;                                 // 0 or 1
+    const int im = offset%8;                                 // 0...7
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + offset;
+        const uint8_t * q = x[i].qs + offset;
+        const uint8_t * s = x[i].scales;
+
+        const float dall = (float)x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            const uint8_t hl = x[i].hmask[im+l] >> in;
+            const uint8_t ql = q[l];
+            sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+                 + y[l+16] * dall * ((s[0] >>  4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+                 + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+                 + y[l+48] * dall * ((s[1] >>  4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
+        }
+        tmp += sum;
+    }
+#endif
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+/*
+DPCT1110:6: The total declared local variable size in device function
+dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register
+pressure. Consult with your hardware vendor to find the total register size
+available and adjust the code, or use smaller sub-group size to avoid high
+register pressure.
+*/
+static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
+                                        const float *__restrict__ yy,
+                                        float *__restrict__ dst,
+                                        const int ncols, int nrows,
+                                        const sycl::nd_item<3> &item_ct1) {
+
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+    if (row > nrows) return;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+#if QK_K == 256
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid =
+        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+    const int ix =
+        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
+
+    const int il  = tid/step;                            // 0...3
+    const int ir  = tid - step*il;                       // 0...7 or 0...3
+    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+#if K_QUANTS_PER_ITERATION == 2
+    uint32_t q32[4];
+    const uint8_t * q4 = (const uint8_t *)q32;
+#else
+    uint16_t q16[4];
+    const uint8_t * q4 = (const uint8_t *)q16;
+#endif
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y1 = yy + i*QK_K + y_offset;
+        const float   * y2 = y1 + 128;
+
+        const float dall = x[i].dm[0];
+        const float dmin = x[i].dm[1];
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+#if K_QUANTS_PER_ITERATION == 2
+        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
+        const uint32_t * q2 = q1 + 16;
+
+        q32[0] = q1[0] & 0x0f0f0f0f;
+        q32[1] = q1[0] & 0xf0f0f0f0;
+        q32[2] = q2[0] & 0x0f0f0f0f;
+        q32[3] = q2[0] & 0xf0f0f0f0;
+
+        sycl::float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < 4; ++l) {
+            s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4];
+            s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12];
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +
+                       s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -
+               dmin * smin;
+#else
+        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
+        const uint16_t * q2 = q1 + 32;
+
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[0] & 0xf0f0;
+        q16[2] = q2[0] & 0x0f0f;
+        q16[3] = q2[0] & 0xf0f0;
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < 2; ++l) {
+            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
+            s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#endif
+
+    }
+#else
+    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15
+    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
+
+    const int step = tid * K_QUANTS_PER_ITERATION;
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    float tmp = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+        const uint8_t * q = x[i].qs + step;
+        const float   * y = yy + i*QK_K + step;
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+        const float d = (float)x[i].dm[0];
+        const float m = (float)x[i].dm[1];
+        float sum = 0.f;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+                 + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+                 + y[j+32] * (d * s[1] * (q[j+ 0] >>  4) - m * s[3])
+                 + y[j+48] * (d * s[1] * (q[j+16] >>  4) - m * s[3]);
+        }
+        tmp += sum;
+    }
+
+#endif
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+/*
+DPCT1110:7: The total declared local variable size in device function
+dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register
+pressure. Consult with your hardware vendor to find the total register size
+available and adjust the code, or use smaller sub-group size to avoid high
+register pressure.
+*/
+static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
+                                        const float *__restrict__ yy,
+                                        float *__restrict__ dst,
+                                        const int ncols,
+                                        const sycl::nd_item<3> &item_ct1) {
+
+    const int row = item_ct1.get_group(2);
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid = item_ct1.get_local_id(2) / 2; // 0...15
+    const int ix = item_ct1.get_local_id(2) % 2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 2;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    uint16_t q16[8];
+    const uint8_t * q4 = (const uint8_t *)q16;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        const uint8_t * ql1 = x[i].qs + q_offset;
+        const uint8_t * qh  = x[i].qh + l0;
+        const float   * y1  = yy + i*QK_K + y_offset;
+        const float   * y2  = y1 + 128;
+
+        const float dall = x[i].dm[0];
+        const float dmin = x[i].dm[1];
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        sycl::float4 sum = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        const uint16_t * q1 = (const uint16_t *)ql1;
+        const uint16_t * q2 = q1 + 32;
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[8] & 0x0f0f;
+        q16[2] = (q1[0] >> 4) & 0x0f0f;
+        q16[3] = (q1[8] >> 4) & 0x0f0f;
+        q16[4] = q2[0] & 0x0f0f;
+        q16[5] = q2[8] & 0x0f0f;
+        q16[6] = (q2[0] >> 4) & 0x0f0f;
+        q16[7] = (q2[8] >> 4) & 0x0f0f;
+        for (int l = 0; l < n; ++l) {
+            sum.x() +=
+                y1[l + 0] * (q4[l + 0] + (qh[l + 0] & (hm1 << 0) ? 16 : 0)) +
+                y1[l + 16] * (q4[l + 2] + (qh[l + 16] & (hm1 << 0) ? 16 : 0));
+            sum.y() +=
+                y1[l + 32] * (q4[l + 4] + (qh[l + 0] & (hm1 << 1) ? 16 : 0)) +
+                y1[l + 48] * (q4[l + 6] + (qh[l + 16] & (hm1 << 1) ? 16 : 0));
+            sum.z() +=
+                y2[l + 0] * (q4[l + 8] + (qh[l + 0] & (hm2 << 0) ? 16 : 0)) +
+                y2[l + 16] * (q4[l + 10] + (qh[l + 16] & (hm2 << 0) ? 16 : 0));
+            sum.w() +=
+                y2[l + 32] * (q4[l + 12] + (qh[l + 0] & (hm2 << 1) ? 16 : 0)) +
+                y2[l + 48] * (q4[l + 14] + (qh[l + 16] & (hm2 << 1) ? 16 : 0));
+            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+        }
+        tmp += dall * (sum.x() * sc[0] + sum.y() * sc[1] + sum.z() * sc[4] +
+                       sum.w() * sc[5]) -
+               dmin * smin;
+    }
+
+#else
+    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...15
+    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
+    const int step = tid * K_QUANTS_PER_ITERATION;
+    const int im = step/8;
+    const int in = step%8;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+        const uint8_t * q = x[i].qs + step;
+        const int8_t  * s = x[i].scales;
+        const float   * y = yy + i*QK_K + step;
+        const float     d = x[i].d;
+        float sum = 0.f;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            const uint8_t h = x[i].qh[in+j] >> im;
+            sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+                 + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+                 + y[j+32] * d * s[2] * ((q[j+ 0] >>  4) - ((h >> 4) & 1 ? 0 : 16))
+                 + y[j+48] * d * s[3] * ((q[j+16] >>  4) - ((h >> 6) & 1 ? 0 : 16));
+        }
+        tmp += sum;
+    }
+#endif
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,
+                                        const sycl::nd_item<3> &item_ct1) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q6_K * x = (const block_q6_K *)vx + ib0;
+
+#if QK_K == 256
+
+    const int tid =
+        item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+    const int ix =
+        item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * ql = x[i].ql + ql_offset;
+        const uint8_t * qh = x[i].qh + qh_offset;
+        const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = x[i].d;
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp += sum;
+#endif
+
+    }
+
+#else
+
+    const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION);  // 0...7
+    const int ix  = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);  // 0...3
+
+    const int step = tid * K_QUANTS_PER_ITERATION;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + step;
+        const uint8_t * ql = x[i].ql + step;
+        const uint8_t * qh = x[i].qh + step;
+        const int8_t  * s  = x[i].scales;
+
+        const float d = x[i+0].d;
+
+        float sum = 0;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+                 + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+                 + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >>  4) | ((qh[j] & 0x30) >> 0)) - 32)
+                 + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >>  4) | ((qh[j] & 0xc0) >> 2)) - 32);
+        }
+        tmp += sum;
+
+    }
+
+#endif
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+
+static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+                dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
+                    vx, y, dst, ncols, nrows, item_ct1);
+            });
+    }
+}
+
+static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+                dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
+                    vx, y, dst, ncols, nrows, item_ct1);
+            });
+    }
+}
+
+static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+                dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
+                    vx, y, dst, ncols, nrows, item_ct1);
+            });
+    }
+}
+
+static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+                dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
+                    vx, y, dst, ncols, nrows, item_ct1);
+            });
+    }
+}
+
+static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+                dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
+                    vx, y, dst, ncols, nrows, item_ct1);
+            });
+    }
+}
+
+static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, ny, 32);
+    stream->parallel_for(
+        sycl::nd_range<3>(block_nums * block_dims, block_dims),
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+            dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
+        });
+}
+
+static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, ny, 32);
+    stream->parallel_for(
+        sycl::nd_range<3>(block_nums * block_dims, block_dims),
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+            dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
+        });
+}
+
+static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, ny, 32);
+    stream->parallel_for(
+        sycl::nd_range<3>(block_nums * block_dims, block_dims),
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+            dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
+        });
+}
+
+static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const sycl::range<3> block_dims(1, 1, 32);
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+            dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
+        });
+}
+
+static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
+                                             float *dst, const int ncols,
+                                             const int nrows,
+                                             dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, ny, 32);
+    stream->parallel_for(
+        sycl::nd_range<3>(block_nums * block_dims, block_dims),
+        [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+            dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
+        });
+}
+
+void ggml_sycl_op_dequantize_mul_mat_vec(
+    ggml_backend_sycl_context & ctx,
+    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+    const int64_t src1_ncols, const int64_t src1_padded_row_size,
+    const dpct::queue_ptr &stream) {
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_SYCL_F16
+    ggml_sycl_pool_alloc<sycl::half> src1_dfloat_a(ctx.pool());
+    sycl::half *src1_dfloat = nullptr; // dfloat == half
+
+    bool src1_convert_f16 =
+        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+    if (src1_convert_f16) {
+        src1_dfloat = src1_dfloat_a.alloc(ne00);
+        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+        GGML_ASSERT(to_fp16_sycl != nullptr);
+        to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
+    }
+#else
+    const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
+#endif // GGML_SYCL_F16
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_F16:
+            convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        default:
+            printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
+            GGML_ASSERT(false);
+            break;
+    }
+
+    (void) src1;
+    (void) dst;
+    (void) src1_ddq_i;
+    (void) src1_ncols;
+    (void) src1_padded_row_size;
+}
diff --git a/src/ggml-sycl/dmmv.hpp b/src/ggml-sycl/dmmv.hpp
new file mode 100644 (file)
index 0000000..bd83735
--- /dev/null
@@ -0,0 +1,27 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_DMMV_HPP
+#define GGML_SYCL_DMMV_HPP
+
+#include "common.hpp"
+
+
+void ggml_sycl_op_dequantize_mul_mat_vec(
+    ggml_backend_sycl_context & ctx,
+    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+    const int64_t src1_ncols, const int64_t src1_padded_row_size,
+    const dpct::queue_ptr &stream);
+
+#endif // GGML_SYCL_DMMV_HPP
index 017fd6ee1326808ee941b91b08468ea75da6403f..1ff297218c6853d8cb88e0d8aaefe45998eb185b 100644 (file)
@@ -588,266 +588,222 @@ namespace dpct
         out = prop;
     }
 
-    /// dpct device extension
-    class device_ext : public sycl::device
-    {
-        typedef std::mutex mutex_type;
-
-    public:
-        device_ext() : sycl::device(), _ctx(*this) {}
-        ~device_ext()
-        {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            clear_queues();
-        }
-        device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this)
-        {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            init_queues();
-        }
-
-        int is_native_atomic_supported() { return 0; }
-        int get_major_version() const
-        {
-            return dpct::get_major_version(*this);
-        }
-
-        int get_minor_version() const
-        {
-            return dpct::get_minor_version(*this);
-        }
-
-        int get_max_compute_units() const
-        {
-            return get_device_info().get_max_compute_units();
-        }
-
-        /// Return the maximum clock frequency of this device in KHz.
-        int get_max_clock_frequency() const
-        {
-            return get_device_info().get_max_clock_frequency();
-        }
-
-        int get_integrated() const { return get_device_info().get_integrated(); }
-
-        int get_max_sub_group_size() const
-        {
-            return get_device_info().get_max_sub_group_size();
-        }
-
-        int get_max_register_size_per_work_group() const
-        {
-            return get_device_info().get_max_register_size_per_work_group();
-        }
-
-        int get_max_work_group_size() const
-        {
-            return get_device_info().get_max_work_group_size();
-        }
-
-        int get_mem_base_addr_align() const
-        {
-            return get_info<sycl::info::device::mem_base_addr_align>();
-        }
-
-        size_t get_global_mem_size() const
-        {
-            return get_device_info().get_global_mem_size();
-        }
-
-        size_t get_max_mem_alloc_size() const
-        {
-            return get_device_info().get_max_mem_alloc_size();
-        }
-
-        /// Get the number of bytes of free and total memory on the SYCL device.
-        /// \param [out] free_memory The number of bytes of free memory on the SYCL device.
-        /// \param [out] total_memory The number of bytes of total memory on the SYCL device.
-        void get_memory_info(size_t &free_memory, size_t &total_memory)
-        {
-            total_memory = get_device_info().get_global_mem_size();
-            const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not "
-                                 "supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
-                                 "use total memory as free memory";
+   /// dpct device extension
+    class device_ext : public sycl::device {
+      typedef std::mutex mutex_type;
+
+     public:
+      device_ext() : sycl::device() {}
+      ~device_ext() {
+        std::lock_guard<mutex_type> lock(m_mutex);
+        clear_queues();
+      }
+      device_ext(const sycl::device &base) : sycl::device(base) {
+        std::lock_guard<mutex_type> lock(m_mutex);
+        init_queues();
+      }
+
+      int is_native_atomic_supported() { return 0; }
+      int get_major_version() const { return dpct::get_major_version(*this); }
+
+      int get_minor_version() const { return dpct::get_minor_version(*this); }
+
+      int get_max_compute_units() const {
+        return get_device_info().get_max_compute_units();
+      }
+
+      /// Return the maximum clock frequency of this device in KHz.
+      int get_max_clock_frequency() const {
+        return get_device_info().get_max_clock_frequency();
+      }
+
+      int get_integrated() const { return get_device_info().get_integrated(); }
+
+      int get_max_sub_group_size() const {
+        return get_device_info().get_max_sub_group_size();
+      }
+
+      int get_max_register_size_per_work_group() const {
+        return get_device_info().get_max_register_size_per_work_group();
+      }
+
+      int get_max_work_group_size() const {
+        return get_device_info().get_max_work_group_size();
+      }
+
+      int get_mem_base_addr_align() const {
+        return get_info<sycl::info::device::mem_base_addr_align>();
+      }
+
+      size_t get_global_mem_size() const {
+        return get_device_info().get_global_mem_size();
+      }
+
+      size_t get_max_mem_alloc_size() const {
+        return get_device_info().get_max_mem_alloc_size();
+      }
+
+      /// Get the number of bytes of free and total memory on the SYCL device.
+      /// \param [out] free_memory The number of bytes of free memory on the
+      /// SYCL device. \param [out] total_memory The number of bytes of total
+      /// memory on the SYCL device.
+      void get_memory_info(size_t &free_memory, size_t &total_memory) {
+        total_memory = get_device_info().get_global_mem_size();
+        const char *warning_info =
+            "get_memory_info: [warning] ext_intel_free_memory is not "
+            "supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
+            "use total memory as free memory";
 #if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
-            if (!has(sycl::aspect::ext_intel_free_memory))
-            {
-                std::cerr << warning_info << std::endl;
-                free_memory = total_memory;
-            }
-            else
-            {
-                free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
-            }
+        if (!has(sycl::aspect::ext_intel_free_memory)) {
+          std::cerr << warning_info << std::endl;
+          free_memory = total_memory;
+        } else {
+          free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
+        }
 #else
-            std::cerr << warning_info << std::endl;
-            free_memory = total_memory;
+        std::cerr << warning_info << std::endl;
+        free_memory = total_memory;
 #if defined(_MSC_VER) && !defined(__clang__)
 #pragma message("Querying the number of bytes of free memory is not supported")
 #else
 #warning "Querying the number of bytes of free memory is not supported"
 #endif
 #endif
-        }
+      }
 
-        void get_device_info(device_info &out) const
-        {
-            dpct::get_device_info(out, *this);
-        }
+      void get_device_info(device_info &out) const {
+        dpct::get_device_info(out, *this);
+      }
 
-        device_info get_device_info() const
-        {
-            device_info prop;
-            dpct::get_device_info(prop, *this);
-            return prop;
-        }
+      device_info get_device_info() const {
+        device_info prop;
+        dpct::get_device_info(prop, *this);
+        return prop;
+      }
 
-        void reset()
-        {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            clear_queues();
-            init_queues();
-        }
+      void reset() {
+        std::lock_guard<mutex_type> lock(m_mutex);
+        clear_queues();
+        init_queues();
+      }
 
-        sycl::queue &in_order_queue() { return *_q_in_order; }
+      sycl::queue &in_order_queue() { return _q_in_order; }
 
-        sycl::queue &out_of_order_queue() { return *_q_out_of_order; }
+      sycl::queue &out_of_order_queue() { return _q_out_of_order; }
 
-        sycl::queue &default_queue()
-        {
-            return in_order_queue();
-        }
+      sycl::queue &default_queue() { return in_order_queue(); }
 
-        void queues_wait_and_throw()
-        {
-            std::unique_lock<mutex_type> lock(m_mutex);
-            std::vector<std::shared_ptr<sycl::queue>> current_queues(
-                _queues);
-            lock.unlock();
-            for (const auto &q : current_queues)
-            {
-                q->wait_and_throw();
-            }
-            // Guard the destruct of current_queues to make sure the ref count is safe.
-            lock.lock();
+      void queues_wait_and_throw() {
+        std::unique_lock<mutex_type> lock(m_mutex);
+        lock.unlock();
+        for (auto &q : _queues) {
+          q.wait_and_throw();
         }
+        // Guard the destruct of current_queues to make sure the ref count is
+        // safe.
+        lock.lock();
+      }
 
-        sycl::queue *create_queue(bool enable_exception_handler = false)
-        {
-            return create_in_order_queue(enable_exception_handler);
-        }
+      sycl::queue create_queue(bool enable_exception_handler = false) {
+        return create_in_order_queue(enable_exception_handler);
+      }
 
-        sycl::queue *create_queue(sycl::context context, sycl::device device,
-                                bool enable_exception_handler = false) {
-            return create_in_order_queue(context, device, enable_exception_handler);
-        }
+      sycl::queue create_queue(sycl::device device,
+                               bool enable_exception_handler = false) {
+        return create_in_order_queue(device, enable_exception_handler);
+      }
 
-        sycl::queue *create_in_order_queue(bool enable_exception_handler = false) {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            return create_queue_impl(enable_exception_handler,
-                                    sycl::property::queue::in_order());
-        }
+      sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
+        std::lock_guard<mutex_type> lock(m_mutex);
+        return create_queue_impl(enable_exception_handler,
+                                 sycl::property::queue::in_order());
+      }
 
-        sycl::queue *create_in_order_queue(sycl::context context, sycl::device device,
+      sycl::queue create_in_order_queue(sycl::device device,
                                         bool enable_exception_handler = false) {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            return create_queue_impl(context, device, enable_exception_handler,
-                                    sycl::property::queue::in_order());
-        }
-
-        sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            return create_queue_impl(enable_exception_handler);
-        }
-
-        void destroy_queue(sycl::queue *&queue)
-        {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
-                                         [=](const std::shared_ptr<sycl::queue> &q) -> bool
-                                         {
-                                             return q.get() == queue;
-                                         }),
-                          _queues.end());
-            queue = nullptr;
-        }
-        void set_saved_queue(sycl::queue *q)
-        {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            _saved_queue = q;
-        }
-        sycl::queue *get_saved_queue() const
-        {
-            std::lock_guard<mutex_type> lock(m_mutex);
-            return _saved_queue;
-        }
-        sycl::context get_context() const { return _ctx; }
-
-    private:
-        void clear_queues()
-        {
-            _queues.clear();
-            _q_in_order = _q_out_of_order = _saved_queue = nullptr;
-        }
-
-        void init_queues()
-        {
-            _q_in_order = create_queue_impl(true, sycl::property::queue::in_order());
-            _q_out_of_order = create_queue_impl(true);
-            _saved_queue = &default_queue();
+        std::lock_guard<mutex_type> lock(m_mutex);
+        return create_queue_impl(device, enable_exception_handler,
+                                 sycl::property::queue::in_order());
+      }
+
+      sycl::queue create_out_of_order_queue(
+          bool enable_exception_handler = false) {
+        std::lock_guard<mutex_type> lock(m_mutex);
+        return create_queue_impl(enable_exception_handler);
+      }
+
+      void destroy_queue(sycl::queue queue) {
+        std::lock_guard<mutex_type> lock(m_mutex);
+        _queues.clear();
+      }
+      void set_saved_queue(sycl::queue q) {
+        std::lock_guard<mutex_type> lock(m_mutex);
+        _saved_queue = q;
+      }
+      sycl::queue get_saved_queue() const {
+        std::lock_guard<mutex_type> lock(m_mutex);
+        return _saved_queue;
+      }
+
+     private:
+      void clear_queues() { _queues.clear(); }
+
+      void init_queues() {
+        _q_in_order =
+            create_queue_impl(true, sycl::property::queue::in_order());
+        _q_out_of_order = create_queue_impl(true);
+        _saved_queue = default_queue();
+      }
+
+      /// Caller should acquire resource \p m_mutex before calling this
+      /// function.
+      template <class... Properties>
+      sycl::queue create_queue_impl(bool enable_exception_handler,
+                                    Properties... properties) {
+        sycl::async_handler eh = {};
+        if (enable_exception_handler) {
+          eh = exception_handler;
         }
-
-        /// Caller should acquire resource \p m_mutex before calling this function.
-        template <class... Properties>
-        sycl::queue *create_queue_impl(bool enable_exception_handler,
-                                       Properties... properties)
-        {
-            sycl::async_handler eh = {};
-            if (enable_exception_handler)
-            {
-                eh = exception_handler;
-            }
-            _queues.push_back(std::make_shared<sycl::queue>(
-                _ctx, *this, eh,
-                sycl::property_list(
+        auto q = sycl::queue(*this, eh,
+                             sycl::property_list(
 #ifdef DPCT_PROFILING_ENABLED
-                    sycl::property::queue::enable_profiling(),
+                                 sycl::property::queue::enable_profiling(),
 #endif
-                    properties...)));
+                                 properties...));
+        _queues.push_back(q);
 
-            return _queues.back().get();
-        }
+        return _queues.back();
+      }
 
-        template <class... Properties>
-        sycl::queue *create_queue_impl(sycl::context context, sycl::device device,
+      template <class... Properties>
+      sycl::queue create_queue_impl(sycl::device device,
                                     bool enable_exception_handler,
                                     Properties... properties) {
-            sycl::async_handler eh = {};
-            if (enable_exception_handler) {
-                eh = exception_handler;
-            }
-            _queues.push_back(std::make_shared<sycl::queue>(
-                context, device, eh,
-                sycl::property_list(
-        #ifdef DPCT_PROFILING_ENABLED
-                    sycl::property::queue::enable_profiling(),
-        #endif
-                    properties...)));
-
-            return _queues.back().get();
+        sycl::async_handler eh = {};
+        if (enable_exception_handler) {
+          eh = exception_handler;
         }
-
-        void get_version(int &major, int &minor) const
-        {
-            detail::get_version(*this, major, minor);
-        }
-        sycl::queue *_q_in_order, *_q_out_of_order;
-        sycl::queue *_saved_queue;
-        sycl::context _ctx;
-        std::vector<std::shared_ptr<sycl::queue>> _queues;
-        mutable mutex_type m_mutex;
+        _queues.push_back(
+            sycl::queue(device, eh,
+                        sycl::property_list(
+#ifdef DPCT_PROFILING_ENABLED
+                            sycl::property::queue::enable_profiling(),
+#endif
+                            properties...)));
+
+        return _queues.back();
+      }
+
+      void get_version(int &major, int &minor) const {
+        detail::get_version(*this, major, minor);
+      }
+      sycl::queue _q_in_order, _q_out_of_order;
+      sycl::queue _saved_queue;
+      std::vector<sycl::queue> _queues;
+      mutable mutex_type m_mutex;
     };
 
+
     /// device manager
     class dev_mgr
     {
diff --git a/src/ggml-sycl/mmq.cpp b/src/ggml-sycl/mmq.cpp
new file mode 100644 (file)
index 0000000..b514f00
--- /dev/null
@@ -0,0 +1,3031 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "mmq.hpp"
+#include "vecdotq.hpp"
+
+typedef void (*allocate_tiles_sycl_t)(
+    int** x_ql,
+    sycl::half2** x_dm,
+    int** x_qh,
+    int** x_sc);
+typedef void (*load_tiles_sycl_t)(
+    const void* __restrict__ vx,
+    int* __restrict__ x_ql,
+    sycl::half2* __restrict__ x_dm,
+    int* __restrict__ x_qh,
+    int* __restrict__ x_sc,
+    const int& i_offset,
+    const int& i_max,
+    const int& k,
+    const int& blocks_per_row);
+typedef float (*vec_dot_q_mul_mat_sycl_t)(
+    const int* __restrict__ x_ql,
+    const sycl::half2* __restrict__ x_dm,
+    const int* __restrict__ x_qh,
+    const int* __restrict__ x_sc,
+    const int* __restrict__ y_qs,
+    const sycl::half2* __restrict__ y_ms,
+    const int& i,
+    const int& j,
+    const int& k);
+
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
+    (void)x_qh; (void)x_sc;
+
+    *x_ql = tile_x_qs_q4_0;
+    *x_dm = (sycl::half2 *)tile_x_d_q4_0;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh; (void)x_sc;
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI4_0;
+    const int kqsx = k % QI4_0;
+
+    const block_q4_0 * bx0 = (const block_q4_0 *) vx;
+
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+        // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+        int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+    }
+}
+
+static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh; (void)x_sc;
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const float * x_dmf = (const float *) x_dm;
+
+    int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
+    }
+
+    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
+         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
+    (void)x_qh; (void)x_sc;
+
+    *x_ql = tile_x_qs_q4_1;
+    *x_dm = tile_x_dm_q4_1;
+}
+
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh; (void)x_sc;
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI4_1;
+    const int kqsx = k % QI4_1;
+
+    const block_q4_1 * bx0 = (const block_q4_1 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
+        int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+    }
+}
+
+static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh; (void)x_sc;
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+
+    int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
+    }
+
+    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
+         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
+    (void)x_qh; (void)x_sc;
+
+    *x_ql = tile_x_ql_q5_0;
+    *x_dm = (sycl::half2 *)tile_x_d_q5_0;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh; (void)x_sc;
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI5_0;
+    const int kqsx = k % QI5_0;
+
+    const block_q5_0 * bx0 = (const block_q5_0 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        const int ql = get_int_from_uint8(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
+
+        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
+        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
+        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
+        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
+        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
+        qs0 = dpct::vectorized_binary<sycl::char4>(
+            qs0, 0x10101010, dpct::sub_sat()); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
+        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
+        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
+        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
+        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
+        qs1 = dpct::vectorized_binary<sycl::char4>(
+            qs1, 0x10101010, dpct::sub_sat()); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+    const int kbxd = k % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
+        int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
+    }
+}
+
+static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh; (void)x_sc;
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    int u[2*VDR_Q5_0_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
+    }
+
+    return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
+    (void)x_qh; (void)x_sc;
+
+    *x_ql = tile_x_ql_q5_1;
+    *x_dm = tile_x_dm_q5_1;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh; (void)x_sc;
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset < nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI5_1;
+    const int kqsx = k % QI5_1;
+
+    const block_q5_1 * bx0 = (const block_q5_1 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
+
+        int qs0 = (ql >>  0) & 0x0F0F0F0F;
+        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
+        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
+        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
+        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+        int qs1 = (ql >>  4) & 0x0F0F0F0F;
+        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
+        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
+        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
+        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
+        int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+    }
+}
+
+static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh; (void)x_sc;
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
+
+    int u[2*VDR_Q5_1_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
+    }
+
+    return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
+    (void)x_qh; (void)x_sc;
+
+    *x_ql = tile_x_qs_q8_0;
+    *x_dm = (sycl::half2 *)tile_x_d_q8_0;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh; (void)x_sc;
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI8_0;
+    const int kqsx = k % QI8_0;
+    float * x_dmf = (float *) x_dm;
+
+    const block_q8_0 * bx0 = (const block_q8_0 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
+        int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+    }
+}
+
+static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh; (void)x_sc;
+
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
+         y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
+                    int *tile_x_sc_q2_K) {
+    (void)x_qh;
+
+    *x_ql = tile_x_ql_q2_K;
+    *x_dm = tile_x_dm_q2_K;
+    *x_sc = tile_x_sc_q2_K;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh;
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI2_K;
+    const int kqsx = k % QI2_K;
+
+    const block_q2_K * bx0 = (const block_q2_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
+        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
+    }
+}
+
+#define VDR_Q2_K_Q8_1_MMQ  2
+// contiguous u/y values
+static __dpct_inline__ float
+vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
+                           const uint8_t *__restrict__ scales,
+                           const sycl::half2 &dm2, const float &d8) {
+
+    int sumi_d = 0;
+    int sumi_m = 0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
+        int sumi_d_sc = 0;
+
+        const int sc = scales[i0 / (QI8_1/2)];
+
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+
+#pragma unroll
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
+            sumi_m = dpct::dp4a(m, u[i],
+                                sumi_m); // multiply sum of q8_1 values with m
+        }
+
+        sumi_d += sumi_d_sc * (sc & 0xF);
+    }
+
+    const sycl::float2 dm2f =
+        dm2.convert<float, sycl::rounding_mode::automatic>();
+
+    return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
+}
+
+static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh;
+
+    const int kbx = k / QI2_K;
+    const int ky  = (k % QI2_K) * QR2_K;
+    const float * y_df = (const float *) y_ds;
+
+    int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
+
+    const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
+    const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
+
+#pragma unroll
+    for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
+        v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
+    }
+
+    const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+
+    const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
+    return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
+                    int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
+
+    *x_ql = tile_x_ql_q3_K;
+    *x_dm = tile_x_dm_q3_K;
+    *x_qh = tile_x_qh_q3_K;
+    *x_sc = tile_x_sc_q3_K;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI3_K;
+    const int kqsx = k % QI3_K;
+
+    const block_q3_K * bx0 = (const block_q3_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
+    const int kbxd = k % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
+        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
+        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
+
+        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
+
+        const int ksc = k % (QI3_K/4);
+
+        const int ksc_low = ksc % (QI3_K/8);
+        const int shift_low = 4 * (ksc / (QI3_K/8));
+        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+        const int ksc_high = QI3_K/8;
+        const int shift_high = 2 * ksc;
+        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+        const int sc = dpct::vectorized_binary<sycl::char4>(
+            sc_low | sc_high, 0x20202020, dpct::sub_sat());
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
+    }
+}
+
+#define VDR_Q3_K_Q8_1_MMQ  2
+// contiguous u/y values
+static __dpct_inline__ float
+vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
+                           const int8_t *__restrict__ scales, const float &d3,
+                           const float &d8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+        int sumi_sc = 0;
+
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+        }
+
+        sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+    }
+
+    return d3*d8 * sumi;
+}
+
+static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+
+    const int kbx  = k / QI3_K;
+    const int ky  = (k % QI3_K) * QR3_K;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+
+    int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
+        const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
+        const int shift = 2 * ((ky % 32) / 8);
+        const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
+
+        const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
+        const int vlh = (vh << 2) & 0x04040404;
+
+        v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
+    }
+
+    const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
+    return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
+                    int *tile_x_sc_q4_K) {
+    (void)x_qh;
+
+    *x_ql = tile_x_ql_q4_K;
+    *x_dm = tile_x_dm_q4_K;
+    *x_sc = tile_x_sc_q4_K;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh;
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI4_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI4_K; // == k if QK_K == 256
+
+    const block_q4_K * bx0 = (const block_q4_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
+        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+#if QK_K == 256
+        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+#else
+        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
+#endif
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = k % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+
+#define VDR_Q4_K_Q8_1_MMQ  8
+
+// contiguous u/y values
+static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
+    const int *__restrict__ v, const int *__restrict__ u,
+    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
+    const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int j = 0; j < QI8_1; ++j) {
+            sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
+                                u[i * QI8_1 + j], sumi_d); // SIMD dot product
+        }
+
+        const sycl::float2 ds8f =
+            ds8[i].convert<float, sycl::rounding_mode::automatic>();
+
+        sumf_d += ds8f.x() * (sc[i] * sumi_d);
+        sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
+    }
+
+    const sycl::float2 dm4f =
+        dm4.convert<float, sycl::rounding_mode::automatic>();
+
+    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
+}
+
+
+static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh;
+
+    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
+
+    const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
+    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
+                                      x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
+                    int *tile_x_sc_q5_K) {
+    (void)x_qh;
+
+    *x_ql = tile_x_ql_q5_K;
+    *x_dm = tile_x_dm_q5_K;
+    *x_sc = tile_x_sc_q5_K;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh;
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI5_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI5_K; // == k if QK_K == 256
+
+    const block_q5_K * bx0 = (const block_q5_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ky = QR5_K*kqsx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+        const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
+        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+#if QK_K == 256
+        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+#endif
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = k % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+#define VDR_Q5_K_Q8_1_MMQ  8
+
+// contiguous u/y values
+static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
+    const int *__restrict__ v, const int *__restrict__ u,
+    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
+    const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int j = 0; j < QI8_1; ++j) {
+            sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
+                                sumi_d); // SIMD dot product
+        }
+
+        const sycl::float2 ds8f =
+            ds8[i].convert<float, sycl::rounding_mode::automatic>();
+
+        sumf_d += ds8f.x() * (sc[i] * sumi_d);
+        sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
+    }
+
+    const sycl::float2 dm4f =
+        dm4.convert<float, sycl::rounding_mode::automatic>();
+
+    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
+}
+
+static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh;
+
+    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
+
+    const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k;
+    const int index_y = j * WARP_SIZE             + (QR5_K*k) % WARP_SIZE;
+    return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
+                                      x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+                    int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
+    (void)x_qh;
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+                sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+                int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+                const int &k, const int &blocks_per_row) {
+    (void)x_qh;
+
+    GGML_SYCL_ASSUME(i_offset >= 0);
+    GGML_SYCL_ASSUME(i_offset <  nwarps);
+    GGML_SYCL_ASSUME(k >= 0);
+    GGML_SYCL_ASSUME(k <  WARP_SIZE);
+
+    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI6_K; // == k if QK_K == 256
+
+    const block_q6_K * bx0 = (const block_q6_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ky = QR6_K*kqsx;
+
+        const int ql = get_int_from_uint8(bxi->ql, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
+        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
+
+        const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
+        const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
+
+        x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
+            dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
+                                                 dpct::sub_sat());
+        x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
+            dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
+                                                 dpct::sub_sat());
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = sycl::min(i, i_max);
+        }
+
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
+    }
+}
+
+#define VDR_Q6_K_Q8_1_MMQ  8
+
+// contiguous u/y values
+static __dpct_inline__ float
+vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
+                           const int8_t *__restrict__ sc, const float &d6,
+                           const float *__restrict__ d8) {
+
+    float sumf_d = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+        sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+        for (int i = i0; i < i0 + 2; ++i) {
+            sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
+                                    sumi_d.x()); // SIMD dot product
+            sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
+                                    sumi_d.x()); // SIMD dot product
+
+            sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
+                                    sumi_d.y()); // SIMD dot product
+            sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
+                                    sumi_d.y()); // SIMD dot product
+        }
+
+        sumf_d += d8[i0 / 4] *
+                  (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
+    }
+
+    return d6 * sumf_d;
+}
+
+static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
+    const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+    const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+    const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+    const int &i, const int &j, const int &k) {
+    (void)x_qh;
+
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
+
+    const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k;
+    const int index_y = j * WARP_SIZE             + (QR6_K*k) % WARP_SIZE;
+    return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
+}
+
+template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
+          int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
+          vec_dot_q_mul_mat_sycl_t vec_dot>
+/*
+DPCT1110:8: The total declared local variable size in device function mul_mat_q
+exceeds 128 bytes and may cause high register pressure. Consult with your
+hardware vendor to find the total register size available and adjust the code,
+or use smaller sub-group size to avoid high register pressure.
+*/
+static __dpct_inline__ void
+mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
+          float *__restrict__ dst, const int ncols_x, const int nrows_x,
+          const int ncols_y, const int nrows_y, const int nrows_dst,
+          int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
+          int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
+          sycl::half2 *tile_y_ds) {
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    const int blocks_per_row_x = ncols_x / qk;
+    const int blocks_per_col_y = nrows_y / QK8_1;
+    const int blocks_per_warp = WARP_SIZE / qi;
+
+    const int & ncols_dst = ncols_y;
+
+    const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
+    const int & row_x_0 = row_dst_0;
+
+    const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
+    const int & col_y_0 = col_dst_0;
+
+    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
+
+    for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
+
+        load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
+                   tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
+                   nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
+                   blocks_per_row_x);
+
+#pragma unroll
+        for (int ir = 0; ir < qr; ++ir) {
+            const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
+            const int kbxd = kqs / QI8_1;
+
+#pragma unroll
+            for (int i = 0; i < mmq_x; i += nwarps) {
+                const int col_y_eff = dpct::min(
+                    (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
+                    ncols_y - 1); // to prevent out-of-bounds memory accesses
+
+                const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
+
+                const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
+                                    kqs % WARP_SIZE;
+                tile_y_qs[index_y] = get_int_from_int8_aligned(
+                    by0->qs, item_ct1.get_local_id(2) % QI8_1);
+            }
+
+#pragma unroll
+            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
+                const int ids =
+                    (ids0 + item_ct1.get_local_id(1) * QI8_1 +
+                     item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
+                    mmq_x;
+                const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
+                const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
+
+                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+                const sycl::half2 *dsi_src =
+                    &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
+                       ir * (WARP_SIZE / QI8_1) + kby]
+                         .ds;
+                sycl::half2 *dsi_dst =
+                    &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
+                if (need_sum) {
+                    *dsi_dst = *dsi_src;
+                } else {
+                    float * dfi_dst = (float *) dsi_dst;
+                    *dfi_dst = (*dsi_src)[0];
+                }
+            }
+
+            /*
+            DPCT1118:9: SYCL group functions and algorithms must be encountered
+            in converged control flow. You may need to adjust the code.
+            */
+            /*
+            DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
+            sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
+            better performance if there is no access to global memory.
+            */
+            item_ct1.barrier();
+
+// #pragma unroll // unrolling this loop causes too much register pressure
+            for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
+#pragma unroll
+                for (int j = 0; j < mmq_x; j += nwarps) {
+#pragma unroll
+                    for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+                        sum[i / WARP_SIZE][j / nwarps] += vec_dot(
+                            tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
+                            tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
+                            item_ct1.get_local_id(1) + j, k);
+                    }
+                }
+            }
+
+            /*
+            DPCT1118:10: SYCL group functions and algorithms must be encountered
+            in converged control flow. You may need to adjust the code.
+            */
+            /*
+            DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
+            sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
+            better performance if there is no access to global memory.
+            */
+            item_ct1.barrier();
+        }
+    }
+
+#pragma unroll
+    for (int j = 0; j < mmq_x; j += nwarps) {
+        const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
+
+        if (col_dst >= ncols_dst) {
+            return;
+        }
+
+#pragma unroll
+        for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+            const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
+
+            if (row_dst >= nrows_dst) {
+                continue;
+            }
+
+            dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
+        }
+    }
+}
+
+#define  MMQ_X_Q4_0_RDNA2  64
+#define  MMQ_Y_Q4_0_RDNA2  128
+#define NWARPS_Q4_0_RDNA2  8
+#define  MMQ_X_Q4_0_RDNA1  64
+#define  MMQ_Y_Q4_0_RDNA1  64
+#define NWARPS_Q4_0_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q4_0_AMPERE 4
+#define  MMQ_Y_Q4_0_AMPERE 32
+#define NWARPS_Q4_0_AMPERE 4
+#else
+#define  MMQ_X_Q4_0_AMPERE 64
+#define  MMQ_Y_Q4_0_AMPERE 128
+#define NWARPS_Q4_0_AMPERE 4
+#endif
+#define  MMQ_X_Q4_0_PASCAL 64
+#define  MMQ_Y_Q4_0_PASCAL 64
+#define NWARPS_Q4_0_PASCAL 8
+
+template <bool need_check> static void
+    mul_mat_q4_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
+    int *tile_y_qs, sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+
+    const int mmq_x  =  MMQ_X_Q4_0_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;
+    const int nwarps = NWARPS_Q4_0_AMPERE;
+    allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_qs_q4_0, tile_x_d_q4_0);
+    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
+              load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
+              vec_dot_q4_0_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q4_1_RDNA2  64
+#define  MMQ_Y_Q4_1_RDNA2  128
+#define NWARPS_Q4_1_RDNA2  8
+#define  MMQ_X_Q4_1_RDNA1  64
+#define  MMQ_Y_Q4_1_RDNA1  64
+#define NWARPS_Q4_1_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q4_1_AMPERE 4
+#define  MMQ_Y_Q4_1_AMPERE 32
+#define NWARPS_Q4_1_AMPERE 4
+#else
+#define  MMQ_X_Q4_1_AMPERE 64
+#define  MMQ_Y_Q4_1_AMPERE 128
+#define NWARPS_Q4_1_AMPERE 4
+#endif
+#define  MMQ_X_Q4_1_PASCAL 64
+#define  MMQ_Y_Q4_1_PASCAL 64
+#define NWARPS_Q4_1_PASCAL 8
+
+template <bool need_check> static void
+    mul_mat_q4_1(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
+    sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q4_1_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;
+    const int nwarps = NWARPS_Q4_1_AMPERE;
+    allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_qs_q4_1, tile_x_dm_q4_1);
+    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
+              load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
+              vec_dot_q4_1_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q5_0_RDNA2  64
+#define  MMQ_Y_Q5_0_RDNA2  128
+#define NWARPS_Q5_0_RDNA2  8
+#define  MMQ_X_Q5_0_RDNA1  64
+#define  MMQ_Y_Q5_0_RDNA1  64
+#define NWARPS_Q5_0_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q5_0_AMPERE 4
+#define  MMQ_Y_Q5_0_AMPERE 32
+#define NWARPS_Q5_0_AMPERE 4
+#else
+#define  MMQ_X_Q5_0_AMPERE 128
+#define  MMQ_Y_Q5_0_AMPERE 64
+#define NWARPS_Q5_0_AMPERE 4
+#endif
+#define  MMQ_X_Q5_0_PASCAL 64
+#define  MMQ_Y_Q5_0_PASCAL 64
+#define NWARPS_Q5_0_PASCAL 8
+
+template <bool need_check> static void
+    mul_mat_q5_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
+    int *tile_y_qs, sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q5_0_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;
+    const int nwarps = NWARPS_Q5_0_AMPERE;
+    allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_ql_q5_0, tile_x_d_q5_0);
+    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
+              load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
+              vec_dot_q5_0_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q5_1_RDNA2  64
+#define  MMQ_Y_Q5_1_RDNA2  128
+#define NWARPS_Q5_1_RDNA2  8
+#define  MMQ_X_Q5_1_RDNA1  64
+#define  MMQ_Y_Q5_1_RDNA1  64
+#define NWARPS_Q5_1_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q5_1_AMPERE 4
+#define  MMQ_Y_Q5_1_AMPERE 32
+#define NWARPS_Q5_1_AMPERE 4
+#else
+#define  MMQ_X_Q5_1_AMPERE 128
+#define  MMQ_Y_Q5_1_AMPERE 64
+#define NWARPS_Q5_1_AMPERE 4
+#endif
+#define  MMQ_X_Q5_1_PASCAL 64
+#define  MMQ_Y_Q5_1_PASCAL 64
+#define NWARPS_Q5_1_PASCAL 8
+
+template <bool need_check> static void
+mul_mat_q5_1(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
+    sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q5_1_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;
+    const int nwarps = NWARPS_Q5_1_AMPERE;
+    allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_ql_q5_1, tile_x_dm_q5_1);
+    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
+              load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
+              vec_dot_q5_1_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q8_0_RDNA2  64
+#define  MMQ_Y_Q8_0_RDNA2  128
+#define NWARPS_Q8_0_RDNA2  8
+#define  MMQ_X_Q8_0_RDNA1  64
+#define  MMQ_Y_Q8_0_RDNA1  64
+#define NWARPS_Q8_0_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q8_0_AMPERE 4
+#define  MMQ_Y_Q8_0_AMPERE 32
+#define NWARPS_Q8_0_AMPERE 4
+#else
+#define  MMQ_X_Q8_0_AMPERE 128
+#define  MMQ_Y_Q8_0_AMPERE 64
+#define NWARPS_Q8_0_AMPERE 4
+#endif
+#define  MMQ_X_Q8_0_PASCAL 64
+#define  MMQ_Y_Q8_0_PASCAL 64
+#define NWARPS_Q8_0_PASCAL 8
+
+template <bool need_check> static void
+    mul_mat_q8_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
+    int *tile_y_qs, sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q8_0_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;
+    const int nwarps = NWARPS_Q8_0_AMPERE;
+    allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_qs_q8_0, tile_x_d_q8_0);
+    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
+              load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
+              vec_dot_q8_0_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q2_K_RDNA2  64
+#define  MMQ_Y_Q2_K_RDNA2  128
+#define NWARPS_Q2_K_RDNA2  8
+#define  MMQ_X_Q2_K_RDNA1  128
+#define  MMQ_Y_Q2_K_RDNA1  32
+#define NWARPS_Q2_K_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q2_K_AMPERE 4
+#define  MMQ_Y_Q2_K_AMPERE 32
+#define NWARPS_Q2_K_AMPERE 4
+#else
+#define  MMQ_X_Q2_K_AMPERE 64
+#define  MMQ_Y_Q2_K_AMPERE 128
+#define NWARPS_Q2_K_AMPERE 4
+#endif
+#define  MMQ_X_Q2_K_PASCAL 64
+#define  MMQ_Y_Q2_K_PASCAL 64
+#define NWARPS_Q2_K_PASCAL 8
+
+template <bool need_check> static void
+mul_mat_q2_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
+    sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
+    sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q2_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;
+    const int nwarps = NWARPS_Q2_K_AMPERE;
+    allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
+    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
+              load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
+              vec_dot_q2_K_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q3_K_RDNA2  128
+#define  MMQ_Y_Q3_K_RDNA2  64
+#define NWARPS_Q3_K_RDNA2  8
+#define  MMQ_X_Q3_K_RDNA1  32
+#define  MMQ_Y_Q3_K_RDNA1  128
+#define NWARPS_Q3_K_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q3_K_AMPERE 4
+#define  MMQ_Y_Q3_K_AMPERE 32
+#define NWARPS_Q3_K_AMPERE 4
+#else
+#define  MMQ_X_Q3_K_AMPERE 128
+#define  MMQ_Y_Q3_K_AMPERE 128
+#define NWARPS_Q3_K_AMPERE 4
+#endif
+#define  MMQ_X_Q3_K_PASCAL 64
+#define  MMQ_Y_Q3_K_PASCAL 64
+#define NWARPS_Q3_K_PASCAL 8
+
+template <bool need_check> static void
+mul_mat_q3_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
+    sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
+    int *tile_y_qs, sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q3_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;
+    const int nwarps = NWARPS_Q3_K_AMPERE;
+    allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
+                               tile_x_sc_q3_K);
+    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
+              load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
+              vec_dot_q3_K_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q4_K_RDNA2  64
+#define  MMQ_Y_Q4_K_RDNA2  128
+#define NWARPS_Q4_K_RDNA2  8
+#define  MMQ_X_Q4_K_RDNA1  32
+#define  MMQ_Y_Q4_K_RDNA1  64
+#define NWARPS_Q4_K_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q4_K_AMPERE 4
+#define  MMQ_Y_Q4_K_AMPERE 32
+#define NWARPS_Q4_K_AMPERE 4
+#else
+#define  MMQ_X_Q4_K_AMPERE 64
+#define  MMQ_Y_Q4_K_AMPERE 128
+#define NWARPS_Q4_K_AMPERE 4
+#endif
+#define  MMQ_X_Q4_K_PASCAL 64
+#define  MMQ_Y_Q4_K_PASCAL 64
+#define NWARPS_Q4_K_PASCAL 8
+
+template <bool need_check> static void
+    mul_mat_q4_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
+    sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
+    sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q4_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;
+    const int nwarps = NWARPS_Q4_K_AMPERE;
+    allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
+    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
+              load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
+              vec_dot_q4_K_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q5_K_RDNA2  64
+#define  MMQ_Y_Q5_K_RDNA2  128
+#define NWARPS_Q5_K_RDNA2  8
+#define  MMQ_X_Q5_K_RDNA1  32
+#define  MMQ_Y_Q5_K_RDNA1  64
+#define NWARPS_Q5_K_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q5_K_AMPERE 4
+#define  MMQ_Y_Q5_K_AMPERE 32
+#define NWARPS_Q5_K_AMPERE 4
+#else
+#define  MMQ_X_Q5_K_AMPERE 64
+#define  MMQ_Y_Q5_K_AMPERE 128
+#define NWARPS_Q5_K_AMPERE 4
+#endif
+#define  MMQ_X_Q5_K_PASCAL 64
+#define  MMQ_Y_Q5_K_PASCAL 64
+#define NWARPS_Q5_K_PASCAL 8
+
+template <bool need_check> static void
+mul_mat_q5_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
+    sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
+    sycl::half2 *tile_y_ds) {
+    int   * tile_x_ql = nullptr;
+    sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q5_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;
+    const int nwarps = NWARPS_Q5_K_AMPERE;
+    allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
+    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
+              load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
+              vec_dot_q5_K_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define  MMQ_X_Q6_K_RDNA2  64
+#define  MMQ_Y_Q6_K_RDNA2  128
+#define NWARPS_Q6_K_RDNA2  8
+#define  MMQ_X_Q6_K_RDNA1  32
+#define  MMQ_Y_Q6_K_RDNA1  64
+#define NWARPS_Q6_K_RDNA1  8
+#if defined(SYCL_USE_XMX)
+#define  MMQ_X_Q6_K_AMPERE 4
+#define  MMQ_Y_Q6_K_AMPERE 32
+#define NWARPS_Q6_K_AMPERE 4
+#else
+#define  MMQ_X_Q6_K_AMPERE 64
+#define  MMQ_Y_Q6_K_AMPERE 64
+#define NWARPS_Q6_K_AMPERE 4
+#endif
+#define  MMQ_X_Q6_K_PASCAL 64
+#define  MMQ_Y_Q6_K_PASCAL 64
+#define NWARPS_Q6_K_PASCAL 8
+
+template <bool need_check> static void
+    mul_mat_q6_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+    const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
+    int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
+    // int   * tile_x_ql = nullptr;
+    // sycl::half2 *tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    // int   * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+    const int mmq_x  =  MMQ_X_Q6_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;
+    const int nwarps = NWARPS_Q6_K_AMPERE;
+    allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+                               tile_x_ql, tile_x_dm, tile_x_sc);
+    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
+              load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
+              vec_dot_q6_K_q8_1_mul_mat>(
+        vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+        tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q4_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_0_RDNA2;
+        nwarps = NWARPS_Q4_0_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q4_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_0_RDNA1;
+        nwarps = NWARPS_Q4_0_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q4_0_AMPERE;
+        mmq_y  =  MMQ_Y_Q4_0_AMPERE;
+        nwarps = NWARPS_Q4_0_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q4_0_PASCAL;
+        mmq_y  =  MMQ_Y_Q4_0_PASCAL;
+        nwarps = NWARPS_Q4_0_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q4_0<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_qs_q4_0_acc_ct1.get_pointer(),
+                            tile_x_d_q4_0_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q4_0<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_qs_q4_0_acc_ct1.get_pointer(),
+                            tile_x_d_q4_0_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q4_1_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_1_RDNA2;
+        nwarps = NWARPS_Q4_1_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q4_1_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_1_RDNA1;
+        nwarps = NWARPS_Q4_1_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q4_1_AMPERE;
+        mmq_y  =  MMQ_Y_Q4_1_AMPERE;
+        nwarps = NWARPS_Q4_1_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q4_1_PASCAL;
+        mmq_y  =  MMQ_Y_Q4_1_PASCAL;
+        nwarps = NWARPS_Q4_1_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q4_1<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_qs_q4_1_acc_ct1.get_pointer(),
+                            tile_x_dm_q4_1_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q4_1<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_qs_q4_1_acc_ct1.get_pointer(),
+                            tile_x_dm_q4_1_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q5_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_0_RDNA2;
+        nwarps = NWARPS_Q5_0_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q5_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_0_RDNA1;
+        nwarps = NWARPS_Q5_0_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q5_0_AMPERE;
+        mmq_y  =  MMQ_Y_Q5_0_AMPERE;
+        nwarps = NWARPS_Q5_0_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q5_0_PASCAL;
+        mmq_y  =  MMQ_Y_Q5_0_PASCAL;
+        nwarps = NWARPS_Q5_0_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q5_0<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q5_0_acc_ct1.get_pointer(),
+                            tile_x_d_q5_0_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q5_0<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q5_0_acc_ct1.get_pointer(),
+                            tile_x_d_q5_0_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q5_1_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_1_RDNA2;
+        nwarps = NWARPS_Q5_1_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q5_1_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_1_RDNA1;
+        nwarps = NWARPS_Q5_1_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q5_1_AMPERE;
+        mmq_y  =  MMQ_Y_Q5_1_AMPERE;
+        nwarps = NWARPS_Q5_1_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q5_1_PASCAL;
+        mmq_y  =  MMQ_Y_Q5_1_PASCAL;
+        nwarps = NWARPS_Q5_1_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
+                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q5_1<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q5_1_acc_ct1.get_pointer(),
+                            tile_x_dm_q5_1_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
+                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q5_1<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q5_1_acc_ct1.get_pointer(),
+                            tile_x_dm_q5_1_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q8_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q8_0_RDNA2;
+        nwarps = NWARPS_Q8_0_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q8_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q8_0_RDNA1;
+        nwarps = NWARPS_Q8_0_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q8_0_AMPERE;
+        mmq_y  =  MMQ_Y_Q8_0_AMPERE;
+        nwarps = NWARPS_Q8_0_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q8_0_PASCAL;
+        mmq_y  =  MMQ_Y_Q8_0_PASCAL;
+        nwarps = NWARPS_Q8_0_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q8_0<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_qs_q8_0_acc_ct1.get_pointer(),
+                            tile_x_d_q8_0_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q8_0<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_qs_q8_0_acc_ct1.get_pointer(),
+                            tile_x_d_q8_0_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q2_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q2_K_RDNA2;
+        nwarps = NWARPS_Q2_K_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q2_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q2_K_RDNA1;
+        nwarps = NWARPS_Q2_K_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q2_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q2_K_AMPERE;
+        nwarps = NWARPS_Q2_K_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q2_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q2_K_PASCAL;
+        nwarps = NWARPS_Q2_K_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q2_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q2_K_acc_ct1.get_pointer(),
+                            tile_x_dm_q2_K_acc_ct1.get_pointer(),
+                            tile_x_sc_q2_K_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q2_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q2_K_acc_ct1.get_pointer(),
+                            tile_x_dm_q2_K_acc_ct1.get_pointer(),
+                            tile_x_sc_q2_K_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+#if QK_K == 256
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q3_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q3_K_RDNA2;
+        nwarps = NWARPS_Q3_K_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q3_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q3_K_RDNA1;
+        nwarps = NWARPS_Q3_K_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q3_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q3_K_AMPERE;
+        nwarps = NWARPS_Q3_K_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q3_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q3_K_PASCAL;
+        nwarps = NWARPS_Q3_K_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q3_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q3_K_acc_ct1.get_pointer(),
+                            tile_x_dm_q3_K_acc_ct1.get_pointer(),
+                            tile_x_qh_q3_K_acc_ct1.get_pointer(),
+                            tile_x_sc_q3_K_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q3_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q3_K_acc_ct1.get_pointer(),
+                            tile_x_dm_q3_K_acc_ct1.get_pointer(),
+                            tile_x_qh_q3_K_acc_ct1.get_pointer(),
+                            tile_x_sc_q3_K_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+#endif
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q4_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_K_RDNA2;
+        nwarps = NWARPS_Q4_K_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q4_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_K_RDNA1;
+        nwarps = NWARPS_Q4_K_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q4_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q4_K_AMPERE;
+        nwarps = NWARPS_Q4_K_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q4_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q4_K_PASCAL;
+        nwarps = NWARPS_Q4_K_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q4_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q4_K_acc_ct1.get_pointer(),
+                            tile_x_dm_q4_K_acc_ct1.get_pointer(),
+                            tile_x_sc_q4_K_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q4_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q4_K_acc_ct1.get_pointer(),
+                            tile_x_dm_q4_K_acc_ct1.get_pointer(),
+                            tile_x_sc_q4_K_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q5_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_K_RDNA2;
+        nwarps = NWARPS_Q5_K_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q5_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_K_RDNA1;
+        nwarps = NWARPS_Q5_K_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q5_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q5_K_AMPERE;
+        nwarps = NWARPS_Q5_K_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q5_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q5_K_PASCAL;
+        nwarps = NWARPS_Q5_K_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q5_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q5_K_acc_ct1.get_pointer(),
+                            tile_x_dm_q5_K_acc_ct1.get_pointer(),
+                            tile_x_sc_q5_K_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q5_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_q5_K_acc_ct1.get_pointer(),
+                            tile_x_dm_q5_K_acc_ct1.get_pointer(),
+                            tile_x_sc_q5_K_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
+                                        float *dst, const int ncols_x,
+                                        const int nrows_x, const int ncols_y,
+                                        const int nrows_y, const int nrows_dst,
+                                        dpct::queue_ptr stream) try {
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+    const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+    int mmq_x, mmq_y, nwarps;
+    if (compute_capability >= VER_GEN13) {
+        mmq_x  =  MMQ_X_Q6_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q6_K_RDNA2;
+        nwarps = NWARPS_Q6_K_RDNA2;
+    } else if (compute_capability >= VER_GEN12) {
+        mmq_x  =  MMQ_X_Q6_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q6_K_RDNA1;
+        nwarps = NWARPS_Q6_K_RDNA1;
+    } else if (compute_capability >= VER_GEN9) {
+        mmq_x  =  MMQ_X_Q6_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q6_K_AMPERE;
+        nwarps = NWARPS_Q6_K_AMPERE;
+    } else if (compute_capability >= VER_4VEC) {
+        mmq_x  =  MMQ_X_Q6_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q6_K_PASCAL;
+        nwarps = NWARPS_Q6_K_PASCAL;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+    const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        /*
+        DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
+                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q6_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_acc_ct1.get_pointer(),
+                            tile_x_dm_acc_ct1.get_pointer(),
+                            tile_x_sc_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    } else {
+        const bool need_check = true;
+        /*
+        DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            stream->submit([&](sycl::handler &cgh) {
+                sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
+                    sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
+                    cgh);
+                sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
+                    sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+                sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+                sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+                    sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+                cgh.parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        mul_mat_q6_K<need_check>(
+                            vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+                            nrows_dst, item_ct1,
+                            tile_x_ql_acc_ct1.get_pointer(),
+                            tile_x_dm_acc_ct1.get_pointer(),
+                            tile_x_sc_acc_ct1.get_pointer(),
+                            tile_y_qs_acc_ct1.get_pointer(),
+                            tile_y_ds_acc_ct1.get_pointer());
+                    });
+            });
+        }
+    }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+void ggml_sycl_op_mul_mat_q(
+    ggml_backend_sycl_context & ctx,
+    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+    const int64_t src1_ncols, const int64_t src1_padded_row_size,
+    const dpct::queue_ptr &stream) try {
+
+    const int64_t ne00 = src0->ne[0];
+
+    const int64_t ne10 = src1->ne[0];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    const int64_t row_diff = row_high - row_low;
+
+    int device_id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(device_id = get_current_device_id()));
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
+    const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+
+    (void) src1;
+    (void) dst;
+    (void) src1_ddf_i;
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
diff --git a/src/ggml-sycl/mmq.hpp b/src/ggml-sycl/mmq.hpp
new file mode 100644 (file)
index 0000000..3f5297a
--- /dev/null
@@ -0,0 +1,33 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_MMQ_HPP
+#define GGML_SYCL_MMQ_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_mul_mat_q(
+    ggml_backend_sycl_context & ctx,
+    const ggml_tensor* src0,
+    const ggml_tensor* src1,
+    ggml_tensor* dst,
+    const char* src0_dd_i,
+    const float* src1_ddf_i,
+    const char* src1_ddq_i,
+    float* dst_dd_i,
+    const int64_t row_low,
+    const int64_t row_high,
+    const int64_t src1_ncols,
+    const int64_t src1_padded_row_size,
+    const dpct::queue_ptr& stream);
+
+#endif // GGML_SYCL_MMQ_HPP
diff --git a/src/ggml-sycl/mmvq.cpp b/src/ggml-sycl/mmvq.cpp
new file mode 100644 (file)
index 0000000..2322764
--- /dev/null
@@ -0,0 +1,1024 @@
+#include "mmvq.hpp"
+#include "vecdotq.hpp"
+
+
+template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
+static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
+                                       const void *__restrict__ vy,
+                                       float *__restrict__ dst, const int ncols,
+                                       const int nrows,
+                                       const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
+                                      const void *__restrict__ vy,
+                                      float *__restrict__ dst, const int ncols,
+                                      const int nrows,
+                                      const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
+                                     const void *__restrict__ vy,
+                                     float *__restrict__ dst, const int ncols,
+                                     const int nrows,
+                                     const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
+                                       const void *__restrict__ vy,
+                                       float *__restrict__ dst, const int ncols,
+                                       const int nrows,
+                                       const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
+                                     const void *__restrict__ vy,
+                                     float *__restrict__ dst, const int ncols,
+                                     const int nrows,
+                                     const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
+                                     const void *__restrict__ vy,
+                                     float *__restrict__ dst, const int ncols,
+                                     const int nrows,
+                                     const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
+                                     const void *__restrict__ vy,
+                                     float *__restrict__ dst, const int ncols,
+                                     const int nrows,
+                                     const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
+                                      const void *__restrict__ vy,
+                                      float *__restrict__ dst, const int ncols,
+                                      const int nrows,
+                                      const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
+                                      const void *__restrict__ vy,
+                                      float *__restrict__ dst, const int ncols,
+                                      const int nrows,
+                                      const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+                    item_ct1.get_local_id(1);
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+         i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs =
+            vdr *
+            (item_ct1.get_local_id(2) %
+             (qi / vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
+
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK4_0 == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
+                                      VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK4_1 == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
+                                      VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK5_0 == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
+                                      VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK5_1 == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
+                                      VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK8_0 == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
+                                      VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
+                                      VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
+                                      VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
+                                      VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
+                                      VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
+                                       float *dst, const int ncols,
+                                       const int nrows,
+                                       dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
+                                      VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+
+static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS, block_iq2_xxs, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
+                                         float *dst, const int ncols,
+                                         const int nrows,
+                                         dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
+            auto ksigns64_ptr_ct1 = &ksigns64[0];
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS, block_iq2_xs, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
+                                         float *dst, const int ncols,
+                                         const int nrows,
+                                         dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
+            auto ksigns64_ptr_ct1 = &ksigns64[0];
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0];
+            auto ksigns64_ptr_ct1 = &ksigns64[0];
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS, block_iq3_xxs, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_XS, block_iq3_s, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+            auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0];
+            auto ksigns64_ptr_ct1 = &ksigns64[0];
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK4_NL == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
+                                          float *dst, const int ncols,
+                                          const int nrows,
+                                          dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    {
+
+        stream->submit([&](sycl::handler &cgh) {
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                    [[intel::reqd_sub_group_size(32)]] {
+                        mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS, block_iq4_xs, 1>(
+                            vx, vy, dst, ncols, nrows, item_ct1);
+                    });
+        });
+    }
+}
+
+void ggml_sycl_op_mul_mat_vec_q(
+    ggml_backend_sycl_context & ctx,
+    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+    const int64_t src1_ncols, const int64_t src1_padded_row_size,
+    const dpct::queue_ptr &stream) {
+
+    const int64_t ne10 = src1->ne[0];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the kernel writes into
+    const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff;
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ1_M:
+            mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ2_XXS:
+            mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ2_XS:
+            mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ2_S:
+            mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+
+    (void) src1;
+    (void) dst;
+    (void) src1_ddf_i;
+    (void) src1_ncols;
+    (void) src1_padded_row_size;
+}
diff --git a/src/ggml-sycl/mmvq.hpp b/src/ggml-sycl/mmvq.hpp
new file mode 100644 (file)
index 0000000..049b43d
--- /dev/null
@@ -0,0 +1,27 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_MMVQ_HPP
+#define GGML_SYCL_MMVQ_HPP
+
+#include "common.hpp"
+
+
+void ggml_sycl_op_mul_mat_vec_q(
+    ggml_backend_sycl_context & ctx,
+    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+    const int64_t src1_ncols, const int64_t src1_padded_row_size,
+    const dpct::queue_ptr &stream);
+
+#endif // GGML_SYCL_MMVQ_HPP
index dcf0261102e914021244f5f98a12e789d0c9f23c..fe9d41770b76a42d9c3a98d6baa1eb6c98d67382 100644 (file)
 
 #define GGML_SYCL_MAX_STREAMS       8
 #define GGML_SYCL_MAX_BUFFERS       256
-#define GGML_SYCL_MAX_DEVICES       48
-#define GGML_SYCL_NAME "SYCL"
 
-// FIXME: 1024 from cuda
-#define GROUP_SIZE 1024
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
diff --git a/src/ggml-sycl/vecdotq.hpp b/src/ggml-sycl/vecdotq.hpp
new file mode 100644 (file)
index 0000000..5e2e825
--- /dev/null
@@ -0,0 +1,1161 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_VECDOTQ_HPP
+#define GGML_SYCL_VECDOTQ_HPP
+
+#include "dpct/helper.hpp"
+
+typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+
+static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
+  const uint16_t* x16 =
+      (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte
+                                                 // alignment
+
+  int x32 = 0;
+  x32 |= x16[0] << 0;
+  x32 |= x16[1] << 16;
+
+  return x32;
+}
+
+static __dpct_inline__ int get_int_from_uint8(
+    const uint8_t* x8,
+    const int& i32) {
+  const uint16_t* x16 =
+      (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte
+                                                 // alignment
+
+  int x32 = 0;
+  x32 |= x16[0] << 0;
+  x32 |= x16[1] << 16;
+
+  return x32;
+}
+
+static __dpct_inline__ int get_int_from_int8_aligned(
+    const int8_t* x8,
+    const int& i32) {
+  return *(
+      (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+static __dpct_inline__ int get_int_from_uint8_aligned(
+    const uint8_t* x8,
+    const int& i32) {
+  return *(
+      (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
+                                                  const uint8_t *values,
+                                                  int &val1, int &val2) {
+
+    uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
+    aux32 = q4 & 0x0f0f0f0f;
+    uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
+    uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
+    val1 = v1 | (v2 << 16);
+    aux32 = (q4 >> 4) & 0x0f0f0f0f;
+    v1 = values[q8[0]] | (values[q8[1]] << 8);
+    v2 = values[q8[2]] | (values[q8[3]] << 8);
+    val2 = v1 | (v2 << 16);
+}
+
+#define VDR_Q2_K_Q8_1_MMVQ 1
+
+// contiguous v/x values
+static __dpct_inline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+    const int &v, const int *__restrict__ u, const uint8_t *__restrict__ scales,
+    const sycl::half2 &dm2, const float *__restrict__ d8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++i) {
+        const int sc = scales[2*i];
+
+        const int vi = (v >> (2*i)) & 0x03030303;
+
+        sumf_d +=
+            d8[i] * (dpct::dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+        sumf_m += d8[i] *
+                  dpct::dp4a(
+                      m, u[i],
+                      0); // multiply constant q2_K part with sum of q8_1 values
+    }
+
+    const sycl::float2 dm2f =
+        dm2.convert<float, sycl::rounding_mode::automatic>();
+
+    return dm2f.x() * sumf_d - dm2f.y() * sumf_m;
+}
+
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+
+// contiguous v/x values
+static __dpct_inline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+    const int &vl, const int &vh, const int *__restrict__ u,
+    const uint8_t *__restrict__ scales, const int &scale_offset,
+    const float &d3, const float *__restrict__ d8) {
+
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR3_K; ++i) {
+        const int isc = scale_offset + 2*i;
+
+        const int isc_low = isc % (QK_K/32);
+        const int sc_shift_low = 4 * (isc / (QK_K/32));
+        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;
+
+        const int isc_high = isc % (QK_K/64);
+        const int sc_shift_high = 2 * (isc / (QK_K/64));
+        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+        const int sc = (sc_low | sc_high) - 32;
+
+        const int vil = (vl >> (2*i)) & 0x03030303;
+
+        const int vih = ((vh >> i) << 2) & 0x04040404;
+
+        const int vi =
+            dpct::vectorized_binary<sycl::char4>(vil, vih, dpct::sub_sat());
+
+        sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d3 * sumf;
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+
+// contiguous v/x values
+static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+    const int *__restrict__ v, const int *__restrict__ u,
+    const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
+    const sycl::half2 &dm4, const float *__restrict__ d8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR4_K; ++i) {
+        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int dot1 =
+            dpct::dp4a(v1i, u[2 * i + 1],
+                       dpct::dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product
+        const int dot2 =
+            dpct::dp4a(0x01010101, u[2 * i + 1],
+                       dpct::dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
+    }
+
+    const sycl::float2 dm4f =
+        dm4.convert<float, sycl::rounding_mode::automatic>();
+
+    return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
+}
+
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+
+// contiguous v/x values
+static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_vmmq(
+    const int *__restrict__ vl, const int *__restrict__ vh,
+    const int *__restrict__ u, const uint8_t *__restrict__ sc,
+    const uint8_t *__restrict__ m, const sycl::half2 &dm5,
+    const float *__restrict__ d8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+        const int v0i = vl0i | vh0i;
+        const int v1i = vl1i | vh1i;
+
+        const int dot1 =
+            dpct::dp4a(v0i, u[2 * i + 0],
+                       dpct::dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product
+        const int dot2 =
+            dpct::dp4a(0x01010101, u[2 * i + 0],
+                       dpct::dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);
+
+    }
+
+    const sycl::float2 dm5f =
+        dm5.convert<float, sycl::rounding_mode::automatic>();
+
+    return dm5f.x() * sumf_d - dm5f.y() * sumf_m;
+}
+
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+
+// contiguous v/x values
+static __dpct_inline__ float
+vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
+                            const int *__restrict__ u,
+                            const int8_t *__restrict__ scales, const float &d,
+                            const float *__restrict__ d8) {
+
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        const int sc = scales[4*i];
+
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+        const int vi = dpct::vectorized_binary<sycl::char4>(
+            (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32
+
+        sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+}
+
+// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ  4
+
+template <int vdr>
+static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
+                                                    const float &d4,
+                                                    const sycl::half2 &ds8) {
+    int sumi = 0;
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+        // SIMD dot product of quantized values
+        sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
+        sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
+    }
+
+    const sycl::float2 ds8f =
+        ds8.convert<float, sycl::rounding_mode::automatic>();
+
+    // second part effectively subtracts 8 from each quant value
+    return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
+}
+
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ  4
+
+template <int vdr>
+static __dpct_inline__ float vec_dot_q4_1_q8_1_impl(const int *v, const int *u,
+                                                    const sycl::half2 &dm4,
+                                                    const sycl::half2 &ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+        // SIMD dot product of quantized values
+        sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
+        sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
+    }
+
+#ifdef GGML_SYCL_F16
+    const sycl::float2 tmp =
+        (dm4 * ds8).convert<float, sycl::rounding_mode::automatic>();
+    const float d4d8 = tmp.x();
+    const float m4s8 = tmp.y();
+#else
+    const sycl::float2 dm4f =
+        dm4.convert<float, sycl::rounding_mode::automatic>();
+    const sycl::float2 ds8f =
+        ds8.convert<float, sycl::rounding_mode::automatic>();
+    const float d4d8 = dm4f.x() * ds8f.x();
+    const float m4s8 = dm4f.y() * ds8f.y();
+#endif // GGML_SYCL_F16
+
+    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
+}
+
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ  4
+
+template <int vdr>
+static __dpct_inline__ float
+vec_dot_q5_0_q8_1_impl(const int *vl, const int *vh, const int *u,
+                       const float &d5, const sycl::half2 &ds8) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = dpct::dp4a(vi0, u[2 * i + 0],
+                          sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = dpct::dp4a(vi1, u[2 * i + 1],
+                          sumi); // SIMD dot product of quantized values
+    }
+
+    const sycl::float2 ds8f =
+        ds8.convert<float, sycl::rounding_mode::automatic>();
+
+    // second part effectively subtracts 16 from each quant value
+    return d5 * (sumi * ds8f.x() - (16 * vdr / QI5_0) * ds8f.y());
+}
+
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ  4
+
+template <int vdr>
+static __dpct_inline__ float
+vec_dot_q5_1_q8_1_impl(const int *vl, const int *vh, const int *u,
+                       const sycl::half2 &dm5, const sycl::half2 &ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = dpct::dp4a(vi0, u[2 * i + 0],
+                          sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = dpct::dp4a(vi1, u[2 * i + 1],
+                          sumi); // SIMD dot product of quantized values
+    }
+
+#ifdef GGML_SYCL_F16
+     const sycl::float2 tmp =
+        (dm5 * ds8).convert<float, sycl::rounding_mode::automatic>();
+    const float d5d8 = tmp.x();
+    const float m5s8 = tmp.y();
+
+
+#else
+    const sycl::float2 dm5f =
+        dm5.convert<float, sycl::rounding_mode::automatic>();
+    const sycl::float2 ds8f =
+        ds8.convert<float, sycl::rounding_mode::automatic>();
+    const float d5d8 = dm5f.x() * ds8f.x();
+    const float m5s8 = dm5f.y() * ds8f.y();
+#endif // GGML_SYCL_F16
+
+    // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
+}
+
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
+
+template <int vdr>
+static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u,
+                                                    const float &d8_0,
+                                                    const float &d8_1) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = dpct::dp4a(v[i], u[i], sumi);
+    }
+
+    return d8_0*d8_1 * sumi;
+}
+
+template <int vdr>
+static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,
+                                                    const sycl::half2 &dm8,
+                                                    const sycl::half2 &ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = dpct::dp4a(v[i], u[i], sumi);
+    }
+
+#ifdef GGML_SYCL_F16
+    const sycl::float2 tmp =
+        (dm8 * ds8).convert<float, sycl::rounding_mode::automatic>();
+    const float d8d8 = tmp.x();
+    const float m8s8 = tmp.y();
+#else
+    const sycl::float2 dm8f =
+        dm8.convert<float, sycl::rounding_mode::automatic>();
+    const sycl::float2 ds8f =
+        ds8.convert<float, sycl::rounding_mode::automatic>();
+    const float d8d8 = dm8f.x() * ds8f.x();
+    const float m8s8 = dm8f.y() * ds8f.y();
+#endif // GGML_SYCL_F16
+
+    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
+}
+
+static __dpct_inline__ float
+vec_dot_q4_0_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+
+    int v[VDR_Q4_0_Q8_1_MMVQ];
+    int u[2*VDR_Q4_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+        v[i]     = get_int_from_uint8(bq4_0->qs, iqs + i);
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
+    }
+
+    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
+}
+
+static __dpct_inline__ float
+vec_dot_q4_1_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+
+    int v[VDR_Q4_1_Q8_1_MMVQ];
+    int u[2*VDR_Q4_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+        v[i]    = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
+    }
+
+    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
+}
+
+static __dpct_inline__ float
+vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+
+    int vl[VDR_Q5_0_Q8_1_MMVQ];
+    int vh[VDR_Q5_0_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+        vl[i]    = get_int_from_uint8(bq5_0->qs, iqs + i);
+        vh[i]    = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
+    }
+
+    return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
+}
+
+static __dpct_inline__ float
+vec_dot_q5_1_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+
+    int vl[VDR_Q5_1_Q8_1_MMVQ];
+    int vh[VDR_Q5_1_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+        vl[i]   = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
+        vh[i]   = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
+    }
+
+    return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
+}
+
+static __dpct_inline__ float
+vec_dot_q8_0_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+
+    int v[VDR_Q8_0_Q8_1_MMVQ];
+    int u[VDR_Q8_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+        v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
+        u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+    }
+
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d,
+                                                      bq8_1->ds[0]);
+}
+
+static __dpct_inline__ float
+vec_dot_q2_K_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+
+    const int bq8_offset = QR2_K * (iqs / QI8_1);
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    const uint8_t * scales = bq2_K->scales + scale_offset;
+
+    const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
+    int    u[QR2_K];
+    float d8[QR2_K];
+
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++ i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+        d8[i] = bq8_1[bq8_offset + i].ds[0];
+    }
+
+    return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
+}
+
+static __dpct_inline__ float
+vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+
+    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    const float d = bq3_K->d;
+
+    const int vl = get_int_from_uint8(bq3_K->qs, iqs);
+
+    // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+    const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+    int    u[QR3_K];
+    float d8[QR3_K];
+
+#pragma unroll
+    for (int i = 0; i < QR3_K; ++i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+        d8[i] = bq8_1[bq8_offset + i].ds[0];
+    }
+
+    return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+static __dpct_inline__ float
+vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+#ifndef GGML_QKK_64
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+    int    v[2];
+    int    u[2*QR4_K];
+    float d8[QR4_K];
+
+    // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
+
+    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    v[0] = q4[0];
+    v[1] = q4[4];
+
+    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
+
+    for (int i = 0; i < QR4_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i] = bq8i->ds[0];
+
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+        u[2*i+0] = q8[0];
+        u[2*i+1] = q8[4];
+    }
+
+    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+
+#else
+
+#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    const uint16_t * a = (const uint16_t *)bq4_K->scales;
+    aux16[0] = a[0] & 0x0f0f;
+    aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+    const float dall = bq4_K->dm[0];
+    const float dmin = bq4_K->dm[1];
+
+    const float d8_1 = bq8_1[0].ds[0];
+    const float d8_2 = bq8_1[1].ds[1];
+
+    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
+
+    const int * q4 = (const int *)bq4_K->qs + (iqs/2);
+    const int v1 = q4[0];
+    const int v2 = q4[4];
+
+    const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0));
+    const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
+    const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0));
+    const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0));
+
+    sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
+    sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+
+    return dall * sumf_d - dmin * sumf_m;
+
+#else
+    bad_arch();
+#endif // __SYCL_ARCH__ >= VER_4VEC
+
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_q5_K_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+#ifndef GGML_QKK_64
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+    int   vl[2];
+    int   vh[2];
+    int    u[2*QR5_K];
+    float d8[QR5_K];
+
+    const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
+
+    vl[0] = ql[0];
+    vl[1] = ql[4];
+
+    vh[0] = qh[0] >> bq8_offset;
+    vh[1] = qh[4] >> bq8_offset;
+
+    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i] = bq8i->ds[0];
+
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+        u[2*i+0] = q8[0];
+        u[2*i+1] = q8[4];
+    }
+
+    return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
+
+#else
+
+#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+    const int8_t * s = bq5_K->scales;
+
+    const float d = bq5_K->d;
+
+    const float d8_1 = bq8_1[0].ds[0];
+    const float d8_2 = bq8_1[1].ds[1];
+
+    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
+
+    const int * ql = (const int *)bq5_K->qs + (iqs/2);
+    const int vl1 = ql[0];
+    const int vl2 = ql[4];
+
+    const int step = 4 * (iqs/2); // 0, 4, 8, 12
+    const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
+    const int in = step%8; // 0, 4, 0, 4
+    const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+
+    const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
+    const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
+    const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
+    const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+
+    const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1])
+                       + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]);
+
+    return d * sumf_d;
+
+#else
+    bad_arch();
+#endif // __SYCL_ARCH__ >= VER_4VEC
+
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_q6_K_q8_1(const void *__restrict__ vbq,
+                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+
+    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+    const int vl = get_int_from_uint8(bq6_K->ql, iqs);
+    const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
+
+    const int8_t * scales = bq6_K->scales + scale_offset;
+
+    int    u[QR6_K];
+    float d8[QR6_K];
+
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+        d8[i] = bq8_1[bq8_offset + 2 * i].ds[0];
+    }
+
+    return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
+}
+
+
+static __dpct_inline__ float
+vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
+                     const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+                     const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs,
+                     const uint8_t *kmask_iq2xs) {
+#if QK_K == 256
+    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
+
+#if QR2_XXS == 8
+    const int ib32 = iqs;
+    const uint16_t * q2 = bq2->qs + 4*ib32;
+    const uint8_t  * aux8 = (const uint8_t *)q2;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    uint32_t aux32 = q2[2] | (q2[3] << 16);
+    int sumi = 0;
+    for (int l = 0; l < 4; ++l) {
+        const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+        const uint8_t  signs = ksigns_iq2xs[aux32 & 127];
+        for (int j = 0; j < 8; ++j) {
+            sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+        }
+        q8 += 8;
+        aux32 >>= 7;
+    }
+    const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f;
+    return d * sumi;
+#else
+    // iqs is 0...15
+    const int ib32 = iqs/2;
+    const int il = iqs%2;
+    const uint16_t * q2 = bq2->qs + 4*ib32;
+    const uint8_t  * aux8 = (const uint8_t *)q2;
+    const uint8_t  * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+    const uint8_t  * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+    const uint32_t aux32 = q2[2] | (q2[3] << 16);
+    const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * bq8_1[ib32].ds[0] * 0.25f;
+    const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+    const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
+    const int8_t * q8 = bq8_1[ib32].qs + 16*il;
+    int sumi1 = 0, sumi2 = 0;
+    for (int j = 0; j < 8; ++j) {
+        sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
+        sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
+    }
+    return d * (sumi1 + sumi2);
+#endif
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
+                    const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+                    const uint64_t *iq2xs_grid, const uint64_t *ksigns64) {
+#if DPCT_COMPATIBILITY_TEMP >=                                                 \
+    MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#if QK_K == 256
+    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
+
+    const int ib32 = iqs;
+    const uint16_t * q2 = bq2->qs + 4*ib32;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    const uint8_t ls1 = bq2->scales[ib32] & 0xf;
+    const uint8_t ls2 = bq2->scales[ib32] >>  4;
+    int sumi1 = 0;
+    for (int l = 0; l < 2; ++l) {
+        const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
+        const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+            grid[0] ^ signs[0], signs[0], std::minus<>());
+        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+            grid[1] ^ signs[1], signs[1], std::minus<>());
+        sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
+        sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
+        q8 += 8;
+    }
+    int sumi2 = 0;
+    for (int l = 2; l < 4; ++l) {
+        const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
+        const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+            grid[0] ^ signs[0], signs[0], std::minus<>());
+        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+            grid[1] ^ signs[1], signs[1], std::minus<>());
+        sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
+        sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
+        q8 += 8;
+    }
+    const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
+    return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
+#else
+    assert(false);
+    return 0.f;
+#endif
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
+                   const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+#if QK_K == 256
+    const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
+
+    const int ib32 = iqs;
+    const int8_t  * q8 = bq8_1[ib32].qs;
+    const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
+    const uint8_t ls1 = bq2->scales[ib32] & 0xf;
+    const uint8_t ls2 = bq2->scales[ib32] >>  4;
+    int sumi1 = 0;
+    for (int l = 0; l < 2; ++l) {
+        const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
+        const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
+            ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
+            std::equal_to<>());
+        const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
+            ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
+            std::equal_to<>());
+        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+            grid[0] ^ signs0, signs0, std::minus<>());
+        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+            grid[1] ^ signs1, signs1, std::minus<>());
+        sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
+        sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
+        q8 += 8;
+    }
+    int sumi2 = 0;
+    for (int l = 2; l < 4; ++l) {
+        const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
+        const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
+            ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
+            std::equal_to<>());
+        const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
+            ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
+            std::equal_to<>());
+        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+            grid[0] ^ signs0, signs0, std::minus<>());
+        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+            grid[1] ^ signs1, signs1, std::minus<>());
+        sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
+        sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
+        q8 += 8;
+    }
+    const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
+    return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
+#else
+    assert(false);
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
+                     const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+                     const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) {
+#if DPCT_COMPATIBILITY_TEMP >=                                                 \
+    MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#if QK_K == 256
+    const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
+
+    const int ib32 = iqs;
+    const uint8_t  * q3 = bq2->qs + 8*ib32;
+    const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    uint32_t aux32 = gas[0] | (gas[1] << 16);
+    int sumi = 0;
+    for (int l = 0; l < 4; ++l) {
+        const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
+        const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
+        const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
+        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+            grid1[0] ^ signs[0], signs[0], std::minus<>());
+        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+            grid2[0] ^ signs[1], signs[1], std::minus<>());
+        sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
+        sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
+        q8 += 8;
+        aux32 >>= 7;
+    }
+    const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f;
+    return d * sumi;
+#else
+    assert(false);
+    return 0.f;
+#endif
+#else
+    assert(false);
+    return 0.f;
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
+                   const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+                   const uint32_t *iq3s_grid) {
+#if QK_K == 256
+    const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
+
+    const int ib32 = iqs;
+    const uint8_t  * qs = bq2->qs + 8*ib32;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    int sumi = 0;
+    for (int l = 0; l < 4; ++l) {
+        const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
+        const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
+        uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
+            ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,
+            0x08040201, std::equal_to<>());
+        uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
+            ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,
+            0x08040201, std::equal_to<>());
+        const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+            grid1[0] ^ signs0, signs0, std::minus<>());
+        const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+            grid2[0] ^ signs1, signs1, std::minus<>());
+        sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
+        sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
+        q8 += 8;
+    }
+    const float d =
+        (float)bq2->d *
+        (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
+        bq8_1[ib32].ds[0];
+    return d * sumi;
+#else
+    assert(false);
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
+                   const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+                   const uint32_t *iq1s_grid_gpu) {
+#if QK_K == 256
+    const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
+
+    const int ib32 = iqs;
+    int sumi = 0;
+    const int * q8 = (const int *)bq8_1[ib32].qs;
+    for (int l = 0; l < 4; ++l) {
+        const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
+        int grid0 = grid[0] & 0x0f0f0f0f;
+        int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
+        sumi = dpct::dp4a(q8[2 * l + 1], grid1,
+                          dpct::dp4a(q8[2 * l + 0], grid0, sumi));
+    }
+
+    const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
+    const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
+    const float d = d1q * bq8_1[ib32].ds[0];
+    const float m = d1q * bq8_1[ib32].ds[1];
+    return d * sumi + m * delta;
+#else
+    assert(false);
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
+                   const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+#if QK_K == 256
+    const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
+
+    const int ib32 = iqs;
+    int   sumi[2] = {0, 0};
+    float sumf[2] = {0.f, 0.f};
+
+    const int * q8 = (const int *)bq8_1[ib32].qs;
+    for (int l = 0; l < 4; ++l) {
+        const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
+        int grid0 = grid[0] & 0x0f0f0f0f;
+        int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
+        sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,
+                                 dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));
+        const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
+        const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,
+                                    dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));
+        sumf[l/2] += delta*sumy;
+    }
+
+    iq1m_scale_t scale;
+    const uint16_t * sc = (const uint16_t *)bq1->scales;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
+    return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
+#else
+    assert(false);
+#endif
+}
+
+
+static __dpct_inline__ float
+vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
+                    const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+    const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
+
+    const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
+    const int32_t  * q8 = (const int32_t  *)bq8_1->qs + iqs;
+
+    const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
+
+    int v1, v2;
+    int sumi1 = 0, sumi2 = 0;
+    for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
+        const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
+        get_int_from_table_16(aux, values, v1, v2);
+        sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);
+        sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);
+    }
+
+    const float d = (float)bq->d * bq8_1->ds[0];
+    return d * (sumi1 + sumi2);
+}
+
+
+static __dpct_inline__ float
+vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
+                    const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+#if QK_K == 256
+    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
+    const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
+
+    // iqs is 0...7
+    const int ib32 = iqs;
+    const int32_t  * q8 = (const int *)bq8_1[ib32].qs;
+    const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
+    const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
+    const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];
+    int v1, v2;
+    int sumi1 = 0, sumi2 = 0;
+    for (int j = 0; j < 4; ++j) {
+        get_int_from_table_16(q4[j], values, v1, v2);
+        sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);
+        sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
+    }
+    return d * (sumi1 + sumi2);
+#else
+    assert(false);
+#endif
+}
+
+#endif // GGML_SYCL_VECDOTQ_HPP
index f389934ead3ed3b7c82dbe34503ceaeb6c568f46..101781ede4b4fd303c95acd72ed1ebdb9858ec73 100644 (file)
@@ -17,6 +17,8 @@
 #include <memory>
 #include <limits>
 #include <map>
+#include <memory>
+#include <mutex>
 
 #include "ggml.h"
 #include "ggml-backend-impl.h"
@@ -103,7 +105,41 @@ struct vk_matmul_pipeline_struct {
 
 typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
 
-struct vk_device {
+struct vk_device_struct;
+typedef std::shared_ptr<vk_device_struct> vk_device;
+typedef std::weak_ptr<vk_device_struct> vk_device_ref;
+
+struct vk_buffer_struct;
+typedef std::shared_ptr<vk_buffer_struct> vk_buffer;
+typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;
+
+struct ggml_backend_vk_buffer_type_context {
+    std::string name;
+    vk_device device;
+};
+
+GGML_CALL static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
+GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
+static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_vk_buffer_type_name,
+    /* .alloc_buffer     = */ ggml_backend_vk_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_vk_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ ggml_backend_vk_buffer_type_get_alloc_size,
+    /* .is_host          = */ NULL,
+};
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+class vk_memory_logger;
+#endif
+static void ggml_vk_destroy_buffer(vk_buffer& buf);
+
+struct vk_device_struct {
+    std::mutex mutex;
+
     vk::PhysicalDevice physical_device;
     vk::PhysicalDeviceProperties properties;
     std::string name;
@@ -118,7 +154,6 @@ struct vk_device {
     uint32_t subgroup_size;
     bool uma;
 
-    bool initialized;
     size_t idx;
 
     vk_matmul_pipeline pipeline_matmul_f32;
@@ -165,8 +200,24 @@ struct vk_device {
 
     std::vector<vk_pipeline_ref> pipelines;
 
-    ~vk_device() {
+    std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
+
+    vk::Fence fence;
+    vk_buffer sync_staging;
+
+    ggml_backend_buffer_type buffer_type;
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+    std::unique_ptr<vk_memory_logger> memory_logger;
+#endif
+
+    ~vk_device_struct() {
         VK_LOG_DEBUG("destroy device " << name);
+
+        device.destroyFence(fence);
+
+        ggml_vk_destroy_buffer(sync_staging);
+
         device.destroyCommandPool(compute_queue.pool);
         if (!single_queue) {
             device.destroyCommandPool(transfer_queue.pool);
@@ -193,9 +244,7 @@ struct vk_buffer_struct {
     void * ptr;
     size_t size = 0;
 
-    ggml_backend_vk_context * ctx;
-
-    std::shared_ptr<vk_device> device;
+    vk_device device;
 
     ~vk_buffer_struct() {
         if (size == 0) {
@@ -208,9 +257,6 @@ struct vk_buffer_struct {
     }
 };
 
-typedef std::shared_ptr<vk_buffer_struct> vk_buffer;
-typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;
-
 struct vk_subbuffer {
     vk_buffer buffer;
     uint64_t offset;
@@ -359,8 +405,6 @@ struct ggml_vk_garbage_collector {
 };
 
 #if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG)
-#include <mutex>
-
 #define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl
 
 static std::string format_size(size_t size) {
@@ -404,31 +448,21 @@ private:
 struct ggml_backend_vk_context {
     std::string name;
 
-    std::shared_ptr<vk_device> device;
+    vk_device device;
 
     size_t semaphore_idx, event_idx;
     ggml_vk_garbage_collector gc;
-    std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
     size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
     vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
     vk::Fence fence;
     vk_buffer staging;
     size_t staging_size;
     size_t staging_offset;
-    vk_buffer sync_staging;
 
     vk_buffer buffer_pool[MAX_VK_BUFFERS];
 
     vk_context * compute_ctx;
     vk_context * transfer_ctx;
-
-    bool initialized;
-
-    size_t idx;
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-    vk_memory_logger memory_logger;
-#endif
 };
 
 #ifdef GGML_VULKAN_MEMORY_DEBUG
@@ -440,7 +474,7 @@ void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
     allocations[buf->buffer] = size;
     total_device += device ? size : 0;
     total_host += device ? 0 : size;
-    VK_LOG_MEMORY("VULKAN" << buf->ctx->idx << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
+    VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
 }
 
 void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
@@ -456,10 +490,10 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
     total_device -= device ? it->second : 0;
     total_host -= device ? 0 : it->second;
     if (it != allocations.end()) {
-        VK_LOG_MEMORY("VULKAN" << buf->ctx->idx << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
+        VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
         allocations.erase(it);
     } else {
-        VK_LOG_MEMORY("ERROR VULKAN" << buf->ctx->idx << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
+        VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
     }
 }
 #endif // GGML_VULKAN_MEMORY_DEBUG
@@ -468,49 +502,32 @@ struct vk_instance_t {
     vk::Instance instance;
 
     std::vector<size_t> device_indices;
-
-    ggml_backend_t backends[GGML_VK_MAX_DEVICES];
-    ggml_backend_vk_context contexts[GGML_VK_MAX_DEVICES];
-    ggml_backend_buffer_type buffer_types[GGML_VK_MAX_DEVICES];
-    bool initialized[GGML_VK_MAX_DEVICES];
+    vk_device devices[GGML_VK_MAX_DEVICES];
 };
 
-static std::shared_ptr<vk_device> ggml_vk_get_device(size_t idx) {
-    VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
-    static std::weak_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
-
-    if (devices[idx].expired()) {
-        VK_LOG_DEBUG("Initializing new vk_device");
-        std::shared_ptr<vk_device> device = std::make_shared<vk_device>();
-        device->initialized = false;
-        devices[idx] = device;
-        return device;
-    }
-
-    return devices[idx].lock();
-}
+static bool vk_instance_initialized = false;
+static vk_instance_t vk_instance;
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
 static size_t vk_skip_checks;
 static size_t vk_output_tensor;
 
 static void ggml_vk_print_tensor(ggml_backend * ctx, const ggml_tensor * tensor, const char * name);
-static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor);
-static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor);
+static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
+static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
 #endif
 
 typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
 
-static bool vk_instance_initialized = false;
-static vk_instance_t vk_instance;
-
 GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
 
-static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
-    VK_LOG_DEBUG("ggml_vk_create_pipeline(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
+static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
+    VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
     GGML_ASSERT(parameter_count > 0);
     GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
 
+    std::lock_guard<std::mutex> guard(device->mutex);
+
     pipeline = std::make_shared<vk_pipeline_struct>();
     pipeline->name = name;
     pipeline->parameter_count = parameter_count;
@@ -519,7 +536,7 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
     pipeline->align = align;
 
     vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
-    pipeline->shader_module = ctx->device->device.createShaderModule(shader_module_create_info);
+    pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
 
     std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
     std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
@@ -540,17 +557,17 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
         {},
         dsl_binding);
     descriptor_set_layout_create_info.setPNext(&dslbfci);
-    pipeline->dsl = ctx->device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
+    pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
 
     // Check if device supports multiple descriptors per pool
-    if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
+    if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
         const uint32_t alloc_count = 2;
 
         // Try allocating multiple sets from one pool
         // This fails on AMD for some reason, so add a fall back to allocating one pool per set
         vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
         vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, alloc_count, descriptor_pool_size);
-        vk::DescriptorPool pool = ctx->device->device.createDescriptorPool(descriptor_pool_create_info);
+        vk::DescriptorPool pool = device->device.createDescriptorPool(descriptor_pool_create_info);
 
         std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
         for (uint32_t i = 0; i < alloc_count; i++) {
@@ -558,24 +575,24 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
         }
         try {
             vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pool, alloc_count, layouts.data());
-            std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+            std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
         } catch(vk::OutOfPoolMemoryError const&) {
-            ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
+            device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
         }
 
-        ctx->device->device.destroyDescriptorPool(pool);
+        device->device.destroyDescriptorPool(pool);
     }
 
-    if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
+    if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
         vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
         vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 128, descriptor_pool_size);
-        pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
+        pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
     }
 
     pipeline->descriptor_set_idx = 0;
 
     vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
-    pipeline->layout = ctx->device->device.createPipelineLayout(pipeline_layout_create_info);
+    pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
 
     std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
 
@@ -602,9 +619,9 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
         vk::PipelineCreateFlags(),
         pipeline_shader_create_info,
         pipeline->layout);
-    pipeline->pipeline = ctx->device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
+    pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
 
-    ctx->device->pipelines.push_back(pipeline);
+    device->pipelines.push_back(pipeline);
 }
 
 static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -625,14 +642,16 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline)
     device.destroyPipeline(pipeline->pipeline);
 }
 
-static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, uint32_t n) {
+static void ggml_pipeline_allocate_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
     VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")");
     if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
         // Enough descriptors are available
         return;
     }
 
-    if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
+    std::lock_guard<std::mutex> guard(device->mutex);
+
+    if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
         const uint32_t alloc_count = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
 
         std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
@@ -640,16 +659,16 @@ static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx
             layouts[i] = pipeline->dsl;
         }
         vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[0], alloc_count, layouts.data());
-        std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+        std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
         pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
     } else {
         for (uint32_t i = pipeline->descriptor_sets.size(); i < pipeline->descriptor_set_idx + n; i++) {
             vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
             vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 1, descriptor_pool_size);
-            pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
+            pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
 
             vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[i], 1, &pipeline->dsl);
-            std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+            std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
             pipeline->descriptor_sets.push_back(sets[0]);
         }
     }
@@ -660,8 +679,10 @@ static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
     pipeline->descriptor_set_idx = 0;
 }
 
-static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx, vk_queue& q) {
+static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) {
     VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
+    std::lock_guard<std::mutex> guard(device->mutex);
+
     if (q.cmd_buffers.size() > q.cmd_buffer_idx) {
         // Reuse command buffer
         return q.cmd_buffers[q.cmd_buffer_idx++];
@@ -671,7 +692,7 @@ static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx
         q.pool,
         vk::CommandBufferLevel::ePrimary,
         1);
-    const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device->device.allocateCommandBuffers(command_buffer_alloc_info);
+    const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
     auto buf = cmd_buffers.front();
 
     q.cmd_buffers.push_back(buf);
@@ -680,10 +701,10 @@ static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx
     return buf;
 }
 
-static vk_submission ggml_vk_create_submission(ggml_backend_vk_context * ctx, vk_queue& q, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
+static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
     VK_LOG_DEBUG("ggml_vk_create_submission()");
     vk_submission s;
-    s.buffer = ggml_vk_create_cmd_buffer(ctx, q);
+    s.buffer = ggml_vk_create_cmd_buffer(device, q);
     s.wait_semaphores = std::move(wait_semaphores);
     s.signal_semaphores = std::move(signal_semaphores);
     return s;
@@ -809,16 +830,18 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
     abort();
 }
 
-static void ggml_vk_create_queue(ggml_backend_vk_context * ctx, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags) {
+static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags) {
     VK_LOG_DEBUG("ggml_vk_create_queue()");
+    std::lock_guard<std::mutex> guard(device->mutex);
+
     q.queue_family_index = queue_family_index;
 
     vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
-    q.pool = ctx->device->device.createCommandPool(command_pool_create_info_compute);
+    q.pool = device->device.createCommandPool(command_pool_create_info_compute);
 
     q.cmd_buffer_idx = 0;
 
-    q.queue = ctx->device->device.getQueue(queue_family_index, queue_index);
+    q.queue = device->device.getQueue(queue_family_index, queue_index);
 
     q.stage_flags = stage_flags;
 }
@@ -833,6 +856,15 @@ static vk_context * ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_que
     return result;
 }
 
+static vk_context * ggml_vk_create_temporary_context(vk_queue& q) {
+    VK_LOG_DEBUG("ggml_vk_create_temporary_context()");
+    vk_context * result = new vk_context;
+    memset((void *) result, 0, sizeof(vk_context));
+    result->idx = 0;
+    result->q = &q;
+    return result;
+}
+
 static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
     VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
     vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
@@ -862,11 +894,12 @@ static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
     return ctx->gc.events[ctx->event_idx++];
 }
 
-static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
+static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) {
     VK_LOG_DEBUG("ggml_vk_queue_cleanup()");
-    // Requires command buffers to be done
+    std::lock_guard<std::mutex> guard(device->mutex);
 
-    ctx->device->device.resetCommandPool(q.pool);
+    // Requires command buffers to be done
+    device->device.resetCommandPool(q.pool);
     q.cmd_buffer_idx = 0;
 }
 
@@ -882,8 +915,10 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr
     return UINT32_MAX;
 }
 
-static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
-    VK_LOG_DEBUG("ggml_vk_create_buffer(device " << ctx->idx << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
+static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
+    VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
+    std::lock_guard<std::mutex> guard(device->mutex);
+
     vk_buffer buf = std::make_shared<vk_buffer_struct>();
 
     if (size == 0) {
@@ -901,11 +936,11 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
         nullptr,
     };
 
-    buf->buffer = ctx->device->device.createBuffer(buffer_create_info);
+    buf->buffer = device->device.createBuffer(buffer_create_info);
 
-    vk::MemoryRequirements mem_req = ctx->device->device.getBufferMemoryRequirements(buf->buffer);
+    vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
 
-    vk::PhysicalDeviceMemoryProperties mem_props = ctx->device->physical_device.getMemoryProperties();
+    vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
 
     uint32_t memory_type_index = UINT32_MAX;
 
@@ -918,41 +953,39 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
     }
 
     if (memory_type_index == UINT32_MAX) {
-        ctx->device->device.destroyBuffer(buf->buffer);
+        device->device.destroyBuffer(buf->buffer);
         buf->size = 0;
         throw vk::OutOfDeviceMemoryError("No suitable memory type found");
     }
 
     try {
-        buf->device_memory = ctx->device->device.allocateMemory({ mem_req.size, memory_type_index });
+        buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
     } catch (const vk::SystemError& e) {
         // Out of Host/Device memory, clean up buffer
-        ctx->device->device.destroyBuffer(buf->buffer);
+        device->device.destroyBuffer(buf->buffer);
         buf->size = 0;
         throw e;
     }
     buf->ptr = nullptr;
 
     if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
-        buf->ptr = ctx->device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
+        buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
     }
 
-    ctx->device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
-
-    buf->ctx = ctx;
+    device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
 
-    buf->device = ctx->device;
+    buf->device = device;
 
 #ifdef GGML_VULKAN_MEMORY_DEBUG
-    ctx->memory_logger.log_allocation(buf, size);
+    device->memory_logger->log_allocation(buf, size);
 #endif
 
     return buf;
 }
 
-static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
+static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
     try {
-        return ggml_vk_create_buffer(ctx, size, req_flags, fallback_flags);
+        return ggml_vk_create_buffer(device, size, req_flags, fallback_flags);
     } catch (const vk::SystemError& e) {
         std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
         std::cerr << "ggml_vulkan: " << e.what() << std::endl;
@@ -960,14 +993,14 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
     }
 }
 
-static vk_buffer ggml_vk_create_buffer_device(ggml_backend_vk_context * ctx, size_t size) {
+static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
     vk_buffer buf;
     try {
-        if (ctx->device->uma) {
+        if (device->uma) {
             // Fall back to host memory type
-            buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+            buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
         } else {
-            buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
+            buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
         }
     } catch (const vk::SystemError& e) {
         std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
@@ -984,7 +1017,9 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) {
     }
 
 #ifdef GGML_VULKAN_MEMORY_DEBUG
-    buf->ctx->memory_logger.log_deallocation(buf);
+    if (buf->device != nullptr) {
+        buf->device->memory_logger->log_deallocation(buf);
+    }
 #endif
 
     buf.reset();
@@ -1043,10 +1078,8 @@ static bool ggml_vk_build_shader(ggml_type type) {
     }
 }
 
-static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
-    VK_LOG_DEBUG("ggml_vk_load_shaders(" << ctx->name << ")");
-
-    const std::shared_ptr<vk_device> device = ctx->device;
+static void ggml_vk_load_shaders(vk_device& device) {
+    VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
 
     // mulmat
     std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
@@ -1065,534 +1098,707 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
     uint32_t m_align =  64;
     uint32_t s_align =  32;
 
-    ctx->device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
-
-    ctx->device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
+
+    device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
+    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
 
     if (device->fp16) {
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
     } else {
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
     }
 
     // mul mat vec
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32",  mul_mat_vec_f32_f32_f32_len,  mul_mat_vec_f32_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32",  mul_mat_vec_f16_f32_f32_len,  mul_mat_vec_f16_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32",  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32",  mul_mat_vec_f16_f16_f32_len,  mul_mat_vec_f16_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",  mul_mat_vec_id_f32_f32_len,  mul_mat_vec_id_f32_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32",  mul_mat_vec_id_f16_f32_len,  mul_mat_vec_id_f16_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32",  mul_mat_vec_f32_f32_f32_len,  mul_mat_vec_f32_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32",  mul_mat_vec_f16_f32_f32_len,  mul_mat_vec_f16_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32",  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32",  mul_mat_vec_f16_f16_f32_len,  mul_mat_vec_f16_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",  mul_mat_vec_id_f32_f32_len,  mul_mat_vec_id_f32_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32",  mul_mat_vec_id_f16_f32_len,  mul_mat_vec_id_f16_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
     // dequant shaders
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   dequant_f32_len,  dequant_f32_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   dequant_f32_len,  dequant_f32_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
 
     // get_rows
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32",  get_rows_f32_len,  get_rows_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16",  get_rows_f16_len,  get_rows_f16_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32",  get_rows_f32_len,  get_rows_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16",  get_rows_f16_len,  get_rows_f16_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 }
 
+static vk_device ggml_vk_get_device(size_t idx) {
+    VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
+
+    if (vk_instance.devices[idx] == nullptr) {
+        VK_LOG_DEBUG("Initializing new vk_device");
+        vk_device device = std::make_shared<vk_device_struct>();
+        vk_instance.devices[idx] = device;
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+        device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
+#endif
+
+        size_t dev_num = vk_instance.device_indices[idx];
+
+        std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices();
+
+        if (dev_num >= physical_devices.size()) {
+            std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
+            throw std::runtime_error("Device not found");
+        }
+
+        device->physical_device = physical_devices[dev_num];
+        const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
+
+        bool maintenance4_support = false;
+
+        // Check if maintenance4 is supported
+        for (const auto& properties : ext_props) {
+            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
+                maintenance4_support = true;
+            }
+        }
+
+        vk::PhysicalDeviceProperties2 props2;
+        vk::PhysicalDeviceMaintenance3Properties props3;
+        vk::PhysicalDeviceMaintenance4Properties props4;
+        vk::PhysicalDeviceSubgroupProperties subgroup_props;
+        props2.pNext = &props3;
+        props3.pNext = &subgroup_props;
+        if (maintenance4_support) {
+            subgroup_props.pNext = &props4;
+        }
+        device->physical_device.getProperties2(&props2);
+        device->properties = props2.properties;
+
+        const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
+
+        if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
+            device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
+        } else if (maintenance4_support) {
+            device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
+        } else {
+            device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
+        }
+
+        device->vendor_id = device->properties.vendorID;
+        device->subgroup_size = subgroup_props.subgroupSize;
+        device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
+
+        bool fp16_storage = false;
+        bool fp16_compute = false;
+
+        for (const auto& properties : ext_props) {
+            if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
+                fp16_storage = true;
+            } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
+                fp16_compute = true;
+            }
+        }
+
+        const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
+        const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
+
+        device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
+
+        std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
+
+        // Try to find a non-graphics compute queue and transfer-focused queues
+        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
+        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
+
+        const float priorities[] = { 1.0f, 1.0f };
+        device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
+
+        std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
+        if (compute_queue_family_index != transfer_queue_family_index) {
+            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
+            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
+        } else if(!device->single_queue) {
+            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
+        } else {
+            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
+        }
+        vk::DeviceCreateInfo device_create_info;
+        std::vector<const char *> device_extensions;
+        vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
+
+        VkPhysicalDeviceFeatures2 device_features2;
+        device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+        device_features2.pNext = nullptr;
+        device_features2.features = (VkPhysicalDeviceFeatures)device_features;
+
+        VkPhysicalDeviceVulkan11Features vk11_features;
+        vk11_features.pNext = nullptr;
+        vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
+        device_features2.pNext = &vk11_features;
+
+        VkPhysicalDeviceVulkan12Features vk12_features;
+        vk12_features.pNext = nullptr;
+        vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
+        vk11_features.pNext = &vk12_features;
+
+        vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
+
+        device->fp16 = device->fp16 && vk12_features.shaderFloat16;
+
+        if (!vk11_features.storageBuffer16BitAccess) {
+            std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
+            throw std::runtime_error("Unsupported device");
+        }
+
+        device_extensions.push_back("VK_KHR_16bit_storage");
+
+#ifdef GGML_VULKAN_VALIDATE
+        device_extensions.push_back("VK_KHR_shader_non_semantic_info");
+#endif
+
+        if (device->fp16) {
+            device_extensions.push_back("VK_KHR_shader_float16_int8");
+        }
+        device->name = device->properties.deviceName.data();
+
+        device_create_info = {
+            vk::DeviceCreateFlags(),
+            device_queue_create_infos,
+            {},
+            device_extensions
+        };
+        device_create_info.setPNext(&device_features2);
+        device->device = device->physical_device.createDevice(device_create_info);
+
+        device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
+
+        // Queues
+        ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
+
+        // Shaders
+        ggml_vk_load_shaders(device);
+
+        if (!device->single_queue) {
+            const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
+            ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
+        } else {
+            // TODO: Use pointer or reference to avoid copy
+            device->transfer_queue = device->compute_queue;
+        }
+
+        device->buffer_type = {
+            /* .iface    = */ ggml_backend_vk_buffer_type_interface,
+            /* .context  = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
+        };
+
+        device->fence = device->device.createFence({});
+
+        device->idx = idx;
+
+        return device;
+    }
+
+    return vk_instance.devices[idx];
+}
+
+
 static void ggml_vk_print_gpu_info(size_t idx) {
     GGML_ASSERT(idx < vk_instance.device_indices.size());
     size_t dev_num = vk_instance.device_indices[idx];
     VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")");
-    GGML_ASSERT(vk_instance.initialized);
+    GGML_ASSERT(vk_instance_initialized);
 
     std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
 
@@ -1670,6 +1876,8 @@ void ggml_vk_instance_init() {
     }
     VK_LOG_DEBUG("ggml_vk_instance_init()");
 
+    vk_instance_initialized = true;
+
     vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
 
     const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
@@ -1715,8 +1923,6 @@ void ggml_vk_instance_init() {
     }
     vk_instance.instance = vk::createInstance(instance_create_info);
 
-    memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES);
-
     size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
 
     // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -1745,31 +1951,37 @@ void ggml_vk_instance_init() {
 
         // Default to using all dedicated GPUs
         for (size_t i = 0; i < devices.size(); i++) {
-            vk::PhysicalDeviceProperties props = devices[i].getProperties();
-
-            if (props.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
+            vk::PhysicalDeviceProperties2 new_props;
+            vk::PhysicalDeviceDriverProperties new_driver;
+            vk::PhysicalDeviceIDProperties new_id;
+            new_props.pNext = &new_driver;
+            new_driver.pNext = &new_id;
+            devices[i].getProperties2(&new_props);
+
+            if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
                 // Check if there are two physical devices corresponding to the same GPU
                 auto old_device = std::find_if(
                     vk_instance.device_indices.begin(),
                     vk_instance.device_indices.end(),
-                    [&devices, &props](const size_t k){ return devices[k].getProperties().deviceID == props.deviceID; }
+                    [&devices, &new_id](const size_t k){
+                        vk::PhysicalDeviceProperties2 old_props;
+                        vk::PhysicalDeviceIDProperties old_id;
+                        old_props.pNext = &old_id;
+                        devices[k].getProperties2(&old_props);
+                        return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
+                    }
                 );
                 if (old_device == vk_instance.device_indices.end()) {
                     vk_instance.device_indices.push_back(i);
                 } else {
                     // There can be two physical devices corresponding to the same GPU if there are 2 different drivers
                     // This can cause error when splitting layers aross the devices, need to keep only 1
-                    VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same device id");
+                    VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
 
-                    vk::PhysicalDeviceProperties2 old_prop;
+                    vk::PhysicalDeviceProperties2 old_props;
                     vk::PhysicalDeviceDriverProperties old_driver;
-                    old_prop.pNext = &old_driver;
-                    devices[*old_device].getProperties2(&old_prop);
-
-                    vk::PhysicalDeviceProperties2 new_prop;
-                    vk::PhysicalDeviceDriverProperties new_driver;
-                    new_prop.pNext = &new_driver;
-                    devices[i].getProperties2(&new_prop);
+                    old_props.pNext = &old_driver;
+                    devices[*old_device].getProperties2(&old_props);
 
                     std::map<vk::DriverId, int> driver_priorities {};
                     int old_priority = std::numeric_limits<int>::max();
@@ -1777,7 +1989,7 @@ void ggml_vk_instance_init() {
 
                     // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
                     // Smaller number -> higher priority
-                    switch (old_prop.properties.vendorID) {
+                    switch (old_props.properties.vendorID) {
                         case VK_VENDOR_ID_AMD:
                             driver_priorities[vk::DriverId::eMesaRadv] = 1;
                             driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
@@ -1827,177 +2039,32 @@ void ggml_vk_instance_init() {
     for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
         ggml_vk_print_gpu_info(i);
     }
-
-    vk_instance_initialized = true;
 }
 
 static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
     GGML_ASSERT(idx < vk_instance.device_indices.size());
-    size_t dev_num = vk_instance.device_indices[idx];
-    VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << dev_num << ")");
+    VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")");
     ggml_vk_instance_init();
 
-    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
-
-    if (dev_num >= devices.size()) {
-        std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
-        throw std::runtime_error("Device not found");
-    }
+    ctx->name = GGML_VK_NAME + std::to_string(idx);
 
     ctx->device = ggml_vk_get_device(idx);
-    if (!ctx->device->initialized) {
-        ctx->device->physical_device = devices[dev_num];
-        const std::vector<vk::ExtensionProperties> ext_props = ctx->device->physical_device.enumerateDeviceExtensionProperties();
-
-        bool maintenance4_support = false;
-
-        // Check if maintenance4 is supported
-        for (const auto& properties : ext_props) {
-            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
-                maintenance4_support = true;
-            }
-        }
 
-        vk::PhysicalDeviceProperties2 props2;
-        vk::PhysicalDeviceMaintenance3Properties props3;
-        vk::PhysicalDeviceMaintenance4Properties props4;
-        vk::PhysicalDeviceSubgroupProperties subgroup_props;
-        props2.pNext = &props3;
-        props3.pNext = &subgroup_props;
-        if (maintenance4_support) {
-            subgroup_props.pNext = &props4;
-        }
-        ctx->device->physical_device.getProperties2(&props2);
-        ctx->device->properties = props2.properties;
-
-        const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
-
-        if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
-            ctx->device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
-        } else if (maintenance4_support) {
-            ctx->device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
-        } else {
-            ctx->device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
-        }
-
-        ctx->device->vendor_id = ctx->device->properties.vendorID;
-        ctx->device->subgroup_size = subgroup_props.subgroupSize;
-        ctx->device->uma = ctx->device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
-
-        bool fp16_storage = false;
-        bool fp16_compute = false;
-
-        for (const auto& properties : ext_props) {
-            if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
-                fp16_storage = true;
-            } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
-                fp16_compute = true;
-            }
-        }
-
-        const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
-        const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
-
-        ctx->device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
-
-        std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device->physical_device.getQueueFamilyProperties();
-
-        // Try to find a non-graphics compute queue and transfer-focused queues
-        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
-        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
-
-        const float priorities[] = { 1.0f, 1.0f };
-        ctx->device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
-
-        std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
-        if (compute_queue_family_index != transfer_queue_family_index) {
-            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
-            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
-        } else if(!ctx->device->single_queue) {
-            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
-        } else {
-            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
-        }
-        vk::DeviceCreateInfo device_create_info;
-        std::vector<const char *> device_extensions;
-        vk::PhysicalDeviceFeatures device_features = ctx->device->physical_device.getFeatures();
-
-        VkPhysicalDeviceFeatures2 device_features2;
-        device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
-        device_features2.pNext = nullptr;
-        device_features2.features = (VkPhysicalDeviceFeatures)device_features;
-
-        VkPhysicalDeviceVulkan11Features vk11_features;
-        vk11_features.pNext = nullptr;
-        vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
-        device_features2.pNext = &vk11_features;
-
-        VkPhysicalDeviceVulkan12Features vk12_features;
-        vk12_features.pNext = nullptr;
-        vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
-        vk11_features.pNext = &vk12_features;
-
-        vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2);
-
-        ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16;
-
-        if (!vk11_features.storageBuffer16BitAccess) {
-            std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
-            throw std::runtime_error("Unsupported device");
-        }
-
-        device_extensions.push_back("VK_KHR_16bit_storage");
-
-#ifdef GGML_VULKAN_VALIDATE
-        device_extensions.push_back("VK_KHR_shader_non_semantic_info");
-#endif
-
-        if (ctx->device->fp16) {
-            device_extensions.push_back("VK_KHR_shader_float16_int8");
-        }
-        ctx->device->name = ctx->device->properties.deviceName.data();
-
-        device_create_info = {
-            vk::DeviceCreateFlags(),
-            device_queue_create_infos,
-            {},
-            device_extensions
-        };
-        device_create_info.setPNext(&device_features2);
-        ctx->device->device = ctx->device->physical_device.createDevice(device_create_info);
-
-        ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
-
-        // Queues
-        ggml_vk_create_queue(ctx, ctx->device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
-
-        // Shaders
-        ggml_vk_load_shaders(ctx);
-
-        if (!ctx->device->single_queue) {
-            const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
-            ggml_vk_create_queue(ctx, ctx->device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
-        } else {
-            // TODO: Use pointer or reference to avoid copy
-            ctx->device->transfer_queue = ctx->device->compute_queue;
-        }
+    ctx->semaphore_idx = 0;
+    ctx->event_idx = 0;
 
-        ctx->device->idx = dev_num;
-        ctx->device->initialized = true;
-    } else if (ctx->device->idx != dev_num) {
-        std::cerr << "ggml_vulkan: Device " << ctx->device->name << " already initialized with index " << ctx->device->idx << ", but trying to reinitialize with index " << dev_num << std::endl;
-        throw std::runtime_error("Device already initialized");
-    }
+    ctx->prealloc_size_x = 0;
+    ctx->prealloc_size_y = 0;
+    ctx->prealloc_size_split_k = 0;
 
     ctx->fence = ctx->device->device.createFence({});
 
+    ctx->staging_size = 0;
+    ctx->staging_offset = 0;
+
     ctx->compute_ctx = nullptr;
     ctx->transfer_ctx = nullptr;
 
-    ctx->initialized = true;
-
-    ctx->idx = idx;
-
 #ifdef GGML_VULKAN_CHECK_RESULTS
     const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
     vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
@@ -2178,7 +2245,7 @@ static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size)
         ggml_vk_destroy_buffer(b);
     }
 
-    return ggml_vk_create_buffer_device(ctx, size);
+    return ggml_vk_create_buffer_device(ctx->device, size);
 }
 
 static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
@@ -2212,37 +2279,37 @@ static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_
     return buf;
 }
 
-static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
+static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
     VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
-    vk_buffer buf = ggml_vk_create_buffer(ctx, size,
+    vk_buffer buf = ggml_vk_create_buffer(device, size,
         vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
         vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
 
     if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
         fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
             size/1024.0/1024.0);
-        ctx->device->device.freeMemory(buf->device_memory);
-        ctx->device->device.destroyBuffer(buf->buffer);
+        device->device.freeMemory(buf->device_memory);
+        device->device.destroyBuffer(buf->buffer);
         return nullptr;
     }
 
-    ctx->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
+    device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
 
     return buf->ptr;
 }
 
-static void ggml_vk_host_free(ggml_backend_vk_context * ctx, void* ptr) {
+static void ggml_vk_host_free(vk_device& device, void* ptr) {
     if (ptr == nullptr) {
         return;
     }
     VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
     vk_buffer buf;
     size_t index;
-    for (size_t i = 0; i < ctx->pinned_memory.size(); i++) {
-        const uint8_t* addr = (const uint8_t*) std::get<0>(ctx->pinned_memory[i]);
-        const uint8_t* endr = addr + std::get<1>(ctx->pinned_memory[i]);
+    for (size_t i = 0; i < device->pinned_memory.size(); i++) {
+        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
+        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
         if (ptr >= addr && ptr < endr) {
-            buf = std::get<2>(ctx->pinned_memory[i]);
+            buf = std::get<2>(device->pinned_memory[i]);
             index = i;
             break;
         }
@@ -2254,26 +2321,26 @@ static void ggml_vk_host_free(ggml_backend_vk_context * ctx, void* ptr) {
 
     ggml_vk_destroy_buffer(buf);
 
-    ctx->pinned_memory.erase(ctx->pinned_memory.begin() + index);
+    device->pinned_memory.erase(device->pinned_memory.begin() + index);
 }
 
-static void ggml_vk_host_get(ggml_backend_vk_context * ctx, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
+static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
     buf = nullptr;
     buf_offset = 0;
-    for (size_t i = 0; i < ctx->pinned_memory.size(); i++) {
-        const uint8_t* addr = (const uint8_t*) std::get<0>(ctx->pinned_memory[i]);
-        const uint8_t* endr = addr + std::get<1>(ctx->pinned_memory[i]);
+    for (size_t i = 0; i < device->pinned_memory.size(); i++) {
+        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
+        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
         if (ptr >= addr && ptr < endr) {
-            buf = std::get<2>(ctx->pinned_memory[i]);
+            buf = std::get<2>(device->pinned_memory[i]);
             buf_offset = ((const uint8_t *)ptr) - addr;
             break;
         }
     }
 }
 
-static vk_submission ggml_vk_begin_submission(ggml_backend_vk_context * ctx, vk_queue& q, bool one_time = true) {
+static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) {
     vk_submission s;
-    s.buffer = ggml_vk_create_cmd_buffer(ctx, q);
+    s.buffer = ggml_vk_create_cmd_buffer(device, q);
     if (one_time) {
         s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
     } else {
@@ -2333,13 +2400,13 @@ static void ggml_vk_ctx_end(vk_context * ctx) {
     ctx->s = nullptr;
 }
 
-static void ggml_vk_ctx_begin(ggml_backend_vk_context * ctx, vk_context * subctx) {
-    VK_LOG_DEBUG("ggml_vk_ctx_begin(" << ctx << ")");
+static void ggml_vk_ctx_begin(vk_device& device, vk_context * subctx) {
+    VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
     if (subctx->s != nullptr) {
         ggml_vk_ctx_end(subctx);
     }
 
-    subctx->seqs.push_back({ ggml_vk_begin_submission(ctx, *subctx->q) });
+    subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) });
     subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
 }
 
@@ -2356,11 +2423,11 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
     }
 }
 
-static void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size) {
-    if (ctx->sync_staging == nullptr || ctx->sync_staging->size < size) {
+static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
+    if (device->sync_staging == nullptr || device->sync_staging->size < size) {
         VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
-        ggml_vk_destroy_buffer(ctx->sync_staging);
-        ctx->sync_staging = ggml_vk_create_buffer_check(ctx, size,
+        ggml_vk_destroy_buffer(device->sync_staging);
+        device->sync_staging = ggml_vk_create_buffer_check(device, size,
             vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
             vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
     }
@@ -2377,7 +2444,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
     // Check if src is pinned memory
     vk_buffer buf;
     size_t buf_offset;
-    ggml_vk_host_get(ctx, tensor->data, buf, buf_offset);
+    ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
 
     const uint64_t ne0 = tensor->ne[0];
     const uint64_t ne1 = tensor->ne[1];
@@ -2435,9 +2502,9 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
     if (ctx->staging->size < ctx->staging_offset + copy_size) {
         if (sync_staging) {
             // Create temporary larger buffer
-            ggml_vk_ensure_sync_staging_buffer(ctx, copy_size);
+            ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
 
-            staging = ctx->sync_staging;
+            staging = ctx->device->sync_staging;
             staging_offset = 0;
         } else {
             GGML_ASSERT(false);
@@ -2471,11 +2538,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
     }
 }
 
-static void ggml_vk_buffer_write_2d_async(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
+static void ggml_vk_buffer_write_2d_async(vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
-    // Make sure ctx owns the buffer
-    GGML_ASSERT(dst->ctx == ctx);
-
     // Buffer is already mapped
     if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
         std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
@@ -2484,7 +2548,7 @@ static void ggml_vk_buffer_write_2d_async(ggml_backend_vk_context * ctx, vk_cont
     // Check if src is pinned memory
     vk_buffer buf = nullptr;
     size_t buf_offset;
-    ggml_vk_host_get(ctx, src, buf, buf_offset);
+    ggml_vk_host_get(dst->device, src, buf, buf_offset);
 
     if (buf != nullptr) {
         // Memory is pinned, use as staging buffer
@@ -2510,14 +2574,12 @@ static void ggml_vk_buffer_write_2d_async(ggml_backend_vk_context * ctx, vk_cont
     VK_LOG_DEBUG("STAGING");
 
     // Staging buffer required
-    vk_buffer staging = ctx->staging;
-    size_t staging_offset = ctx->staging_offset;
     const size_t copy_size = width*height;
-    if (ctx->staging == nullptr || ctx->staging->size < ctx->staging_offset + copy_size) {
+    if (staging_buffer == nullptr || staging_buffer->size < staging_offset + copy_size) {
         if (sync_staging) {
-            ggml_vk_ensure_sync_staging_buffer(ctx, copy_size);
+            ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
 
-            staging = ctx->sync_staging;
+            staging_buffer = dst->device->sync_staging;
             staging_offset = 0;
         } else {
             GGML_ASSERT(false);
@@ -2530,23 +2592,23 @@ static void ggml_vk_buffer_write_2d_async(ggml_backend_vk_context * ctx, vk_cont
         copy_size};
 
     ggml_vk_sync_buffers(subctx);
-    vkCmdCopyBuffer(subctx->s->buffer, staging->buffer, dst->buffer, 1, &buf_copy);
+    vkCmdCopyBuffer(subctx->s->buffer, staging_buffer->buffer, dst->buffer, 1, &buf_copy);
 
     if (width == spitch) {
-        deferred_memcpy((uint8_t *)staging->ptr + staging_offset, src, width * height, &subctx->in_memcpys);
+        deferred_memcpy((uint8_t *)staging_buffer->ptr + staging_offset, src, width * height, &subctx->in_memcpys);
     } else {
         for (size_t i = 0; i < height; i++) {
-            deferred_memcpy((uint8_t *)staging->ptr + staging_offset + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
+            deferred_memcpy((uint8_t *)staging_buffer->ptr + staging_offset + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
         }
     }
 }
 
-static void ggml_vk_buffer_write_async(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
+static void ggml_vk_buffer_write_async(vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
-    return ggml_vk_buffer_write_2d_async(ctx, subctx, dst, offset, src, size, size, 1, sync_staging);
+    return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, staging_buffer, staging_offset, sync_staging);
 }
 
-static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
+static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
     // Buffer is already mapped
     if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
@@ -2556,38 +2618,38 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
             memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
         }
     } else {
-        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ggml_vk_ctx_begin(ctx, subctx);
-        ggml_vk_buffer_write_2d_async(ctx, subctx, dst, offset, src, spitch, width, height, true);
+        vk_context * subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
+        ggml_vk_ctx_begin(dst->device, subctx);
+        ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, nullptr, 0, true);
         ggml_vk_ctx_end(subctx);
 
         for (auto& cpy : subctx->in_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
 
-        ggml_vk_submit(subctx, ctx->fence);
-        VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
-        ctx->device->device.resetFences({ ctx->fence });
+        ggml_vk_submit(subctx, dst->device->fence);
+        VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
+        dst->device->device.resetFences({ dst->device->fence });
+
+        delete subctx;
     }
 }
 
-static void ggml_vk_buffer_write(ggml_backend_vk_context * ctx, vk_buffer& dst, size_t offset, const void * src, size_t size) {
+static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
     VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
-    ggml_vk_buffer_write_2d(ctx, dst, offset, src, 0, size, 1);
+    ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
 }
 
-static void ggml_vk_buffer_read_2d_async(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
+static void ggml_vk_buffer_read_2d_async(vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
     GGML_ASSERT(width > 0);
     GGML_ASSERT(height > 0);
     GGML_ASSERT(src != nullptr);
-    // Make sure ctx owns the buffer
-    GGML_ASSERT(src->ctx == ctx);
 
     // Check if dst is pinned memory
     vk_buffer buf = nullptr;
     size_t buf_offset;
-    ggml_vk_host_get(ctx, dst, buf, buf_offset);
+    ggml_vk_host_get(src->device, dst, buf, buf_offset);
 
     std::vector<vk::BufferCopy> slices(1);
     if (width == spitch && width == dpitch) {
@@ -2614,55 +2676,56 @@ static void ggml_vk_buffer_read_2d_async(ggml_backend_vk_context * ctx, vk_conte
     VK_LOG_DEBUG("STAGING");
 
     // Fall back to staging buffer
-    vk_buffer staging = ctx->staging;
     const size_t copy_size = dpitch * height;
-    if (ctx->staging == nullptr || ctx->staging->size < ctx->staging_offset + copy_size) {
+    if (staging_buffer == nullptr || staging_buffer->size < staging_offset + copy_size) {
         if (sync_staging) {
             // Create temporary larger buffer
-            ggml_vk_ensure_sync_staging_buffer(ctx, copy_size);
+            ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
 
-            staging = ctx->sync_staging;
+            staging_buffer = src->device->sync_staging;
         } else {
             GGML_ASSERT(false);
         }
     }
 
     ggml_vk_sync_buffers(subctx);
-    subctx->s->buffer.copyBuffer(src->buffer, staging->buffer, slices);
+    subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
 
-    deferred_memcpy(dst, staging->ptr, copy_size, &subctx->out_memcpys);
+    deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
 }
 
-static void ggml_vk_buffer_read_async(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
-    return ggml_vk_buffer_read_2d_async(ctx, subctx, src, offset, dst, size, size, size, 1, sync_staging);
+static void ggml_vk_buffer_read_async(vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+    return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, staging_buffer, staging_offset, sync_staging);
 }
 
-static void ggml_vk_buffer_read(ggml_backend_vk_context * ctx, vk_buffer& src, size_t offset, void * dst, size_t size) {
+static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
     VK_LOG_DEBUG("ggml_vk_buffer_read(" << offset << ", " << size << ")");
     if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
         GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
 
         memcpy(dst, (uint8_t *) src->ptr + offset, size);
     } else {
-        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ggml_vk_ctx_begin(ctx, subctx);
-        ggml_vk_buffer_read_async(ctx, subctx, src, offset, dst, size, true);
+        vk_context * subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
+        ggml_vk_ctx_begin(src->device, subctx);
+        ggml_vk_buffer_read_async(subctx, src, offset, dst, size, nullptr, 0, true);
         ggml_vk_ctx_end(subctx);
 
-        ggml_vk_submit(subctx, ctx->fence);
-        VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
-        ctx->device->device.resetFences({ ctx->fence });
+        ggml_vk_submit(subctx, src->device->fence);
+        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
+        src->device->device.resetFences({ src->device->fence });
 
         for (auto& cpy : subctx->out_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
+
+        delete subctx;
     }
 }
 
 static void ggml_vk_buffer_copy_async(vk_context * ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
     VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
-    // Make sure both buffers are on same ctx
-    GGML_ASSERT(src->ctx == dst->ctx);
+    // Make sure both buffers are on same device
+    GGML_ASSERT(src->device == dst->device);
 
     VkBufferCopy bc{ src_offset, dst_offset, size };
 
@@ -2670,101 +2733,46 @@ static void ggml_vk_buffer_copy_async(vk_context * ctx, vk_buffer& dst, size_t d
 }
 
 static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
-    if (src->ctx == dst->ctx) {
-    VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
+    if (src->device == dst->device) {
+        VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
         // Copy within the device
-        ggml_backend_vk_context * ctx = src->ctx;
-
-        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ggml_vk_ctx_begin(ctx, subctx);
+        vk_context * subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
+        ggml_vk_ctx_begin(src->device, subctx);
         ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
         ggml_vk_ctx_end(subctx);
-        ggml_vk_submit(subctx, ctx->fence);
-        VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
-        ctx->device->device.resetFences({ ctx->fence });
+        ggml_vk_submit(subctx, src->device->fence);
+        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
+        src->device->device.resetFences({ src->device->fence });
+
+        delete subctx;
     } else {
-    VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
+        VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
         // Copy device to device
-        ggml_backend_vk_context * src_ctx = src->ctx;
-        ggml_backend_vk_context * dst_ctx = dst->ctx;
-
-        ggml_vk_ensure_sync_staging_buffer(src_ctx, size);
-        ggml_vk_ensure_sync_staging_buffer(dst_ctx, size);
+        ggml_vk_ensure_sync_staging_buffer(src->device, size);
+        ggml_vk_ensure_sync_staging_buffer(dst->device, size);
 
         // Copy to src staging buffer
-        ggml_vk_buffer_copy(src_ctx->sync_staging, 0, src, src_offset, size);
+        ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
         // memcpy to dst staging buffer
-        memcpy(dst_ctx->sync_staging->ptr, src_ctx->sync_staging->ptr, size);
+        memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
         // Copy to dst buffer
-        ggml_vk_buffer_copy(dst, dst_offset, dst_ctx->sync_staging, 0, size);
+        ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
     }
 }
 
-static void ggml_vk_buffer_memset(ggml_backend_vk_context * ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
+static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
     VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
-    // Make sure ctx owns the buffer
-    GGML_ASSERT(dst->ctx == ctx);
 
-    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-    ggml_vk_ctx_begin(ctx, subctx);
+    vk_context * subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
+    ggml_vk_ctx_begin(dst->device, subctx);
     subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
     ggml_vk_ctx_end(subctx);
 
-    ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
-    ctx->device->device.resetFences({ ctx->fence });
-}
-
-static void ggml_vk_h2d_tensor_2d(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * src, uint64_t i3, uint64_t i2, uint64_t i1) {
-    VK_LOG_DEBUG("ggml_vk_h2d_tensor_2d(dst=" << dst << ", offset=" << offset << ", src=" << src << ", i3=" << i3 << ", i2=" << i2 << ", i1=" << i1 << ")");
-    const uint64_t ne0 = src->ne[0];
-    const uint64_t ne1 = src->ne[1];
-    const uint64_t nb0 = src->nb[0];
-    const uint64_t nb1 = src->nb[1];
-    const uint64_t nb2 = src->nb[2];
-    const uint64_t nb3 = src->nb[3];
-    const enum ggml_type type = src->type;
-    const size_t ts = ggml_type_size(type);
-    const size_t bs = ggml_blck_size(type);
-    const size_t row_length = ts*ne0/bs;
-
-    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
-    if (nb0 == ts && nb1 == row_length) {
-        return ggml_vk_buffer_write_async(ctx, subctx, dst, offset, x, i1*nb1);
-    }
-    if (nb0 == ts && (i1 == ne1 || !ggml_is_permuted(src))) {
-        return ggml_vk_buffer_write_2d_async(ctx, subctx, dst, offset, x, nb1, row_length, i1);
-    }
-
-    GGML_ASSERT(i3 == 0);
-    GGML_ASSERT(i2 == 0);
-    GGML_ASSERT(i1 == (uint64_t) ggml_nrows(src));
-
-    return ggml_vk_buffer_write_nc_async(ctx, subctx, dst, offset, src);
-}
-
-static void ggml_vk_d2h_tensor_2d(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& src, size_t offset, const ggml_tensor * dst) {
-    VK_LOG_DEBUG("ggml_vk_d2h_tensor_2d()");
-    const uint64_t ne0 = dst->ne[0];
-    const uint64_t ne1 = dst->ne[1];
-    const uint64_t ne2 = dst->ne[2];
-    const uint64_t ne3 = dst->ne[3];
-    const uint64_t nb0 = dst->nb[0];
-    const uint64_t nb1 = dst->nb[1];
-    // const uint64_t nb2 = dst->nb[2];
-    // const uint64_t nb3 = dst->nb[3];
-    const enum ggml_type type = dst->type;
-    const size_t ts = ggml_type_size(type);
-    const size_t bs = ggml_blck_size(type);
-    const size_t row_length = ts*ne0/bs;
+    ggml_vk_submit(subctx, dst->device->fence);
+    VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
+    dst->device->device.resetFences({ dst->device->fence });
 
-    if (ggml_is_contiguous(dst)) {
-        return ggml_vk_buffer_read_async(ctx, subctx, src, offset, dst->data, ne1*nb1*ne2*ne3);
-    }
-    if (nb0 == ts) {
-        return ggml_vk_buffer_read_2d_async(ctx, subctx, src, offset, dst->data, nb1, nb1, row_length, ne1*ne2*ne3);
-    }
-    GGML_ASSERT(false);
+    delete subctx;
 }
 
 static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
@@ -2942,8 +2950,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     bool src1_uma = false;
 
     if (ctx->device->uma) {
-        ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
-        ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
+        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
+        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
         src0_uma = d_Qx != nullptr;
         src1_uma = d_Qy != nullptr;
     }
@@ -3035,15 +3043,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 
     // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
+    ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
     if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
     }
     if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
     }
     if (split_k > 1) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
     }
 
     if (x_non_contig) {
@@ -3119,8 +3127,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
     bool src1_uma = false;
 
     if (ctx->device->uma) {
-        ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
-        ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
+        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
+        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
         src0_uma = d_Qx != nullptr;
         src1_uma = d_Qy != nullptr;
     }
@@ -3195,12 +3203,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
 
     // Allocate descriptor sets
     if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
     }
     if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
     }
-    ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13);
+    ggml_pipeline_allocate_descriptor_sets(ctx->device, dmmv, ne12 * ne13);
 
     if (x_non_contig) {
         GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
@@ -3222,6 +3230,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
         stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
     }
 
+    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
+
+    uint32_t groups_x = ne01;
+    uint32_t groups_z = 1;
+
+    if (ne01 > max_groups_x) {
+        groups_z = 64;
+        groups_x /= groups_z;
+    }
+
     // compute
     const vk_mat_vec_push_constants pc = {
         (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
@@ -3229,11 +3247,13 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
         (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
     };
     ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
+    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+                              { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} },
+                              sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
 }
 
 static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
+    VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
     GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
@@ -3264,7 +3284,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     bool src1_uma = false;
 
     if (ctx->device->uma) {
-        ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
+        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
         src1_uma = d_Qy != nullptr;
     }
 
@@ -3289,7 +3309,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     }
 
     // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
+    ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
 
     const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
@@ -3338,7 +3358,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     bool src1_uma = false;
 
     if (ctx->device->uma) {
-        ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
+        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
         src1_uma = d_Qy != nullptr;
     }
 
@@ -3364,7 +3384,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     }
 
     // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
+    ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
 
     const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
@@ -3442,9 +3462,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
     bool ids_uma = false;
 
     if (ctx->device->uma) {
-        ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
-        ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
-        ggml_vk_host_get(ctx, ids->data, d_ids, ids_buf_offset);
+        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
+        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
+        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
         src0_uma = d_Qx != nullptr;
         src1_uma = d_Qy != nullptr;
         ids_uma = d_ids != nullptr;
@@ -3539,12 +3559,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
     GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 
     // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
+    ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
     if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
     }
     if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
     }
 
     if (x_non_contig) {
@@ -3628,9 +3648,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     bool ids_uma = false;
 
     if (ctx->device->uma) {
-        ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
-        ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
-        ggml_vk_host_get(ctx, ids->data, d_ids, ids_buf_offset);
+        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
+        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
+        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
         src0_uma = d_Qx != nullptr;
         src1_uma = d_Qy != nullptr;
         ids_uma = d_ids != nullptr;
@@ -3712,12 +3732,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
 
     // Allocate descriptor sets
     if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
     }
     if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
     }
-    ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13);
+    ggml_pipeline_allocate_descriptor_sets(ctx->device, dmmv, ne12 * ne13);
 
     if (x_non_contig) {
         GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
@@ -3734,6 +3754,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
     }
 
+    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
+
+    uint32_t groups_x = ne01;
+    uint32_t groups_z = 1;
+
+    if (ne01 > max_groups_x) {
+        groups_z = 64;
+        groups_x /= groups_z;
+    }
+
     // compute
     const vk_mat_vec_id_push_constants pc = {
         (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
@@ -3743,7 +3773,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     ggml_vk_sync_buffers(subctx);
     ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
         { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23}, { d_ids, ids_buf_offset, ids_sz } },
-        sizeof(vk_mat_vec_id_push_constants), &pc, { (uint32_t)ne01, (uint32_t)nei0, 1 });
+        sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
 }
 
 static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
@@ -4056,14 +4086,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
     bool src2_uma = false;
 
     if (ctx->device->uma) {
-        ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset);
+        ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
         src0_uma = d_X != nullptr;
         if (use_src1) {
-            ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset);
+            ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset);
             src1_uma = d_Y != nullptr;
         }
         if (use_src2) {
-            ggml_vk_host_get(ctx, src2->data, d_Z, z_buf_offset);
+            ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
             src2_uma = d_Z != nullptr;
         }
     }
@@ -4123,7 +4153,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
 
     // Single call if dimension 2 is contiguous
     if (op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
 
         switch (dst->op) {
         case GGML_OP_NORM:
@@ -4199,7 +4229,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
         GGML_ASSERT(op != GGML_OP_ARGSORT);
         GGML_ASSERT(!use_src2);
 
-        ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03);
+        ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, ne02 * ne03);
 
         switch (dst->op) {
         case GGML_OP_NORM:
@@ -4804,101 +4834,6 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1
     }
 }
 
-static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_t ne1, size_t ne2, size_t ne3) {
-    const size_t ne = ne0 * ne1 * ne2 * ne3;
-
-    ggml_init_params iparams = {
-        /*.mem_size   =*/ 1024*1024*1024,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-
-    ggml_context * ggml_ctx = ggml_init(iparams);
-
-    ggml_tensor * tensor = ggml_new_tensor_4d(ggml_ctx, GGML_TYPE_F32, ne0, ne2, ne1, ne3);  // NOLINT
-    ggml_tensor * result_tensor = ggml_new_tensor_4d(ggml_ctx, GGML_TYPE_F32, ne0, ne1, ne2, ne3);
-
-    float * data = (float *) ggml_vk_host_malloc(ctx, ggml_nbytes(tensor));
-    tensor->data = data;
-
-    float * result_data = (float *) malloc(ggml_nbytes(tensor));
-    result_tensor->data = result_data;
-
-    // Permute
-    {
-        size_t tmp = tensor->nb[2];
-        tensor->nb[2] = tensor->nb[1];
-        tensor->nb[1] = tmp;
-
-        tensor->ne[2] = ne2;
-        tensor->ne[1] = ne1;
-    }
-
-    for (size_t i = 0; i < ne; i++) {
-        data[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
-    }
-
-    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
-    ggml_vk_ctx_begin(ctx, subctx);
-
-    vk_buffer buffer = ggml_vk_create_buffer_check(ctx, ggml_nbytes(tensor), vk::MemoryPropertyFlagBits::eDeviceLocal);
-
-    ggml_vk_h2d_tensor_2d(ctx, subctx, buffer, 0, tensor, 0, 0, ggml_nrows(tensor));
-
-    ggml_vk_ctx_end(subctx);
-    ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
-    ctx->device->device.resetFences({ ctx->fence });
-
-    ggml_vk_buffer_read(ctx, buffer, 0, result_data, ggml_nbytes(tensor));
-
-    double avg_err = 0.0;
-    int first_err_i0 = -1;
-    int first_err_i1 = -1;
-    int first_err_i2 = -1;
-    int first_err_i3 = -1;
-
-    for (size_t i3 = 0; i3 < ne3; i3++) {
-        for (size_t i2 = 0; i2 < ne2; i2++) {
-            for (size_t i1 = 0; i1 < ne1; i1++) {
-                for (size_t i0 = 0; i0 < ne0; i0++) {
-                    float correct = *(float *) ((char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
-                    float result = *(float *) ((char *) result_data + i3*ne2*ne1*ne0*sizeof(float) + i2*ne1*ne0*sizeof(float) + i1*ne0*sizeof(float) + i0*sizeof(float));
-                    double err = std::fabs(result - correct);
-
-                    avg_err += err;
-
-                    if (err > 0.05f && first_err_i0 == -1) {
-                        first_err_i0 = i0;
-                        first_err_i1 = i1;
-                        first_err_i2 = i2;
-                        first_err_i3 = i3;
-                    }
-                }
-            }
-        }
-    }
-
-    avg_err /= ne;
-
-    std::cerr << "TEST nc copy ne0=" << ne0 << " ne1=" << ne1 << " ne2=" << ne2 << " ne3=" << ne3 << " avg_err=" << avg_err << std::endl;
-
-    if (avg_err > 0.1) {
-        std::cerr << "i0 = " << first_err_i0 << " i1 = " << first_err_i1 << " i2 = " << first_err_i2 << " i3 = " << first_err_i3 << std::endl;
-        std::cerr << "Actual result: " << std::endl << std::endl;
-        ggml_vk_print_tensor_area(result_tensor, first_err_i0, first_err_i1, first_err_i2, first_err_i3);
-        std::cerr << "Expected result: " << std::endl << std::endl;
-        ggml_vk_print_tensor_area(tensor, first_err_i0, first_err_i1, first_err_i2, first_err_i3);
-    }
-
-    ggml_free(ggml_ctx);
-
-    ggml_vk_destroy_buffer(buffer);
-
-    ggml_vk_host_free(ctx, data);
-    free(result_data);
-}
-
 static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool pinned) {
     VK_LOG_DEBUG("ggml_vk_test_transfer(" << ne << ")");
     // Check transfers are correct
@@ -5509,7 +5444,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
         if (ctx->prealloc_x != nullptr) {
             ggml_vk_destroy_buffer(ctx->prealloc_x);
         }
-        ctx->prealloc_x = ggml_vk_create_buffer_device(ctx, ctx->prealloc_size_x);
+        ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x);
     }
     if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {
         VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")");
@@ -5517,7 +5452,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
         if (ctx->prealloc_y != nullptr) {
             ggml_vk_destroy_buffer(ctx->prealloc_y);
         }
-        ctx->prealloc_y = ggml_vk_create_buffer_device(ctx, ctx->prealloc_size_y);
+        ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
     }
     if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
         VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
@@ -5525,7 +5460,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
         if (ctx->prealloc_split_k != nullptr) {
             ggml_vk_destroy_buffer(ctx->prealloc_split_k);
         }
-        ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx, ctx->prealloc_size_split_k);
+        ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
     }
     if (ctx->staging == nullptr || (ctx->staging_size > 0 && ctx->staging->size < ctx->staging_size)) {
         VK_LOG_MEMORY("ggml_vk_preallocate_buffers(staging_size: " << ctx->staging_size << ")");
@@ -5533,7 +5468,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
         if (ctx->staging != nullptr) {
             ggml_vk_destroy_buffer(ctx->staging);
         }
-        ctx->staging = ggml_vk_create_buffer_check(ctx, ctx->staging_size,
+        ctx->staging = ggml_vk_create_buffer_check(ctx->device, ctx->staging_size,
             vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
             vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
     }
@@ -5601,7 +5536,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 
     if (ctx->compute_ctx == nullptr) {
         ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
-        ggml_vk_ctx_begin(ctx, ctx->compute_ctx);
+        ggml_vk_ctx_begin(ctx->device, ctx->compute_ctx);
     }
 
     switch (node->op) {
@@ -5709,7 +5644,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     }
 }
 
-static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor){
+static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor){
     ggml_tensor_extra_gpu * extra = nullptr;
 
     switch (tensor->op) {
@@ -5762,17 +5697,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
         return false;
     }
 
-    if (params->ith != 0) {
-        return true;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return true;
-    }
-
     VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
-    ggml_vk_check_results_0(ctx, params, tensor);
+    ggml_vk_check_results_0(ctx, tensor);
 #endif
 
     vk_context& subctx = ctx->gc.contexts[extra->ctx_idx];
@@ -5819,8 +5747,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
         ggml_pipeline_cleanup(pl);
     }
 
-    ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
-    ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
+    ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
+    ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
 
     for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
         ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
@@ -5848,14 +5776,13 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
 
 // Clean up on backend free
 static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
-    VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->idx << ")");
+    VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
     ggml_vk_graph_cleanup(ctx);
 
     ggml_vk_destroy_buffer(ctx->prealloc_x);
     ggml_vk_destroy_buffer(ctx->prealloc_y);
     ggml_vk_destroy_buffer(ctx->prealloc_split_k);
     ggml_vk_destroy_buffer(ctx->staging);
-    ggml_vk_destroy_buffer(ctx->sync_staging);
 
     for (auto& buffer : ctx->buffer_pool) {
         ggml_vk_destroy_buffer(buffer);
@@ -5900,14 +5827,14 @@ GGML_CALL static void ggml_vk_get_device_description(int device, char * descript
 static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT
 
 struct ggml_backend_vk_buffer_context {
-    ggml_backend_vk_context * ctx;
+    vk_device_ref device;
     vk_buffer dev_buffer;
     ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
     size_t temp_tensor_extra_index = 0;
     std::string name;
 
-    ggml_backend_vk_buffer_context(ggml_backend_vk_context * ctx, vk_buffer&& dev_buffer, std::string& name) :
-        ctx(ctx),
+    ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
+        device(device),
         dev_buffer(dev_buffer),
         name(name) {
     }
@@ -5973,24 +5900,24 @@ GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t b
 
 GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
-    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
-
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
     vk_buffer buf = extra->buffer_gpu.lock();
 
-    ggml_vk_buffer_write(ctx->ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
+    ggml_vk_buffer_write(buf, extra->offset + tensor->view_offs + offset, data, size);
+
+    GGML_UNUSED(buffer);
 }
 
 GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
     VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
-    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
-
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
     vk_buffer buf = extra->buffer_gpu.lock();
 
-    ggml_vk_buffer_read(ctx->ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
+    ggml_vk_buffer_read(buf, extra->offset + tensor->view_offs + offset, data, size);
+
+    GGML_UNUSED(buffer);
 }
 
 GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -6013,7 +5940,7 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu
 GGML_CALL static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
     ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
 
-    ggml_vk_buffer_memset(ctx->ctx, ctx->dev_buffer, 0, value, buffer->size);
+    ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size);
 }
 
 static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
@@ -6029,11 +5956,6 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
 };
 
 // vk buffer type
-struct ggml_backend_vk_buffer_type_context {
-    std::string name;
-    ggml_backend_vk_context * ctx;
-};
-
 GGML_CALL static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {
     ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
 
@@ -6046,24 +5968,24 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(
 
     vk_buffer dev_buffer = nullptr;
     try {
-        dev_buffer = ggml_vk_create_buffer_device(ctx->ctx, size);
+        dev_buffer = ggml_vk_create_buffer_device(ctx->device, size);
     } catch (const vk::SystemError& e) {
         return nullptr;
     }
 
-    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->ctx, std::move(dev_buffer), ctx->name);
+    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);
 
     return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);
 }
 
 GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
-    return ctx->ctx->device->properties.limits.minStorageBufferOffsetAlignment;
+    return ctx->device->properties.limits.minStorageBufferOffsetAlignment;
 }
 
 GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
     ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
-    return ctx->ctx->device->max_memory_allocation_size;
+    return ctx->device->max_memory_allocation_size;
 }
 
 GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
@@ -6072,25 +5994,14 @@ GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_
     UNUSED(buft);
 }
 
-static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_vk_buffer_type_name,
-    /* .alloc_buffer     = */ ggml_backend_vk_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_vk_buffer_type_get_alignment,
-    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
-    /* .get_alloc_size   = */ ggml_backend_vk_buffer_type_get_alloc_size,
-    /* .is_host          = */ NULL,
-};
-
 GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
     ggml_vk_instance_init();
 
     VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")");
 
-    GGML_ASSERT(dev_num < vk_instance.device_indices.size());
+    vk_device dev = ggml_vk_get_device(dev_num);
 
-    ggml_backend_vk_init(dev_num);
-
-    return &vk_instance.buffer_types[dev_num];
+    return &dev->buffer_type;
 }
 
 // host buffer type
@@ -6109,15 +6020,16 @@ GGML_CALL static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buff
 
 GGML_CALL static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
-    ggml_vk_host_free(&vk_instance.contexts[0], buffer->context);
+    ggml_vk_host_free(vk_instance.devices[0], buffer->context);
 }
 
 GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")");
+
     size += 32;  // Behave like the CPU buffer type
     void * ptr = nullptr;
     try {
-        ptr = ggml_vk_host_malloc(&vk_instance.contexts[0], size);
+        ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
     } catch (vk::SystemError& e) {
         std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl;
         std::cerr << "ggml_vulkan: " << e.what() << std::endl;
@@ -6131,14 +6043,18 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_bu
     buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
 
     return buffer;
+
+    UNUSED(buft);
 }
 
 GGML_CALL static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return vk_instance.contexts[0].device->properties.limits.minMemoryMapAlignment;
+    return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;
 
     UNUSED(buft);
 }
 
+// Should be changed to return device-specific host buffer type
+// but that probably requires changes in llama.cpp
 GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
     static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {
         /* .iface    = */ {
@@ -6152,14 +6068,14 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
         /* .context  = */ nullptr,
     };
 
-    if (!vk_instance.contexts[0].initialized) {
-        // Fall back to CPU
-        return ggml_backend_cpu_buffer_type();
-    }
+    // Make sure device 0 is initialized
+    ggml_vk_instance_init();
+    ggml_vk_get_device(0);
 
     return &ggml_backend_vk_buffer_type_host;
 }
 
+
 // backend
 
 GGML_CALL static const char * ggml_backend_vk_name(ggml_backend_t backend) {
@@ -6172,74 +6088,65 @@ GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) {
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
     VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")");
 
-    size_t idx = ctx->idx;
-
     ggml_vk_cleanup(ctx);
 
-    ctx->device.reset();
-    ctx->initialized = false;
-
-    vk_instance.initialized[idx] = false;
-    vk_instance.backends[idx] = nullptr;
-    memset(&vk_instance.buffer_types[idx], 0, sizeof(ggml_backend_buffer_type));
+    delete ctx;
     delete backend;
 }
 
 GGML_CALL static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
 
-    GGML_ASSERT(ctx->initialized);
-
-    return ggml_backend_vk_buffer_type(ctx->idx);
+    return &ctx->device->buffer_type;
 }
 
 GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
+    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
     if (ctx->transfer_ctx == nullptr) {
         // Initialize new transfer context
         ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
+        ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
     }
 
     vk_buffer buf = extra->buffer_gpu.lock();
 
-    ggml_vk_buffer_write_async(ctx, ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
+    ggml_vk_buffer_write_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
 }
 
 GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
     VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
+    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
     if (ctx->transfer_ctx == nullptr) {
         // Initialize new transfer context
         ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
+        ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
     }
 
     vk_buffer buf = extra->buffer_gpu.lock();
 
-    ggml_vk_buffer_read_async(ctx, ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
+    ggml_vk_buffer_read_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
 }
 
 GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    if ((dst->buffer->buft == ggml_backend_vk_buffer_type(ctx->idx) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
+    if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
         ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
         ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
 
         if (ctx->transfer_ctx == nullptr) {
             // Initialize new transfer context
             ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-            ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
+            ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
         }
 
         vk_buffer src_buf = src_extra->buffer_gpu.lock();
@@ -6300,9 +6207,6 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
         ggml_vk_build_graph(ctx,cgraph->nodes[i], i == last_node);
     }
 
-    ggml_compute_params params = {};
-    params.type = GGML_TASK_TYPE_COMPUTE;
-    params.ith = 0;
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -6310,13 +6214,13 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
             continue;
         }
 
-        bool ok = ggml_vk_compute_forward(ctx, &params, node);
+        bool ok = ggml_vk_compute_forward(ctx, node);
         if (!ok) {
             fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
         }
 #ifdef GGML_VULKAN_CHECK_RESULTS
         else {
-            ggml_vk_check_results_1(ctx, &params, node);
+            ggml_vk_check_results_1(ctx, node);
         }
 #endif
         GGML_ASSERT(ok);
@@ -6458,7 +6362,7 @@ GGML_CALL static bool ggml_backend_vk_supports_buft(ggml_backend_t backend, ggml
     ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
 
-    return buft_ctx->ctx->idx == ctx->idx;
+    return buft_ctx->device == ctx->device;
 }
 
 // TODO: enable async and synchronize
@@ -6491,28 +6395,17 @@ static ggml_guid_t ggml_backend_vk_guid() {
 }
 
 GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
-    if (vk_instance.initialized[dev_num]) {
-        return vk_instance.backends[dev_num];
-    }
     VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
 
-    ggml_backend_vk_context * ctx = &vk_instance.contexts[dev_num];
+    ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
     ggml_vk_init(ctx, dev_num);
-    ctx->name = GGML_VK_NAME + std::to_string(dev_num);
-    vk_instance.buffer_types[dev_num] = {
-        /* .iface    = */ ggml_backend_vk_buffer_type_interface,
-        /* .context  = */ new ggml_backend_vk_buffer_type_context{ ctx->name, ctx },
-    };
-    vk_instance.initialized[dev_num] = true;
 
     ggml_backend_t vk_backend = new ggml_backend {
         /* .guid      = */ ggml_backend_vk_guid(),
         /* .interface = */ ggml_backend_vk_interface,
-        /* .context   = */ &vk_instance.contexts[ctx->idx],
+        /* .context   = */ ctx,
     };
 
-    vk_instance.backends[dev_num] = vk_backend;
-
     return vk_backend;
 }
 
@@ -6697,11 +6590,8 @@ void * comp_result;
 size_t comp_size;
 size_t comp_nb[GGML_MAX_DIMS];
 size_t check_counter = 0;
-static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor) {
-    if (params->ith != 0) {
-        return;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
+static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
+        if (tensor->op == GGML_OP_TRANSPOSE) {
         return;
     }
 
@@ -7005,11 +6895,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
     ggml_free(ggml_ctx);
 }
 
-static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_params * params, ggml_tensor * tensor) {
-    if (params->ith != 0) {
-        return;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE || tensor->op == GGML_OP_TRANSPOSE) {
+static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
+    if (tensor->op == GGML_OP_TRANSPOSE) {
         return;
     }
     if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
diff --git a/src/ggml-vulkan.h b/src/ggml-vulkan.h
deleted file mode 100644 (file)
index af661c2..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-
-#ifdef  __cplusplus
-extern "C" {
-#endif
-
-#define GGML_VK_NAME "Vulkan"
-#define GGML_VK_MAX_DEVICES 16
-
-GGML_API void ggml_vk_instance_init(void);
-
-// backend API
-GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);
-
-GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend);
-GGML_API GGML_CALL int  ggml_backend_vk_get_device_count(void);
-GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
-GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
-
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
-// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
-GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
-
-#ifdef  __cplusplus
-}
-#endif
index d5d33c2ba10299253ec9b731a24ccb0d32952f32..f5502afbe98b36dc0c75b08b1986105057395b81 100644 (file)
@@ -175,7 +175,6 @@ void ggml_print_backtrace(void) {
 }
 #endif
 
-/*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
 #define GGML_GELU_QUICK_FP16
@@ -293,7 +292,7 @@ inline static void * ggml_calloc(size_t num, size_t size) {
 #define GGML_FREE(ptr) free(ptr)
 
 #define UNUSED GGML_UNUSED
-#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
+#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
 
 #if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
@@ -474,18 +473,6 @@ int64_t ggml_cycles_per_ms(void) {
     return CLOCKS_PER_SEC/1000;
 }
 
-#ifdef GGML_PERF
-#define ggml_perf_time_ms()       ggml_time_ms()
-#define ggml_perf_time_us()       ggml_time_us()
-#define ggml_perf_cycles()        ggml_cycles()
-#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
-#else
-#define ggml_perf_time_ms()       0
-#define ggml_perf_time_us()       0
-#define ggml_perf_cycles()        0
-#define ggml_perf_cycles_per_ms() 0
-#endif
-
 //
 // cross-platform UTF-8 file paths
 //
@@ -1730,8 +1717,8 @@ struct ggml_context {
 
     int    n_objects;
 
-    struct ggml_object* objects_begin;
-    struct ggml_object* objects_end;
+    struct ggml_object * objects_begin;
+    struct ggml_object * objects_end;
 
     struct ggml_scratch scratch;
     struct ggml_scratch scratch_save;
@@ -1744,30 +1731,38 @@ struct ggml_context_container {
 };
 
 struct ggml_compute_state_shared {
-    const struct ggml_cgraph* cgraph;
-    const struct ggml_cplan* cplan;
-
-    int64_t perf_node_start_cycles;
-    int64_t perf_node_start_time_us;
+    const struct ggml_cgraph * cgraph;
+    const struct ggml_cplan * cplan;
 
     int n_threads;
 
     // synchronization primitives
-    atomic_int n_active;  // num active threads
-    atomic_int node_n;    // active graph node
-    atomic_int node_task; // active graph node task phase
+    atomic_int n_barrier;
+    atomic_int n_barrier_passed;
 
     ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
-    void* abort_callback_data;
+    void * abort_callback_data;
 
-    atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
+    atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
+
+    enum ggml_status ec;
 };
 
 struct ggml_compute_state {
     ggml_thread_t thrd;
     int ith;
-    struct ggml_compute_state_shared* shared;
-    enum ggml_status ec;
+    struct ggml_compute_state_shared * shared;
+};
+
+struct ggml_compute_params {
+    // ith = thread index, nth = number of threads
+    int ith, nth;
+
+    // work buffer for all threads
+    size_t wsize;
+    void * wdata;
+
+    struct ggml_compute_state_shared * shared;
 };
 
 //
@@ -2815,42 +2810,6 @@ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
-// WARN:
-// Mis-configuration can lead to problem that's hard to reason about:
-// * At best  it crash or talks nosense.
-// * At worst it talks slightly difference but hard to perceive.
-//
-// An op has to enable INIT or FINALIZE when any of it's branch needs that pass.
-// Take care about compile options (e.g., GGML_USE_xxx).
-static bool GGML_OP_HAS_INIT    [GGML_OP_COUNT] = { 0 };
-static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 };
-
-static void ggml_setup_op_has_task_pass(void) {
-    {   // INIT
-        bool * p = GGML_OP_HAS_INIT;
-
-        p[GGML_OP_ACC                    ] = true;
-        p[GGML_OP_MUL_MAT                ] = true;
-        p[GGML_OP_MUL_MAT_ID             ] = true;
-        p[GGML_OP_OUT_PROD               ] = true;
-        p[GGML_OP_SET                    ] = true;
-        p[GGML_OP_GET_ROWS_BACK          ] = true;
-        p[GGML_OP_DIAG_MASK_INF          ] = true;
-        p[GGML_OP_DIAG_MASK_ZERO         ] = true;
-        p[GGML_OP_CONV_TRANSPOSE_1D      ] = true;
-        p[GGML_OP_CONV_TRANSPOSE_2D      ] = true;
-        p[GGML_OP_FLASH_ATTN_BACK        ] = true;
-        p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
-        p[GGML_OP_ADD_REL_POS            ] = true;
-    }
-
-    {   // FINALIZE
-        bool * p = GGML_OP_HAS_FINALIZE;
-
-        p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
-    }
-}
-
 //
 // NUMA support
 //
@@ -2889,7 +2848,7 @@ struct ggml_state {
 static struct ggml_state g_state;
 static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
 
-// barrier via spin lock
+// critical section via spin lock
 inline static void ggml_critical_section_start(void) {
     while (atomic_flag_test_and_set(&g_state_critical)) {
         // spin
@@ -2897,6 +2856,48 @@ inline static void ggml_critical_section_start(void) {
     }
 }
 
+#ifdef GGML_USE_OPENMP
+static void ggml_barrier(struct ggml_compute_state_shared * shared) {
+    if (shared->n_threads == 1) {
+        return;
+    }
+
+    #pragma omp barrier
+}
+#else
+static void ggml_barrier(struct ggml_compute_state_shared * shared) {
+    if (shared->n_threads == 1) {
+        return;
+    }
+
+    atomic_int * n_barrier = &shared->n_barrier;
+    atomic_int * n_barrier_passed = &shared->n_barrier_passed;
+
+    int n_threads = shared->n_threads;
+    int passed_old = atomic_load(n_barrier_passed);
+
+    if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
+        // last thread
+        atomic_store(n_barrier, 0);
+        atomic_fetch_add(n_barrier_passed, 1);
+    } else {
+        // wait for other threads
+        const int n_spin_before_sleep = 100000;
+        while (true) {
+            for (int i = 0; i < n_spin_before_sleep; i++) {
+                if (atomic_load(n_barrier_passed) != passed_old) {
+                    return;
+                }
+            #if defined(__SSE3__)
+                _mm_pause();
+            #endif
+            }
+            sched_yield();
+        }
+    }
+}
+#endif
+
 // TODO: make this somehow automatically executed
 //       some sort of "sentry" mechanism
 inline static void ggml_critical_section_end(void) {
@@ -3001,7 +3002,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
         }
     }
 #else
-    GGML_UNUSED(numa_flag);
+    UNUSED(numa_flag);
     // TODO
 #endif
 }
@@ -3107,9 +3108,7 @@ GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t) {
         enum ggml_unary_op uop = ggml_get_unary_op(t);
         return ggml_unary_op_name(uop);
     }
-    else {
-        return ggml_op_name(t->op);
-    }
+    return ggml_op_name(t->op);
 }
 
 GGML_CALL size_t ggml_element_size(const struct ggml_tensor * tensor) {
@@ -3376,8 +3375,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
-        ggml_setup_op_has_task_pass();
-
         is_first_call = false;
     }
 
@@ -3644,15 +3641,12 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.flags        =*/ 0,
         /*.grad         =*/ NULL,
         /*.src          =*/ { NULL },
-        /*.perf_runs    =*/ 0,
-        /*.perf_cycles  =*/ 0,
-        /*.perf_time_us =*/ 0,
         /*.view_src     =*/ view_src,
         /*.view_offs    =*/ view_offs,
         /*.data         =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
         /*.extra        =*/ NULL,
-        /*.padding      =*/ { 0 },
+        ///*.padding      =*/ { 0 },
     };
 
 #ifdef __clang__
@@ -7830,10 +7824,6 @@ static void ggml_compute_forward_dup_same_cont(
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
     GGML_ASSERT(src0->type == dst->type);
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const size_t nb00 = src0->nb[0];
     const size_t nb0 = dst->nb[0];
 
@@ -7862,10 +7852,6 @@ static void ggml_compute_forward_dup_f16(
 
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     const int ith = params->ith; // thread index
@@ -8135,10 +8121,6 @@ static void ggml_compute_forward_dup_bf16(
 
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     const int ith = params->ith; // thread index
@@ -8495,10 +8477,6 @@ static void ggml_compute_forward_dup_f32(
 
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     const int ith = params->ith; // thread index
@@ -8818,10 +8796,6 @@ static void ggml_compute_forward_dup_bytes(
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
     GGML_ASSERT(src0->type == dst->type);
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
         ggml_compute_forward_dup_same_cont(params, dst);
         return;
@@ -9002,10 +8976,6 @@ static void ggml_compute_forward_add_f32(
 
     GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9081,10 +9051,6 @@ static void ggml_compute_forward_add_f16_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9160,10 +9126,6 @@ static void ggml_compute_forward_add_bf16_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9239,10 +9201,6 @@ static void ggml_compute_forward_add_f16_f16(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9295,10 +9253,6 @@ static void ggml_compute_forward_add_bf16_bf16(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9351,10 +9305,6 @@ static void ggml_compute_forward_add_q_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int nr  = ggml_nrows(src0);
 
     GGML_TENSOR_BINARY_OP_LOCALS
@@ -9504,10 +9454,6 @@ static void ggml_compute_forward_add1_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -9558,10 +9504,6 @@ static void ggml_compute_forward_add1_f16_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = *(float *) src1->data;
 
@@ -9610,10 +9552,6 @@ static void ggml_compute_forward_add1_f16_f16(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
 
@@ -9662,10 +9600,6 @@ static void ggml_compute_forward_add1_q_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = *(float *) src1->data;
 
@@ -9731,10 +9665,6 @@ static void ggml_compute_forward_add1_bf16_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = *(float *) src1->data;
 
@@ -9783,10 +9713,6 @@ static void ggml_compute_forward_add1_bf16_bf16(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_scalar(src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scalar to add
     const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
 
@@ -9911,20 +9837,16 @@ static void ggml_compute_forward_acc_f32(
     size_t offset  = ((int32_t *) dst->op_params)[3];
     bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 
-    if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) {
-        if (params->ith != 0) {
-            return;
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
         }
-        // memcpy needs to be synchronized across threads to avoid race conditions.
-        // => do it in INIT phase
-        memcpy(
-            ((char *)  dst->data),
-            ((char *) src0->data),
-            ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+        ggml_barrier(params->shared);
     }
 
     const int ith = params->ith;
@@ -10026,13 +9948,12 @@ static void ggml_compute_forward_sub_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
     const int nr  = ggml_nrows(src0);
 
     GGML_TENSOR_BINARY_OP_LOCALS
@@ -10110,9 +10031,6 @@ static void ggml_compute_forward_mul_f32(
 
     GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -10207,10 +10125,6 @@ static void ggml_compute_forward_div_f32(
 
     GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -10299,13 +10213,12 @@ static void ggml_compute_forward_sqr_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_are_same_shape(src0, dst));
+
     const int n     = ggml_nrows(src0);
     const int nc    = src0->ne[0];
 
@@ -10345,13 +10258,12 @@ static void ggml_compute_forward_sqrt_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_are_same_shape(src0, dst));
+
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -10391,13 +10303,12 @@ static void ggml_compute_forward_log_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -10437,13 +10348,13 @@ static void ggml_compute_forward_sum_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_is_scalar(dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_is_scalar(dst));
+
+
     assert(ggml_is_scalar(dst));
     assert(src0->nb[0] == sizeof(float));
 
@@ -10472,13 +10383,12 @@ static void ggml_compute_forward_sum_f16(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_is_scalar(dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_is_scalar(dst));
+
     assert(src0->nb[0] == sizeof(ggml_fp16_t));
 
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
@@ -10506,13 +10416,12 @@ static void ggml_compute_forward_sum_bf16(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-    assert(ggml_is_scalar(dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    assert(ggml_is_scalar(dst));
+
     assert(src0->nb[0] == sizeof(ggml_bf16_t));
 
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
@@ -10568,9 +10477,7 @@ static void ggml_compute_forward_sum_rows_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -10623,9 +10530,7 @@ static void ggml_compute_forward_mean_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -10682,9 +10587,7 @@ static void ggml_compute_forward_argmax_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -10732,13 +10635,12 @@ static void ggml_compute_forward_repeat_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
     GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
@@ -10777,13 +10679,12 @@ static void ggml_compute_forward_repeat_f16(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
     GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
@@ -10852,13 +10753,12 @@ static void ggml_compute_forward_repeat_back_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_can_repeat(dst, src0));
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
     GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
@@ -10932,10 +10832,6 @@ static void ggml_compute_forward_concat_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -11004,15 +10900,14 @@ static void ggml_compute_forward_abs_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11049,15 +10944,14 @@ static void ggml_compute_forward_sgn_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11094,15 +10988,14 @@ static void ggml_compute_forward_neg_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11139,15 +11032,14 @@ static void ggml_compute_forward_step_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11184,15 +11076,14 @@ static void ggml_compute_forward_tanh_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11229,15 +11120,14 @@ static void ggml_compute_forward_elu_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11274,15 +11164,14 @@ static void ggml_compute_forward_relu_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11319,15 +11208,14 @@ static void ggml_compute_forward_sigmoid_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11368,10 +11256,6 @@ static void ggml_compute_forward_gelu_f32(
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -11431,10 +11315,6 @@ static void ggml_compute_forward_gelu_quick_f32(
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -11494,10 +11374,6 @@ static void ggml_compute_forward_silu_f32(
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -11552,15 +11428,14 @@ static void ggml_compute_forward_leaky_relu_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11610,10 +11485,6 @@ static void ggml_compute_forward_silu_back_f32(
     assert(ggml_are_same_shape(src0, dst));
     assert(ggml_are_same_shape(src0, grad));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -11669,15 +11540,14 @@ static void ggml_compute_forward_hardswish_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11711,15 +11581,14 @@ static void ggml_compute_forward_hardsigmoid_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -11759,10 +11628,6 @@ static void ggml_compute_forward_norm_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -11834,10 +11699,6 @@ static void ggml_compute_forward_rms_norm_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -11905,10 +11766,6 @@ static void ggml_compute_forward_rms_norm_back_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -12083,10 +11940,6 @@ static void ggml_compute_forward_group_norm_f32(
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -12191,8 +12044,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
 
     const bool src1_cont = ggml_is_contiguous(src1);
 
-    ggml_vec_dot_t    const vec_dot = type_traits[type].vec_dot;
-    enum ggml_type    const vec_dot_type = type_traits[type].vec_dot_type;
+    ggml_vec_dot_t const vec_dot      = type_traits[type].vec_dot;
+    enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
 
     // broadcast factors
     const int64_t r2 = ne12 / ne02;
@@ -12266,15 +12119,11 @@ static void ggml_compute_forward_mul_mat_one_chunk(
 
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
-              struct ggml_tensor * dst,
-              struct ggml_compute_state * state) {
+              struct ggml_tensor * dst) {
 
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -12301,16 +12150,14 @@ static void ggml_compute_forward_mul_mat(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    // broadcast factors
-    const int64_t r2 = ne12 / ne02;
-    const int64_t r3 = ne13 / ne03;
-    UNUSED(r2);
-    UNUSED(r3);
-
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
 #if GGML_USE_LLAMAFILE
+    // broadcast factors
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
     const bool src1_cont = ggml_is_contiguous(src1);
 
     if (src1_cont) {
@@ -12324,7 +12171,6 @@ static void ggml_compute_forward_mul_mat(
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),
                                      ith, nth,
-                                     params->type,
                                      src0->type,
                                      src1->type,
                                      dst->type))
@@ -12334,36 +12180,34 @@ static void ggml_compute_forward_mul_mat(
 UseGgmlGemm1:;
 #endif
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
-        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
-        atomic_store(&state->shared->current_chunk, nth);
-        if (src1->type != vec_dot_type) {
-            char * wdata = params->wdata;
-            const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
-            assert(params->wsize >= ne11*ne12*ne13*row_size);
-            GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-            for (int64_t i13 = 0; i13 < ne13; ++i13) {
-                for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                    for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                        from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                        wdata += row_size;
-                    }
-                }
+    if (src1->type != vec_dot_type) {
+        char * wdata = params->wdata;
+
+        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
+
+        assert(params->wsize >= ne13*nbw3);
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+                    from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                                          (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                                           ne10);
+                }
             }
         }
-
-        return;
     }
 
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+    if (ith == 0) {
+        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+        atomic_store(&params->shared->current_chunk, nth);
     }
 
+    ggml_barrier(params->shared);
+
 #if GGML_USE_LLAMAFILE
     if (src1->type != vec_dot_type) {
         const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
@@ -12379,7 +12223,6 @@ UseGgmlGemm1:;
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),
                                      ith, nth,
-                                     params->type,
                                      src0->type,
                                      vec_dot_type,
                                      dst->type))
@@ -12389,11 +12232,6 @@ UseGgmlGemm1:;
 UseGgmlGemm2:;
 #endif
 
-#ifdef GGML_PERF
-    int chunks_executed = 0;
-    UNUSED(chunks_executed);
-#endif
-
     // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
     const int64_t nr0 = ne0;
 
@@ -12435,9 +12273,6 @@ UseGgmlGemm2:;
     const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
     const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
 
-    //if (ith == 0)
-    //    printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d.  Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
-
     // The first chunk comes from our thread_id, the rest will get auto-assigned.
     int current_chunk = ith;
 
@@ -12453,23 +12288,12 @@ UseGgmlGemm2:;
 
         ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
 
-#ifdef GGML_PERF
-        chunks_executed++;
-#endif
-
         if (nth >= nchunk0 * nchunk1) {
             break;
         }
 
-        current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
+        current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
     }
-
-#ifdef GGML_PERF
-    // These numbers are useful when trying to measure how well the threading scheduling works.
-    //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
-    //float time = (ggml_perf_time_us() - t0);
-    //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
-#endif
 }
 
 // ggml_compute_forward_mul_mat_id
@@ -12521,32 +12345,33 @@ static void ggml_compute_forward_mul_mat_id(
     int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
     struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
 
-   if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (src1->type != vec_dot_type) {
         char * wdata = params->wdata;
-        if (src1->type != vec_dot_type) {
-            const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
-            assert(params->wsize >= ne11*ne12*ne13*row_size);
-            assert(src1->type == GGML_TYPE_F32);
+        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
 
-            for (int64_t i13 = 0; i13 < ne13; ++i13) {
-                for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                    for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                        from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                        wdata += row_size;
-                    }
+        assert(params->wsize >= ne13*nbw3);
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+                    from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                                          (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                                           ne10);
                 }
             }
         }
+    }
+
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
 
+    if (ith == 0) {
         // initialize matrix_row_counts
         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
 
-#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
-
         // group rows by src0 matrix
         for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
             for (int id = 0; id < n_ids; ++id) {
@@ -12558,13 +12383,9 @@ static void ggml_compute_forward_mul_mat_id(
                 matrix_row_counts[i02] += 1;
             }
         }
-
-        return;
     }
 
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
+    ggml_barrier(params->shared);
 
     // compute each matrix multiplication in sequence
     for (int cur_a = 0; cur_a < n_as; ++cur_a) {
@@ -12663,9 +12484,6 @@ static void ggml_compute_forward_out_prod_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    // int64_t t0 = ggml_perf_time_us();
-    // UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -12690,17 +12508,10 @@ static void ggml_compute_forward_out_prod_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     // dst[:,:,:,:] = 0
     // for i2,i3:
@@ -12776,19 +12587,6 @@ static void ggml_compute_forward_out_prod_f32(
             }
         }
     }
-
-    //int64_t t1 = ggml_perf_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
-
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
 }
 
 static void ggml_compute_forward_out_prod_q_f32(
@@ -12798,9 +12596,6 @@ static void ggml_compute_forward_out_prod_q_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    // int64_t t0 = ggml_perf_time_us();
-    // UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
@@ -12831,17 +12626,10 @@ static void ggml_compute_forward_out_prod_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     // parallelize by last three dimensions
 
@@ -12888,19 +12676,6 @@ static void ggml_compute_forward_out_prod_q_f32(
             ggml_vec_mad_f32(ne0, d, wdata, *s1);
         }
     }
-
-    //int64_t t1 = ggml_perf_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
-
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
 }
 
 static void ggml_compute_forward_out_prod(
@@ -12960,10 +12735,6 @@ static void ggml_compute_forward_scale_f32(
     GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // scale factor
     float v;
     memcpy(&v, dst->op_params, sizeof(float));
@@ -13032,20 +12803,16 @@ static void ggml_compute_forward_set_f32(
     size_t offset  = ((int32_t *) dst->op_params)[3];
     bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 
-    if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) {
-        if (params->ith != 0) {
-            return;
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
         }
-        // memcpy needs to be synchronized across threads to avoid race conditions.
-        // => do it in INIT phase
-        memcpy(
-            ((char *)  dst->data),
-            ((char *) src0->data),
-            ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+        ggml_barrier(params->shared);
     }
 
     const int ith = params->ith;
@@ -13194,10 +12961,6 @@ static void ggml_compute_forward_get_rows_q(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
@@ -13242,10 +13005,6 @@ static void ggml_compute_forward_get_rows_f16(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
@@ -13287,10 +13046,6 @@ static void ggml_compute_forward_get_rows_bf16(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
@@ -13332,10 +13087,6 @@ static void ggml_compute_forward_get_rows_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
@@ -13447,21 +13198,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    GGML_ASSERT(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     GGML_ASSERT(ggml_is_contiguous(dst));
 
     // ggml_compute_forward_dup_same_cont(params, opt0, dst);
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (params->ith != 0) {
-            return;
-        }
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
+    memset(dst->data, 0, ggml_nbytes(dst));
 
     const int nc = src0->ne[0];
     const int nr = ggml_nelements(src1);
@@ -13486,21 +13231,15 @@ static void ggml_compute_forward_get_rows_back_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    GGML_ASSERT(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     GGML_ASSERT(ggml_is_contiguous(dst));
 
     // ggml_compute_forward_dup_same_cont(params, opt0, dst);
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (params->ith != 0) {
-            return;
-        }
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
+    memset(dst->data, 0, ggml_nbytes(dst));
 
     const int nc = src0->ne[0];
     const int nr = ggml_nelements(src1);
@@ -13566,9 +13305,7 @@ static void ggml_compute_forward_diag_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -13637,22 +13374,18 @@ static void ggml_compute_forward_diag_mask_f32(
 
     GGML_ASSERT(n_past >= 0);
 
-    if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) {
-        if (ith != 0) {
-            return;
+    if (!inplace) {
+        if (ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
         }
-        // memcpy needs to be synchronized across threads to avoid race conditions.
-        // => do it in INIT phase
-        GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-        GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-        memcpy(
-            ((char *)  dst->data),
-            ((char *) src0->data),
-            ggml_nbytes(dst));
-    }
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+        ggml_barrier(params->shared);
     }
 
     // TODO: handle transposed/permuted matrices
@@ -13724,10 +13457,6 @@ static void ggml_compute_forward_soft_max_f32(
     assert(ggml_is_contiguous(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     float scale    = 1.0f;
     float max_bias = 0.0f;
 
@@ -13849,10 +13578,6 @@ static void ggml_compute_forward_soft_max_back_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_are_same_shape(src1, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
@@ -13941,9 +13666,7 @@ static void ggml_compute_forward_clamp_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -14090,10 +13813,6 @@ static void ggml_compute_forward_rope_f32(
     const struct ggml_tensor * src1 = dst->src[1];
     const struct ggml_tensor * src2 = dst->src[2];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
@@ -14220,10 +13939,6 @@ static void ggml_compute_forward_rope_f16(
     const struct ggml_tensor * src1 = dst->src[1];
     const struct ggml_tensor * src2 = dst->src[2];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
@@ -14398,9 +14113,6 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -14411,10 +14123,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         memset(params->wdata, 0, params->wsize);
 
         // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
@@ -14447,13 +14156,8 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
 
         // need to zero dst since we are accumulating into it
         memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
 
@@ -14497,9 +14201,6 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -14510,10 +14211,7 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
     GGML_ASSERT(nb00 == sizeof(float));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         memset(params->wdata, 0, params->wsize);
 
         // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
@@ -14546,13 +14244,8 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
 
         // need to zero dst since we are accumulating into it
         memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
 
@@ -14621,9 +14314,6 @@ static void ggml_compute_forward_im2col_f32(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
@@ -14654,14 +14344,6 @@ static void ggml_compute_forward_im2col_f32(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
     {
         float * const wdata = (float *) dst->data;
@@ -14709,9 +14391,6 @@ static void ggml_compute_forward_im2col_f16(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F16);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
@@ -14742,14 +14421,6 @@ static void ggml_compute_forward_im2col_f16(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
     {
         ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
@@ -14815,9 +14486,6 @@ static void ggml_compute_forward_conv_transpose_2d(
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
@@ -14828,10 +14496,7 @@ static void ggml_compute_forward_conv_transpose_2d(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith != 0) {
-            return;
-        }
+    if (ith == 0) {
         memset(params->wdata, 0, params->wsize);
 
         // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
@@ -14866,13 +14531,8 @@ static void ggml_compute_forward_conv_transpose_2d(
         }
 
         memset(dst->data, 0, ggml_nbytes(dst));
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
     }
+    ggml_barrier(params->shared);
 
     const int32_t stride = ggml_get_op_params_i32(dst, 0);
 
@@ -14920,9 +14580,8 @@ static void ggml_compute_forward_pool_1d_sk_p0(
     const struct ggml_tensor * src = dst->src[0];
 
     assert(src->type == GGML_TYPE_F32);
-    assert(params->ith == 0);
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -14989,9 +14648,8 @@ static void ggml_compute_forward_pool_2d(
     const struct ggml_tensor * src = dst->src[0];
 
     GGML_ASSERT(src->type == GGML_TYPE_F32);
-    GGML_ASSERT(params->ith == 0);
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -15064,10 +14722,6 @@ static void ggml_compute_forward_upscale_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
     const int ith = params->ith;
@@ -15128,10 +14782,6 @@ static void ggml_compute_forward_pad_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT( dst->nb[0] == sizeof(float));
 
@@ -15188,10 +14838,6 @@ static void ggml_compute_forward_arange_f32(
     const struct ggml_compute_params * params,
     struct ggml_tensor * dst) {
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_ASSERT(dst->nb[0] == sizeof(float));
 
     const int ith = params->ith;
@@ -15230,10 +14876,6 @@ static void ggml_compute_forward_timestep_embedding_f32(
     const struct ggml_compute_params * params,
     struct ggml_tensor * dst) {
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const struct ggml_tensor * src0 = dst->src[0];
 
     GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -15289,10 +14931,6 @@ static void ggml_compute_forward_argsort_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     GGML_ASSERT(nb0 == sizeof(float));
@@ -15353,8 +14991,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         const struct ggml_tensor * v,
         const struct ggml_tensor * mask,
         struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
 
     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
     GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
@@ -15399,14 +15035,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const int64_t rv2 = neq2/nev2;
     const int64_t rv3 = neq3/nev3;
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // parallelize by q rows using ggml_vec_dot_f32
 
     // total rows in q
@@ -15589,9 +15217,6 @@ static void ggml_compute_forward_flash_attn_back_f32(
     const struct ggml_tensor * v = dst->src[2];
     const struct ggml_tensor * d = dst->src[3];
 
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
     GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
     GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
@@ -15638,16 +15263,10 @@ static void ggml_compute_forward_flash_attn_back_f32(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith == 0) {
-            memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
-        }
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+    if (ith == 0) {
+        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
     }
+    ggml_barrier(params->shared);
 
     const int64_t elem_q = ggml_nelements(q);
     const int64_t elem_k = ggml_nelements(k);
@@ -15927,10 +15546,6 @@ static void ggml_compute_forward_flash_attn_back(
 static void ggml_compute_forward_ssm_conv_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const struct ggml_tensor * src0 = dst->src[0]; // conv_state
     const struct ggml_tensor * src1 = dst->src[1]; // x
     const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
@@ -16053,10 +15668,6 @@ static void ggml_compute_forward_ssm_conv(
 static void ggml_compute_forward_ssm_scan_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const struct ggml_tensor * src0 = dst->src[0]; // s
     const struct ggml_tensor * src1 = dst->src[1]; // x
     const struct ggml_tensor * src2 = dst->src[2]; // dt
@@ -16178,13 +15789,10 @@ static void ggml_compute_forward_ssm_scan(
 static void ggml_compute_forward_win_part_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
+    UNUSED(params);
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 
@@ -16244,13 +15852,10 @@ static void ggml_compute_forward_win_part(
 static void ggml_compute_forward_win_unpart_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
+    UNUSED(params);
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 
@@ -16376,13 +15981,10 @@ static void ggml_compute_forward_unary(
 static void ggml_compute_forward_get_rel_pos_f16(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
+    UNUSED(params);
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
 
     GGML_TENSOR_UNARY_OP_LOCALS
@@ -16432,20 +16034,12 @@ static void ggml_compute_forward_add_rel_pos_f32(
     const struct ggml_tensor * src2 = dst->src[2];
 
     const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
-    if (!inplace && params->type == GGML_TASK_TYPE_INIT) {
-        if (params->ith != 0) {
-            return;
+    if (!inplace) {
+        if (params->ith == 0) {
+            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
         }
-        memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
-        return;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
+        ggml_barrier(params->shared);
     }
-
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
 
     float * src1_data = (float *) src1->data;
@@ -16519,15 +16113,14 @@ static void ggml_compute_forward_map_unary_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -16567,16 +16160,15 @@ static void ggml_compute_forward_map_binary_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    assert(params->ith == 0);
+    if (params->ith != 0) {
+        return;
+    }
+
     assert(ggml_is_contiguous_1(src0));
     assert(ggml_is_contiguous_1(src1));
     assert(ggml_is_contiguous_1(dst));
     assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
@@ -16616,9 +16208,7 @@ static void ggml_compute_forward_map_custom1_f32(
 
     const struct ggml_tensor * a = dst->src[0];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -16635,9 +16225,7 @@ static void ggml_compute_forward_map_custom2_f32(
     const struct ggml_tensor * a = dst->src[0];
     const struct ggml_tensor * b = dst->src[1];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -16655,9 +16243,7 @@ static void ggml_compute_forward_map_custom3_f32(
     const struct ggml_tensor * b = dst->src[1];
     const struct ggml_tensor * c = dst->src[1];
 
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+    if (params->ith != 0) {
         return;
     }
 
@@ -16672,10 +16258,6 @@ static void ggml_compute_forward_map_custom1(
 
     const struct ggml_tensor * a = dst->src[0];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     struct ggml_map_custom1_op_params p;
     memcpy(&p, dst->op_params, sizeof(p));
 
@@ -16691,10 +16273,6 @@ static void ggml_compute_forward_map_custom2(
     const struct ggml_tensor * a = dst->src[0];
     const struct ggml_tensor * b = dst->src[1];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     struct ggml_map_custom2_op_params p;
     memcpy(&p, dst->op_params, sizeof(p));
 
@@ -16711,10 +16289,6 @@ static void ggml_compute_forward_map_custom3(
     const struct ggml_tensor * b = dst->src[1];
     const struct ggml_tensor * c = dst->src[2];
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     struct ggml_map_custom3_op_params p;
     memcpy(&p, dst->op_params, sizeof(p));
 
@@ -16746,21 +16320,10 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
 
     GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
 
-    if (params->type == GGML_TASK_TYPE_INIT) {
-        if (ith == 0) {
-            memset(sums, 0, sizeof(float) * (nth + nth * nc));
-        }
-        return;
-    }
-
-    if (params->type == GGML_TASK_TYPE_FINALIZE) {
-        if (ith == 0) {
-            float * dp = (float *) dst->data;
-            ggml_vec_sum_f32(nth, dp, sums);
-            dp[0] *= -1.0f / (float) nr;
-        }
-        return;
+    if (ith == 0) {
+        memset(sums, 0, sizeof(float) * (nth + nth * nc));
     }
+    ggml_barrier(params->shared);
 
     const double eps = 1e-9;
 
@@ -16808,7 +16371,13 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
         }
 #endif
     }
+    ggml_barrier(params->shared);
 
+    if (ith == 0) {
+        float * dp = (float *) dst->data;
+        ggml_vec_sum_f32(nth, dp, sums);
+        dp[0] *= -1.0f / (float) nr;
+    }
 }
 
 static void ggml_compute_forward_cross_entropy_loss(
@@ -16848,10 +16417,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
     const int64_t ith = params->ith;
     const int64_t nth = params->nth;
 
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
     const double eps = 1e-9;
 
     // TODO: handle transposed/permuted matrices
@@ -16922,7 +16487,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
 
 /////////////////////////////////
 
-static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
     GGML_ASSERT(params);
 
     if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@@ -17020,7 +16585,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_MUL_MAT:
             {
-                ggml_compute_forward_mul_mat(params, tensor, state);
+                ggml_compute_forward_mul_mat(params, tensor);
             } break;
         case GGML_OP_MUL_MAT_ID:
             {
@@ -18498,9 +18063,6 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
         /*.leafs        =*/ leafs_ptr,
         /*.hash_table   =*/ { hash_size, hash_keys_ptr },
         /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
-        /*.perf_runs    =*/ 0,
-        /*.perf_cycles  =*/ 0,
-        /*.perf_time_us =*/ 0,
     };
 
     return cgraph;
@@ -18520,9 +18082,6 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
         /*.leafs        =*/ NULL,
         /*.hash_table   =*/ { 0, NULL },
         /*.order        =*/ cgraph0->order,
-        /*.perf_runs    =*/ 0,
-        /*.perf_cycles  =*/ 0,
-        /*.perf_time_us =*/ 0,
     };
 
     return cgraph;
@@ -18716,16 +18275,7 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n);  }
 static void clear_numa_thread_affinity(void) {}
 #endif
 
-static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
-    int64_t cycles_cur  = ggml_perf_cycles()  - st->perf_node_start_cycles;
-    int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
-
-    node->perf_runs++;
-    node->perf_cycles  += cycles_cur;
-    node->perf_time_us += time_us_cur;
-}
-
-static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
+static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
     int n_tasks = 0;
 
     if (ggml_is_empty(node)) {
@@ -18768,8 +18318,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_SIGMOID:
-                case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
-                case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_HARDSIGMOID:
                     {
                         n_tasks = 1;
                     } break;
@@ -18792,33 +18342,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
         case GGML_OP_RMS_NORM_BACK:
         case GGML_OP_GROUP_NORM:
         case GGML_OP_CONCAT:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_MUL_MAT:
-            {
-                n_tasks = n_threads;
-
-                // TODO: use different scheduling for different matrix sizes
-                //const int nr0 = ggml_nrows(node->src[0]);
-                //const int nr1 = ggml_nrows(node->src[1]);
-
-                //n_tasks = MIN(n_threads, MAX(1, nr0/128));
-                //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
-            } break;
         case GGML_OP_MUL_MAT_ID:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_OUT_PROD:
             {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_GET_ROWS:
             {
-                // FIXME: the cost of launching additional threads decreases performance with GPU offloading
-                //n_tasks = MIN(n_threads, ggml_nelements(node->src[1]));
-                n_tasks = MIN(n_cur_threads, ggml_nelements(node->src[1]));
+                // FIXME: get_rows can use additional threads, but the cost of launching additional threads
+                // decreases performance with GPU offloading
+                //n_tasks = n_threads;
+                n_tasks = 1;
             } break;
         case GGML_OP_SCALE:
         case GGML_OP_SET:
@@ -18848,14 +18383,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
             {
                 n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
             } break;
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_IM2COL:
-            {
-                n_tasks = n_threads;
-            } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
                 n_tasks = n_threads;
@@ -18866,33 +18395,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 n_tasks = 1;
             } break;
         case GGML_OP_UPSCALE:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_PAD:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_ARANGE:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_TIMESTEP_EMBEDDING:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_ARGSORT:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_FLASH_ATTN_EXT:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_FLASH_ATTN_BACK:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
             {
@@ -18940,9 +18448,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 }
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
-            {
-                n_tasks = n_threads;
-            } break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             {
                 n_tasks = n_threads;
@@ -18972,190 +18477,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
     return n_tasks;
 }
 
-static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_compute_state * state, const bool do_yield) {
-    // wait for other threads to finish
-    const int last_node_n = * node_n;
-
-    while (true) {
-        if (do_yield) {
-            sched_yield();
-        }
-
-        *node_n = atomic_load(&state->shared->node_n);
-        if (*node_n != last_node_n) {
-            break;
-        }
-
-#if defined(__SSE3__)
-        // Tell the processor we're spinning.  It's a processor hint for spinlocks.
-        _mm_pause();
-#endif
-    }
-}
-
-static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_compute_state * state, const bool do_yield) {
-    // wait for other threads to finish
-    const int last_task_phase = *task_phase;
-
-    while (true) {
-        if (do_yield) {
-            sched_yield();
-        }
-
-        *task_phase = atomic_load(&state->shared->node_task);
-        if (*task_phase != last_task_phase) {
-            break;
-        }
-
-#if defined(__SSE3__)
-        // Tell the processor we're spinning.  It's a processor hint for spinlocks.
-        _mm_pause();
-#endif
-    }
-}
-
-static thread_ret_t ggml_graph_compute_thread(void * data) {
-    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
-
-    const struct ggml_cgraph * cgraph = state->shared->cgraph;
-    const struct ggml_cplan  * cplan  = state->shared->cplan;
-
-    const int   n_threads   = state->shared->n_threads;
-
-    set_numa_thread_affinity(state->ith);
-
-    int node_n     = -1;
-    int task_phase = GGML_TASK_TYPE_FINALIZE;
-
-    while (true) {
-        if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
-            state->shared->node_n += 1;
-            state->ec = GGML_STATUS_ABORTED;
-            return 0;
-        }
-
-        if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
-            // all other threads are finished and spinning
-            // do finalize and init here so we don't have synchronize again
-            struct ggml_compute_params params = {
-                /*.type  =*/ GGML_TASK_TYPE_FINALIZE,
-                /*.ith   =*/ 0,
-                /*.nth   =*/ 0,
-                /*.wsize =*/ cplan->work_size,
-                /*.wdata =*/ cplan->work_data,
-            };
-
-            if (node_n != -1) {
-                /* FINALIZE */
-                struct ggml_tensor * node = cgraph->nodes[node_n];
-                if (GGML_OP_HAS_FINALIZE[node->op]) {
-                    params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
-                    ggml_compute_forward(&params, node, state);
-                }
-                ggml_graph_compute_perf_stats_node(node, state->shared);
-            }
-
-            // distribute new work or execute it direct if 1T
-            while (++node_n < cgraph->n_nodes) {
-                GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
-                struct ggml_tensor * node = cgraph->nodes[node_n];
-                const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
-
-                state->shared->perf_node_start_cycles  = ggml_perf_cycles();
-                state->shared->perf_node_start_time_us = ggml_perf_time_us();
-
-                params.nth = n_tasks;
-
-                if (n_tasks == 1) {
-                    /* INIT */
-                    if (GGML_OP_HAS_INIT[node->op]) {
-                        params.type = GGML_TASK_TYPE_INIT;
-                        ggml_compute_forward(&params, node, state);
-                    }
-
-                    // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
-                    // they do something more efficient than spinning (?)
-                    params.type = GGML_TASK_TYPE_COMPUTE;
-                    ggml_compute_forward(&params, node, state);
-
-                    if (GGML_OP_HAS_FINALIZE[node->op]) {
-                        params.type = GGML_TASK_TYPE_FINALIZE;
-                        ggml_compute_forward(&params, node, state);
-                    }
-
-                    ggml_graph_compute_perf_stats_node(node, state->shared);
-                } else {
-                    break;
-                }
-
-                if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
-                    break;
-                }
-            }
-
-            task_phase = GGML_TASK_TYPE_INIT;
-            atomic_store(&state->shared->n_active,  n_threads);
-            atomic_store(&state->shared->node_n,    node_n);
-            atomic_store(&state->shared->node_task, task_phase);
-        } else {
-            ggml_graph_compute_thread_sync_node(&node_n,     state, false);
-            ggml_graph_compute_thread_sync_task(&task_phase, state, false);
-        }
-
-        // check if we should stop
-        if (node_n >= cgraph->n_nodes) break;
-
-        /* INIT & COMPUTE */
-        struct ggml_tensor * node = cgraph->nodes[node_n];
-        const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
-
-        struct ggml_compute_params params = {
-            /*.type  =*/ GGML_TASK_TYPE_INIT,
-            /*.ith   =*/ state->ith,
-            /*.nth   =*/ n_tasks,
-            /*.wsize =*/ cplan->work_size,
-            /*.wdata =*/ cplan->work_data,
-        };
-
-        if (state->ith < n_tasks) {
-            if (GGML_OP_HAS_INIT[node->op]) {
-                ggml_compute_forward(&params, node, state);
-            }
-        }
-
-        if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
-            task_phase = GGML_TASK_TYPE_COMPUTE;
-            atomic_store(&state->shared->n_active,  n_threads);
-            atomic_store(&state->shared->node_task, task_phase);
-        }
-        else {
-            // TODO: this sched_yield can have significant impact on the performance - either positive or negative
-            //       depending on the workload and the operating system.
-            //       since it is not clear what is the best approach, it should potentially become user-configurable
-            //       ref: https://github.com/ggerganov/ggml/issues/291
-            // UPD:  adding the do_yield flag seems to resolve the issue universally
-            const bool do_yield = node_n < 0 || cgraph->nodes[node_n]->op == GGML_OP_MUL_MAT;
-            ggml_graph_compute_thread_sync_task(&task_phase, state, do_yield);
-        }
-
-        if (state->ith < n_tasks) {
-            params.type = GGML_TASK_TYPE_COMPUTE;
-            ggml_compute_forward(&params, node, state);
-        }
-
-        if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
-            task_phase = GGML_TASK_TYPE_FINALIZE;
-            atomic_store(&state->shared->n_active,  n_threads);
-            atomic_store(&state->shared->node_task, task_phase);
-        }
-        else {
-            ggml_graph_compute_thread_sync_task(&task_phase, state, false);
-        }
-    }
-
-    return 0;
-}
-
 struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
     if (n_threads <= 0) {
         n_threads = GGML_DEFAULT_N_THREADS;
@@ -19172,7 +18493,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
-        const int n_tasks = ggml_get_n_tasks(node, n_threads, 1);
+        const int n_tasks = ggml_get_n_tasks(node, n_threads);
 
         max_tasks = MAX(max_tasks, n_tasks);
 
@@ -19324,121 +18645,121 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
     return cplan;
 }
 
-static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) {
-    enum ggml_status compute_status = GGML_STATUS_SUCCESS;
+static thread_ret_t ggml_graph_compute_thread(void * data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
 
-#ifdef GGML_USE_OPENMP
-    if (n_threads > 1) {
-        #pragma omp parallel num_threads(n_threads)
-        {
-            #pragma omp single
-            {
-                // update the number of threads from the actual number of threads that we got from OpenMP
-                n_threads = omp_get_num_threads();
-                workers[0].shared->n_threads = n_threads;
-                workers[0].shared->n_active  = n_threads;
-            }
-            ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
-        }
-    } else {
-        ggml_graph_compute_thread(&workers[0]);
-    }
-#else
-    // create thread pool
-    if (n_threads > 1) {
-        for (int j = 1; j < n_threads; ++j) {
-            const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
-            GGML_ASSERT(rc == 0);
-            UNUSED(rc);
-        }
-    }
+    const struct ggml_cgraph * cgraph = state->shared->cgraph;
+    const struct ggml_cplan  * cplan  = state->shared->cplan;
 
-    // this is a work thread too
-    ggml_graph_compute_thread(&workers[0]);
+    set_numa_thread_affinity(state->ith);
 
-    // join or kill thread pool
-    if (n_threads > 1) {
-        for (int j = 1; j < n_threads; j++) {
-            const int rc = ggml_thread_join(workers[j].thrd, NULL);
-            GGML_ASSERT(rc == 0);
-            UNUSED(rc);
+    struct ggml_compute_params params = {
+        /*.ith   =*/ state->ith,
+        /*.nth   =*/ state->shared->n_threads,
+        /*.wsize =*/ cplan->work_size,
+        /*.wdata =*/ cplan->work_data,
+        /*.shared=*/ state->shared,
+    };
+
+    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
+        struct ggml_tensor * node = cgraph->nodes[node_n];
+
+        ggml_compute_forward(&params, node);
+
+        if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+            state->shared->ec = GGML_STATUS_ABORTED;
         }
-    }
-#endif
-    // don't leave affinity set on the main thread
-    clear_numa_thread_affinity();
 
-    for (int j = 0; j < n_threads; j++) {
-        if (workers[j].ec != GGML_STATUS_SUCCESS) {
-            compute_status = workers[j].ec;
+        ggml_barrier(state->shared);
+
+        if (state->shared->ec != GGML_STATUS_SUCCESS) {
             break;
         }
     }
-    return compute_status;
+
+    return 0;
 }
 
 enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
-    {
-        GGML_ASSERT(cplan);
-        GGML_ASSERT(cplan->n_threads > 0);
-
-        if (cplan->work_size > 0) {
-            GGML_ASSERT(cplan->work_data);
-        }
-    }
+    GGML_ASSERT(cplan);
+    GGML_ASSERT(cplan->n_threads > 0);
+    GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
 
     int n_threads = cplan->n_threads;
 
-#if defined(GGML_USE_OPENMP)
-    n_threads = MIN(n_threads, omp_get_max_threads());
-#endif
-
     struct ggml_compute_state_shared state_shared = {
         /*.cgraph                  =*/ cgraph,
         /*.cgraph_plan             =*/ cplan,
-        /*.perf_node_start_cycles  =*/ 0,
-        /*.perf_node_start_time_us =*/ 0,
         /*.n_threads               =*/ n_threads,
-        /*.n_active                =*/ n_threads,
-        /*.node_n                  =*/ -1,
-        /*.node_task               =*/ GGML_TASK_TYPE_FINALIZE,
+        /*.n_barrier               =*/ 0,
+        /*.n_barrier_passed        =*/ 0,
         /*.abort_callback          =*/ NULL,
         /*.abort_callback_data     =*/ NULL,
-        /*.current_chunk;          =*/ 0,
+        /*.current_chunk           =*/ 0,
+        /*.ec                      =*/ GGML_STATUS_SUCCESS,
     };
+
+#ifdef GGML_USE_OPENMP
+    if (n_threads > 1) {
+        #pragma omp parallel num_threads(n_threads)
+        {
+            #pragma omp single
+            {
+                // update the number of threads from the actual number of threads that we got from OpenMP
+                n_threads = omp_get_num_threads();
+                state_shared.n_threads = n_threads;
+            }
+
+            struct ggml_compute_state worker = {
+                .thrd   = 0,
+                .ith    = omp_get_thread_num(),
+                .shared = &state_shared,
+            };
+            ggml_graph_compute_thread(&worker);
+        }
+    } else {
+        struct ggml_compute_state worker = {
+            .thrd   = 0,
+            .ith    = 0,
+            .shared = &state_shared,
+        };
+        ggml_graph_compute_thread(&worker);
+    }
+#else
     struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
-    const int64_t perf_start_cycles  = ggml_perf_cycles();
-    const int64_t perf_start_time_us = ggml_perf_time_us();
 
     for (int j = 0; j < n_threads; ++j) {
         workers[j] = (struct ggml_compute_state) {
             .thrd   = 0,
             .ith    = j,
             .shared = &state_shared,
-            .ec     = GGML_STATUS_SUCCESS,
         };
     }
 
-    enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads);
-
-    // performance stats (graph)
-    {
-        int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_start_cycles;
-        int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us;
+    // create thread pool
+    for (int j = 1; j < n_threads; ++j) {
+        const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+        GGML_ASSERT(rc == 0);
+        UNUSED(rc);
+    }
 
-        cgraph->perf_runs++;
-        cgraph->perf_cycles  += perf_cycles_cur;
-        cgraph->perf_time_us += perf_time_us_cur;
+    // this is a work thread too
+    ggml_graph_compute_thread(&workers[0]);
 
-        GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
-                __func__, cgraph->perf_runs,
-                (double) perf_cycles_cur      / (double) ggml_cycles_per_ms(),
-                (double) cgraph->perf_cycles  / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs,
-                (double) perf_time_us_cur     / 1000.0,
-                (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
+    // join or kill thread pool
+    if (n_threads > 1) {
+        for (int j = 1; j < n_threads; j++) {
+            const int rc = ggml_thread_join(workers[j].thrd, NULL);
+            GGML_ASSERT(rc == 0);
+            UNUSED(rc);
+        }
     }
+#endif
+
+    // don't leave affinity set on the main thread
+    clear_numa_thread_affinity();
 
-    return compute_status;
+    return state_shared.ec;
 }
 
 enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
@@ -19937,24 +19258,16 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context *
 }
 
 void ggml_graph_print(const struct ggml_cgraph * cgraph) {
-    int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};
-
     GGML_PRINT("=== GRAPH ===\n");
 
     GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
-        perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
-
-        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
-                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ", node->perf_runs,
-                (double) node->perf_cycles  / (double) ggml_cycles_per_ms(),
-                (double) node->perf_cycles  / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
-                (double) node->perf_time_us / 1000.0,
-                (double) node->perf_time_us / 1000.0 / node->perf_runs);
+                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ");
     }
 
     GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs);
@@ -19968,14 +19281,6 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
                 ggml_get_name(node));
     }
 
-    for (int i = 0; i < GGML_OP_COUNT; i++) {
-        if (perf_total_per_op_us[i] == 0) {
-            continue;
-        }
-
-        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0);
-    }
-
     GGML_PRINT("========================================\n");
 }
 
index 4ddb6664be31de3436db31f8dc390179197143fd..a64f554e586baea42a7a7a2b4dfdb06b7f453c53 100644 (file)
@@ -1,3 +1,5 @@
+find_library(MATH_LIBRARY m)
+
 # check systems
 if (NOT UNAME_S)
     execute_process(COMMAND uname -s OUTPUT_VARIABLE UNAME_S)
@@ -253,6 +255,10 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml)
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
 
+if (MATH_LIBRARY)
+    target_link_libraries(test-mul-mat2 PRIVATE ${MATH_LIBRARY})
+endif()
+
 #
 # test0
 
index 1d490148365bcfc81fc4d56f793e7062b2277737..db63b6a8fb33bef51053518730e5de6eb63f8c97 100644 (file)
@@ -1,6 +1,6 @@
-#include "ggml/ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #ifdef GGML_USE_CUBLAS
 #include "ggml-cuda.h"
index 71181d92ac0f0310605f60f355c4f7ce8246f08e..e8b62f28250f094de6a691fe0a0c31e9373444a9 100644 (file)
@@ -1,11 +1,10 @@
-#include <cstring>
 #include <ggml.h>
 #include <ggml-alloc.h>
 #include <ggml-backend.h>
-#include <ggml-backend-impl.h>
-#include <stdio.h>
-#include <stdlib.h>
 
+#include <cstring>
+#include <cstdio>
+#include <cstdlib>
 
 static bool is_pow2(size_t x) {
     return (x & (x - 1)) == 0;
index 2b48e623e34764611bad5bde8c58d59b04ef8821..f74c0db475e2e2512fe1e41bab961fde4187a82c 100644 (file)
@@ -1,7 +1,6 @@
 #include <ggml.h>
 #include <ggml-alloc.h>
 #include <ggml-backend.h>
-#include <ggml-backend-impl.h>
 
 #include <algorithm>
 #include <array>
@@ -785,6 +784,10 @@ struct test_cpy : public test_case {
         return VARS_TO_STR3(type_src, type_dst, ne);
     }
 
+    double max_nmse_err() override {
+        return 1e-6;
+    }
+
     size_t op_size(ggml_tensor * t) override {
         return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
     }
@@ -1063,6 +1066,33 @@ struct test_sqr : public test_case {
     }
 };
 
+// GGML_OP_SQRT
+struct test_sqrt : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sqrt(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_sqrt(ctx, a);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        // fill with positive values
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, 0.0f, 100.0f);
+        }
+    }
+};
+
 // GGML_OP_CLAMP
 struct test_clamp : public test_case {
     const ggml_type type;
@@ -2200,6 +2230,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 
     test_cases.emplace_back(new test_sqr());
+    test_cases.emplace_back(new test_sqrt());
     test_cases.emplace_back(new test_clamp());
 
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10,  1,  1}, 5));
index 116266e9fd547f57bc7f03cb7281830ccdfbfbb6..da930d3a56d9428db7f47ba332780d65025ff772 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <string.h>
 #include <stdio.h>
index b6daa7b6963fa333929950b5da278111e5b6c570..f2ea01b1ceb24ce6f90a570ee5cec869b690264b 100644 (file)
@@ -1,6 +1,6 @@
 #include "ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 // #define GGML_USE_CUBLAS
 
index 35d565444adddc245cd29bf3c752f01578a7f036..98005cd2d663ef4b46bfe51af53b406101452572 100644 (file)
@@ -1,6 +1,6 @@
 #include "ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 // #define GGML_USE_CUBLAS
 
index e96aa67c9bf6629336369998ead380fd7d05cd93..c03841fb316d09e1b1659a6bcfd790017f61c6d6 100644 (file)
@@ -1,4 +1,5 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
+
 #include <string.h>
 #include <stdio.h>
 #include <stdlib.h>
index afc887b7b143c3464a285fb860404fd0de6630f5..1ff133fabad1c083180bb41d74b49dafdbf65e5a 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <stdio.h>
 #include <stdlib.h>
index 07a6ffeba7bcf56f1b9c5e4e77d4021ace5d0c23..bf194ce1566a5be079a0e46c79da7f46f5f23396 100644 (file)
@@ -1,6 +1,6 @@
 #include "ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 //#define GGML_USE_CUBLAS // uncomment this to use cuda backend, make sure build ggml lib with GGML_CUBLAS=ON
 
index ee52b7a6cd3b7799f15d11f5286b82f60ba2d470..ebd004b0473baa2b16813ece95b8f4a22bb8224d 100644 (file)
@@ -1,5 +1,5 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <math.h>
 #include <stdio.h>
index e4626ec63dd963ec8f1d6a532957909fc70fe3fa..d1252927be07e9f45a9742e9c21603d20c394e69 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <string.h>
 #include <stdio.h>
index 47c8438555a82b021bddb40f55515460c3f7df5b..e525e5ca1e6fc398018ea9ab8bc1c0260309c7bf 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <string.h>
 #include <stdio.h>
index 29ddd0bd4d3280bb0f4e4ff89e538d20a7f5b185..58bcd2f0d7a644ec6d9895fdcd02ea318befe140 100644 (file)
@@ -1,6 +1,6 @@
-#include "ggml/ggml.h"
-#include "ggml/ggml-alloc.h"
-#include "ggml/ggml-backend.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #ifdef GGML_USE_CUBLAS
 #include "ggml-cuda.h"
index 2d2fa85bb0023a6c2d80180f4fee8cc7dc36d134..09154b46919859af2fa2c9a356ca81dd1b066e4c 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <stdio.h>
 #include <stdlib.h>
index 26994d8952df3bd49527717a5daf2443d5b18dd4..5e6259421f6f056089e1a39194c208ba09b537c1 100644 (file)
@@ -1,41 +1,41 @@
-const std = @import("std");\r
-const c = @cImport({\r
-    @cInclude("ggml/ggml.h");\r
-});\r
-\r
-pub fn main() !void {\r
-    const params = .{\r
-        .mem_size = 128 * 1024 * 1024,\r
-        .mem_buffer = null,\r
-        .no_alloc = false,\r
-    };\r
-\r
-    const ctx0 = c.ggml_init(params);\r
-    defer c.ggml_free(ctx0);\r
-\r
-    const t1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 10);\r
-    const t2 = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_I16, 10, 20);\r
-    const t3 = c.ggml_new_tensor_3d(ctx0, c.GGML_TYPE_I32, 10, 20, 30);\r
-\r
-    try std.testing.expect(c.ggml_n_dims(t1) == 1);\r
-    try std.testing.expect(t1.*.ne[0] == 10);\r
-    try std.testing.expect(t1.*.nb[1] == 10 * @sizeOf(f32));\r
-\r
-    try std.testing.expect(c.ggml_n_dims(t2) == 2);\r
-    try std.testing.expect(t2.*.ne[0] == 10);\r
-    try std.testing.expect(t2.*.ne[1] == 20);\r
-    try std.testing.expect(t2.*.nb[1] == 10 * @sizeOf(i16));\r
-    try std.testing.expect(t2.*.nb[2] == 10 * 20 * @sizeOf(i16));\r
-\r
-    try std.testing.expect(c.ggml_n_dims(t3) == 3);\r
-    try std.testing.expect(t3.*.ne[0] == 10);\r
-    try std.testing.expect(t3.*.ne[1] == 20);\r
-    try std.testing.expect(t3.*.ne[2] == 30);\r
-    try std.testing.expect(t3.*.nb[1] == 10 * @sizeOf(i32));\r
-    try std.testing.expect(t3.*.nb[2] == 10 * 20 * @sizeOf(i32));\r
-    try std.testing.expect(t3.*.nb[3] == 10 * 20 * 30 * @sizeOf(i32));\r
-\r
-    c.ggml_print_objects(ctx0);\r
-\r
-    _ = try std.io.getStdIn().reader().readByte();\r
-}\r
+const std = @import("std");
+const c = @cImport({
+    @cInclude("ggml.h");
+});
+
+pub fn main() !void {
+    const params = .{
+        .mem_size = 128 * 1024 * 1024,
+        .mem_buffer = null,
+        .no_alloc = false,
+    };
+
+    const ctx0 = c.ggml_init(params);
+    defer c.ggml_free(ctx0);
+
+    const t1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 10);
+    const t2 = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_I16, 10, 20);
+    const t3 = c.ggml_new_tensor_3d(ctx0, c.GGML_TYPE_I32, 10, 20, 30);
+
+    try std.testing.expect(c.ggml_n_dims(t1) == 1);
+    try std.testing.expect(t1.*.ne[0] == 10);
+    try std.testing.expect(t1.*.nb[1] == 10 * @sizeOf(f32));
+
+    try std.testing.expect(c.ggml_n_dims(t2) == 2);
+    try std.testing.expect(t2.*.ne[0] == 10);
+    try std.testing.expect(t2.*.ne[1] == 20);
+    try std.testing.expect(t2.*.nb[1] == 10 * @sizeOf(i16));
+    try std.testing.expect(t2.*.nb[2] == 10 * 20 * @sizeOf(i16));
+
+    try std.testing.expect(c.ggml_n_dims(t3) == 3);
+    try std.testing.expect(t3.*.ne[0] == 10);
+    try std.testing.expect(t3.*.ne[1] == 20);
+    try std.testing.expect(t3.*.ne[2] == 30);
+    try std.testing.expect(t3.*.nb[1] == 10 * @sizeOf(i32));
+    try std.testing.expect(t3.*.nb[2] == 10 * 20 * @sizeOf(i32));
+    try std.testing.expect(t3.*.nb[3] == 10 * 20 * 30 * @sizeOf(i32));
+
+    c.ggml_print_objects(ctx0);
+
+    _ = try std.io.getStdIn().reader().readByte();
+}
index 230aaed856d16b697c1cb40152c7a7620036a43a..7b5a546a8e1900d2c110fa6258072fcf07cb545b 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <stdio.h>
 #include <stdlib.h>
index 507562c4196d377957c392b60b44a5aa69cccf23..815d81438bbfc2b06e9998d8a545628cf129bf48 100644 (file)
-const std = @import("std");\r
-const c = @cImport({\r
-    @cInclude("ggml/ggml.h");\r
-});\r
-\r
-pub fn main() !void {\r
-    const n_threads = 2;\r
-\r
-    const params = .{\r
-        .mem_size = 128 * 1024 * 1024,\r
-        .mem_buffer = null,\r
-        .no_alloc = false,\r
-    };\r
-\r
-    const ctx0 = c.ggml_init(params);\r
-    defer c.ggml_free(ctx0);\r
-\r
-    {\r
-        const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-\r
-        c.ggml_set_param(ctx0, x);\r
-\r
-        const a = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-        const b = c.ggml_mul(ctx0, x, x);\r
-        const f = c.ggml_mul(ctx0, b, a);\r
-\r
-        // a*x^2\r
-        // 2*a*x\r
-\r
-        c.ggml_print_objects(ctx0);\r
-\r
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);\r
-        c.ggml_build_forward_expand(gf, f);\r
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);\r
-\r
-        _ = c.ggml_set_f32(x, 2.0);\r
-        _ = c.ggml_set_f32(a, 3.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(f.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("f     = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});\r
-        std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 12.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 12.0);\r
-\r
-        _ = c.ggml_set_f32(x, 3.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(f.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("f     = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});\r
-        std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 27.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 18.0);\r
-\r
-        c.ggml_graph_dump_dot(gf, null, "test1-1-forward.dot");\r
-        c.ggml_graph_dump_dot(gb, gf, "test1-1-backward.dot");\r
-    }\r
-\r
-    /////////////////////////////////////////////////////////////\r
-\r
-    {\r
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-        const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-\r
-        _ = c.ggml_set_f32(x1, 3.0);\r
-        _ = c.ggml_set_f32(x2, 1.0);\r
-        _ = c.ggml_set_f32(x3, 0.0);\r
-\r
-        c.ggml_set_param(ctx0, x1);\r
-        c.ggml_set_param(ctx0, x2);\r
-\r
-        const y = c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2));\r
-\r
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);\r
-        c.ggml_build_forward_expand(gf, y);\r
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(y.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
-        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
-        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 7.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);\r
-\r
-        const g1 = x1.*.grad;\r
-        const g2 = x2.*.grad;\r
-\r
-        const gbb = c.ggml_graph_dup(ctx0, @constCast(gb));\r
-\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gb), @constCast(gbb), true);\r
-\r
-        c.ggml_graph_reset(@constCast(gb));\r
-        _ = c.ggml_set_f32(g1.*.grad, 1.0);\r
-        _ = c.ggml_set_f32(g2.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gbb), n_threads);\r
-\r
-        std.debug.print("H * [1, 1] = [ {d:.6} {d:.6} ]\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0) });\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 3.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0);\r
-\r
-        c.ggml_graph_dump_dot(gf, null, "test1-2-forward.dot");\r
-        c.ggml_graph_dump_dot(gb, gf, "test1-2-backward.dot");\r
-    }\r
-\r
-    ///////////////////////////////////////////////////////////////\r
-\r
-    {\r
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-\r
-        c.ggml_set_param(ctx0, x1);\r
-        c.ggml_set_param(ctx0, x2);\r
-\r
-        const y = c.ggml_mul(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2)), x1);\r
-\r
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);\r
-        c.ggml_build_forward_expand(gf, y);\r
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);\r
-\r
-        _ = c.ggml_set_f32(x1, 3.0);\r
-        _ = c.ggml_set_f32(x2, 4.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(y.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
-        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
-        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 63.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 51.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 9.0);\r
-\r
-        c.ggml_graph_dump_dot(gf, null, "test1-3-forward.dot");\r
-        c.ggml_graph_dump_dot(gb, gf, "test1-3-backward.dot");\r
-    }\r
-\r
-    ///////////////////////////////////////////////////////////////\r
-\r
-    {\r
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-        const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);\r
-\r
-        c.ggml_set_param(ctx0, x1);\r
-        c.ggml_set_param(ctx0, x2);\r
-        c.ggml_set_param(ctx0, x3);\r
-\r
-        const y = c.ggml_mul(ctx0, c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x2, x2)), x3);\r
-\r
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);\r
-        c.ggml_build_forward_expand(gf, y);\r
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);\r
-\r
-        _ = c.ggml_set_f32(x1, 1.0);\r
-        _ = c.ggml_set_f32(x2, 2.0);\r
-        _ = c.ggml_set_f32(x3, 3.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(y.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
-        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
-        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});\r
-        std.debug.print("df/dx3 = {d:.6}\n", .{c.ggml_get_f32_1d(x3.*.grad, 0)});\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 24.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 12.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 4.0);\r
-\r
-        const g1 = x1.*.grad;\r
-        const g2 = x2.*.grad;\r
-        const g3 = x3.*.grad;\r
-\r
-        const gbb = c.ggml_graph_dup(ctx0, @constCast(gb));\r
-\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gb), @constCast(gbb), true);\r
-\r
-        c.ggml_graph_reset(@constCast(gb));\r
-        _ = c.ggml_set_f32(g1.*.grad, 1.0);\r
-        _ = c.ggml_set_f32(g2.*.grad, 1.0);\r
-        _ = c.ggml_set_f32(g3.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gbb), n_threads);\r
-\r
-        std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n", .{\r
-            c.ggml_get_f32_1d(x1.*.grad, 0),\r
-            c.ggml_get_f32_1d(x2.*.grad, 0),\r
-            c.ggml_get_f32_1d(x3.*.grad, 0),\r
-        });\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 56.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 34.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 12.0);\r
-\r
-        c.ggml_graph_dump_dot(gf, null, "test1-4-forward.dot");\r
-        c.ggml_graph_dump_dot(gb, gf, "test1-4-backward.dot");\r
-    }\r
-\r
-    ///////////////////////////////////////////////////////////////\r
-\r
-    {\r
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);\r
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);\r
-\r
-        c.ggml_set_param(ctx0, x1);\r
-        c.ggml_set_param(ctx0, x2);\r
-\r
-        const y = c.ggml_sum(ctx0, c.ggml_mul(ctx0, x1, x2));\r
-\r
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);\r
-        c.ggml_build_forward_expand(gf, y);\r
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);\r
-\r
-        _ = c.ggml_set_f32(x1, 3.0);\r
-        _ = c.ggml_set_f32(x2, 5.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(y.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x1.*.grad, 0),\r
-            c.ggml_get_f32_1d(x1.*.grad, 1),\r
-            c.ggml_get_f32_1d(x1.*.grad, 2),\r
-        });\r
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x2.*.grad, 0),\r
-            c.ggml_get_f32_1d(x2.*.grad, 1),\r
-            c.ggml_get_f32_1d(x2.*.grad, 2),\r
-        });\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 45.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 5.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 5.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 5.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);\r
-\r
-        c.ggml_graph_dump_dot(gf, null, "test1-5-forward.dot");\r
-        c.ggml_graph_dump_dot(gb, gf, "test1-5-backward.dot");\r
-    }\r
-\r
-    ///////////////////////////////////////////////////////////////\r
-\r
-    {\r
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);\r
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);\r
-\r
-        c.ggml_set_param(ctx0, x1);\r
-        c.ggml_set_param(ctx0, x2);\r
-\r
-        const y =\r
-            c.ggml_sum(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x2), c.ggml_mul(ctx0, c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1), c.ggml_mul(ctx0, x1, x1))));\r
-\r
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);\r
-        c.ggml_build_forward_expand(gf, y);\r
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);\r
-\r
-        _ = c.ggml_set_f32(x1, 3.0);\r
-        _ = c.ggml_set_f32(x2, 5.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(y.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x1.*.grad, 0),\r
-            c.ggml_get_f32_1d(x1.*.grad, 1),\r
-            c.ggml_get_f32_1d(x1.*.grad, 2),\r
-        });\r
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x2.*.grad, 0),\r
-            c.ggml_get_f32_1d(x2.*.grad, 1),\r
-            c.ggml_get_f32_1d(x2.*.grad, 2),\r
-        });\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == -9.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -7.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -7.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -7.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);\r
-\r
-        c.ggml_graph_dump_dot(gf, null, "test1-6-forward.dot");\r
-        c.ggml_graph_dump_dot(gb, gf, "test1-6-backward.dot");\r
-    }\r
-\r
-    ///////////////////////////////////////////////////////////////\r
-\r
-    {\r
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);\r
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);\r
-\r
-        c.ggml_set_param(ctx0, x1);\r
-        c.ggml_set_param(ctx0, x2);\r
-\r
-        const y =\r
-            c.ggml_sum(ctx0, c.ggml_sub(ctx0, c.ggml_mul(ctx0, x1, x2), c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1))));\r
-\r
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);\r
-        c.ggml_build_forward_expand(gf, y);\r
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);\r
-\r
-        _ = c.ggml_set_f32(x1, 3.0);\r
-        _ = c.ggml_set_f32(x2, 5.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(y.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x1.*.grad, 0),\r
-            c.ggml_get_f32_1d(x1.*.grad, 1),\r
-            c.ggml_get_f32_1d(x1.*.grad, 2),\r
-        });\r
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x2.*.grad, 0),\r
-            c.ggml_get_f32_1d(x2.*.grad, 1),\r
-            c.ggml_get_f32_1d(x2.*.grad, 2),\r
-        });\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 99.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 17.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 17.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 17.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);\r
-\r
-        c.ggml_graph_dump_dot(gf, null, "test1-7-forward.dot");\r
-        c.ggml_graph_dump_dot(gb, gf, "test1-7-backward.dot");\r
-    }\r
-\r
-    ///////////////////////////////////////////////////////////////\r
-\r
-    {\r
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);\r
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);\r
-\r
-        c.ggml_set_param(ctx0, x1);\r
-        c.ggml_set_param(ctx0, x2);\r
-\r
-        const y =\r
-            c.ggml_abs(ctx0, c.ggml_sub(ctx0, x1, x2));\r
-\r
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);\r
-        c.ggml_build_forward_expand(gf, y);\r
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));\r
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);\r
-\r
-        _ = c.ggml_set_f32(x1, 3.0);\r
-        _ = c.ggml_set_f32(x2, 5.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(y.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x1.*.grad, 0),\r
-            c.ggml_get_f32_1d(x1.*.grad, 1),\r
-            c.ggml_get_f32_1d(x1.*.grad, 2),\r
-        });\r
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x2.*.grad, 0),\r
-            c.ggml_get_f32_1d(x2.*.grad, 1),\r
-            c.ggml_get_f32_1d(x2.*.grad, 2),\r
-        });\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 1.0);\r
-\r
-        _ = c.ggml_set_f32(x1, 7.0);\r
-        _ = c.ggml_set_f32(x2, 5.0);\r
-\r
-        c.ggml_graph_reset(@constCast(gf));\r
-        _ = c.ggml_set_f32(y.*.grad, 1.0);\r
-\r
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);\r
-\r
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x1.*.grad, 0),\r
-            c.ggml_get_f32_1d(x1.*.grad, 1),\r
-            c.ggml_get_f32_1d(x1.*.grad, 2),\r
-        });\r
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{\r
-            c.ggml_get_f32_1d(x2.*.grad, 0),\r
-            c.ggml_get_f32_1d(x2.*.grad, 1),\r
-            c.ggml_get_f32_1d(x2.*.grad, 2),\r
-        });\r
-\r
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == -1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == -1.0);\r
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == -1.0);\r
-\r
-        c.ggml_graph_dump_dot(gf, null, "test1-8-forward.dot");\r
-        c.ggml_graph_dump_dot(gb, gf, "test1-8-backward.dot");\r
-    }\r
-\r
-    _ = try std.io.getStdIn().reader().readByte();\r
-}\r
+const std = @import("std");
+const c = @cImport({
+    @cInclude("ggml.h");
+});
+
+pub fn main() !void {
+    const n_threads = 2;
+
+    const params = .{
+        .mem_size = 128 * 1024 * 1024,
+        .mem_buffer = null,
+        .no_alloc = false,
+    };
+
+    const ctx0 = c.ggml_init(params);
+    defer c.ggml_free(ctx0);
+
+    {
+        const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+
+        c.ggml_set_param(ctx0, x);
+
+        const a = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+        const b = c.ggml_mul(ctx0, x, x);
+        const f = c.ggml_mul(ctx0, b, a);
+
+        // a*x^2
+        // 2*a*x
+
+        c.ggml_print_objects(ctx0);
+
+        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
+        c.ggml_build_forward_expand(gf, f);
+        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
+        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
+
+        _ = c.ggml_set_f32(x, 2.0);
+        _ = c.ggml_set_f32(a, 3.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(f.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("f     = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});
+        std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});
+
+        try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 12.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 12.0);
+
+        _ = c.ggml_set_f32(x, 3.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(f.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("f     = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});
+        std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});
+
+        try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 27.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 18.0);
+
+        c.ggml_graph_dump_dot(gf, null, "test1-1-forward.dot");
+        c.ggml_graph_dump_dot(gb, gf, "test1-1-backward.dot");
+    }
+
+    /////////////////////////////////////////////////////////////
+
+    {
+        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+        const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+
+        _ = c.ggml_set_f32(x1, 3.0);
+        _ = c.ggml_set_f32(x2, 1.0);
+        _ = c.ggml_set_f32(x3, 0.0);
+
+        c.ggml_set_param(ctx0, x1);
+        c.ggml_set_param(ctx0, x2);
+
+        const y = c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2));
+
+        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
+        c.ggml_build_forward_expand(gf, y);
+        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
+        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(y.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
+        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});
+        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});
+
+        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 7.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
+
+        const g1 = x1.*.grad;
+        const g2 = x2.*.grad;
+
+        const gbb = c.ggml_graph_dup(ctx0, @constCast(gb));
+
+        c.ggml_build_backward_expand(ctx0, @constCast(gb), @constCast(gbb), true);
+
+        c.ggml_graph_reset(@constCast(gb));
+        _ = c.ggml_set_f32(g1.*.grad, 1.0);
+        _ = c.ggml_set_f32(g2.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gbb), n_threads);
+
+        std.debug.print("H * [1, 1] = [ {d:.6} {d:.6} ]\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0) });
+
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 3.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0);
+
+        c.ggml_graph_dump_dot(gf, null, "test1-2-forward.dot");
+        c.ggml_graph_dump_dot(gb, gf, "test1-2-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+
+        c.ggml_set_param(ctx0, x1);
+        c.ggml_set_param(ctx0, x2);
+
+        const y = c.ggml_mul(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2)), x1);
+
+        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
+        c.ggml_build_forward_expand(gf, y);
+        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
+        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
+
+        _ = c.ggml_set_f32(x1, 3.0);
+        _ = c.ggml_set_f32(x2, 4.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(y.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
+        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});
+        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});
+
+        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 63.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 51.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 9.0);
+
+        c.ggml_graph_dump_dot(gf, null, "test1-3-forward.dot");
+        c.ggml_graph_dump_dot(gb, gf, "test1-3-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+        const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
+
+        c.ggml_set_param(ctx0, x1);
+        c.ggml_set_param(ctx0, x2);
+        c.ggml_set_param(ctx0, x3);
+
+        const y = c.ggml_mul(ctx0, c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x2, x2)), x3);
+
+        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
+        c.ggml_build_forward_expand(gf, y);
+        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
+        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
+
+        _ = c.ggml_set_f32(x1, 1.0);
+        _ = c.ggml_set_f32(x2, 2.0);
+        _ = c.ggml_set_f32(x3, 3.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(y.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
+        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});
+        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});
+        std.debug.print("df/dx3 = {d:.6}\n", .{c.ggml_get_f32_1d(x3.*.grad, 0)});
+
+        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 24.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 12.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 4.0);
+
+        const g1 = x1.*.grad;
+        const g2 = x2.*.grad;
+        const g3 = x3.*.grad;
+
+        const gbb = c.ggml_graph_dup(ctx0, @constCast(gb));
+
+        c.ggml_build_backward_expand(ctx0, @constCast(gb), @constCast(gbb), true);
+
+        c.ggml_graph_reset(@constCast(gb));
+        _ = c.ggml_set_f32(g1.*.grad, 1.0);
+        _ = c.ggml_set_f32(g2.*.grad, 1.0);
+        _ = c.ggml_set_f32(g3.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gbb), n_threads);
+
+        std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n", .{
+            c.ggml_get_f32_1d(x1.*.grad, 0),
+            c.ggml_get_f32_1d(x2.*.grad, 0),
+            c.ggml_get_f32_1d(x3.*.grad, 0),
+        });
+
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 56.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 34.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 12.0);
+
+        c.ggml_graph_dump_dot(gf, null, "test1-4-forward.dot");
+        c.ggml_graph_dump_dot(gb, gf, "test1-4-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
+        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
+
+        c.ggml_set_param(ctx0, x1);
+        c.ggml_set_param(ctx0, x2);
+
+        const y = c.ggml_sum(ctx0, c.ggml_mul(ctx0, x1, x2));
+
+        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
+        c.ggml_build_forward_expand(gf, y);
+        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
+        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
+
+        _ = c.ggml_set_f32(x1, 3.0);
+        _ = c.ggml_set_f32(x2, 5.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(y.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
+        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x1.*.grad, 0),
+            c.ggml_get_f32_1d(x1.*.grad, 1),
+            c.ggml_get_f32_1d(x1.*.grad, 2),
+        });
+        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x2.*.grad, 0),
+            c.ggml_get_f32_1d(x2.*.grad, 1),
+            c.ggml_get_f32_1d(x2.*.grad, 2),
+        });
+
+        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 45.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 5.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 5.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 5.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);
+
+        c.ggml_graph_dump_dot(gf, null, "test1-5-forward.dot");
+        c.ggml_graph_dump_dot(gb, gf, "test1-5-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
+        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
+
+        c.ggml_set_param(ctx0, x1);
+        c.ggml_set_param(ctx0, x2);
+
+        const y =
+            c.ggml_sum(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x2), c.ggml_mul(ctx0, c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1), c.ggml_mul(ctx0, x1, x1))));
+
+        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
+        c.ggml_build_forward_expand(gf, y);
+        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
+        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
+
+        _ = c.ggml_set_f32(x1, 3.0);
+        _ = c.ggml_set_f32(x2, 5.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(y.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
+        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x1.*.grad, 0),
+            c.ggml_get_f32_1d(x1.*.grad, 1),
+            c.ggml_get_f32_1d(x1.*.grad, 2),
+        });
+        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x2.*.grad, 0),
+            c.ggml_get_f32_1d(x2.*.grad, 1),
+            c.ggml_get_f32_1d(x2.*.grad, 2),
+        });
+
+        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == -9.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -7.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -7.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -7.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);
+
+        c.ggml_graph_dump_dot(gf, null, "test1-6-forward.dot");
+        c.ggml_graph_dump_dot(gb, gf, "test1-6-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
+        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
+
+        c.ggml_set_param(ctx0, x1);
+        c.ggml_set_param(ctx0, x2);
+
+        const y =
+            c.ggml_sum(ctx0, c.ggml_sub(ctx0, c.ggml_mul(ctx0, x1, x2), c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1))));
+
+        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
+        c.ggml_build_forward_expand(gf, y);
+        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
+        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
+
+        _ = c.ggml_set_f32(x1, 3.0);
+        _ = c.ggml_set_f32(x2, 5.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(y.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
+        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x1.*.grad, 0),
+            c.ggml_get_f32_1d(x1.*.grad, 1),
+            c.ggml_get_f32_1d(x1.*.grad, 2),
+        });
+        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x2.*.grad, 0),
+            c.ggml_get_f32_1d(x2.*.grad, 1),
+            c.ggml_get_f32_1d(x2.*.grad, 2),
+        });
+
+        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 99.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 17.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 17.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 17.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);
+
+        c.ggml_graph_dump_dot(gf, null, "test1-7-forward.dot");
+        c.ggml_graph_dump_dot(gb, gf, "test1-7-backward.dot");
+    }
+
+    ///////////////////////////////////////////////////////////////
+
+    {
+        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
+        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
+
+        c.ggml_set_param(ctx0, x1);
+        c.ggml_set_param(ctx0, x2);
+
+        const y =
+            c.ggml_abs(ctx0, c.ggml_sub(ctx0, x1, x2));
+
+        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
+        c.ggml_build_forward_expand(gf, y);
+        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
+        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
+
+        _ = c.ggml_set_f32(x1, 3.0);
+        _ = c.ggml_set_f32(x2, 5.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(y.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
+        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x1.*.grad, 0),
+            c.ggml_get_f32_1d(x1.*.grad, 1),
+            c.ggml_get_f32_1d(x1.*.grad, 2),
+        });
+        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x2.*.grad, 0),
+            c.ggml_get_f32_1d(x2.*.grad, 1),
+            c.ggml_get_f32_1d(x2.*.grad, 2),
+        });
+
+        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 1.0);
+
+        _ = c.ggml_set_f32(x1, 7.0);
+        _ = c.ggml_set_f32(x2, 5.0);
+
+        c.ggml_graph_reset(@constCast(gf));
+        _ = c.ggml_set_f32(y.*.grad, 1.0);
+
+        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
+
+        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
+        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x1.*.grad, 0),
+            c.ggml_get_f32_1d(x1.*.grad, 1),
+            c.ggml_get_f32_1d(x1.*.grad, 2),
+        });
+        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
+            c.ggml_get_f32_1d(x2.*.grad, 0),
+            c.ggml_get_f32_1d(x2.*.grad, 1),
+            c.ggml_get_f32_1d(x2.*.grad, 2),
+        });
+
+        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == -1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == -1.0);
+        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == -1.0);
+
+        c.ggml_graph_dump_dot(gf, null, "test1-8-forward.dot");
+        c.ggml_graph_dump_dot(gb, gf, "test1-8-backward.dot");
+    }
+
+    _ = try std.io.getStdIn().reader().readByte();
+}
index fb29a9fa846c7e927b153a5febbb3cfe2db76fcb..0b3adb8f35f2b10f8065b599b1433b890aa528a5 100644 (file)
@@ -1,5 +1,5 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <math.h>
 #include <stdio.h>
index 7c68d6d1591801577dcc4eb5cd7851aca073b50d..783eba6e636644a6479d2077887c6845a3cce60d 100644 (file)
-const std = @import("std");\r
-const Thread = std.Thread;\r
-const c = @cImport({\r
-    @cInclude("ggml/ggml.h");\r
-});\r
-\r
-fn is_close(a: f32, b: f32, epsilon: f32) bool {\r
-    return @abs(a - b) < epsilon;\r
-}\r
-\r
-pub fn main() !void {\r
-    const params = .{\r
-        .mem_size = 128 * 1024 * 1024,\r
-        .mem_buffer = null,\r
-        .no_alloc = false,\r
-    };\r
-\r
-    var opt_params = c.ggml_opt_default_params(c.GGML_OPT_TYPE_LBFGS);\r
-\r
-    const nthreads = try Thread.getCpuCount();\r
-    opt_params.n_threads = @intCast(nthreads);\r
-    std.debug.print("test2: n_threads:{}\n", .{opt_params.n_threads});\r
-\r
-    const xi = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 };\r
-    const yi = [_]f32{ 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0, 105.0 };\r
-\r
-    const n = xi.len;\r
-\r
-    const ctx0 = c.ggml_init(params);\r
-    defer c.ggml_free(ctx0);\r
-\r
-    const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n);\r
-    const y = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n);\r
-\r
-    for (0..n) |i| {\r
-        const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data));\r
-        x_data_pointer[i] = xi[i];\r
-        const y_data_pointer: [*]f32 = @ptrCast(@alignCast(y.*.data));\r
-        y_data_pointer[i] = yi[i];\r
-    }\r
-\r
-    {\r
-        const t0 = c.ggml_new_f32(ctx0, 0.0);\r
-        const t1 = c.ggml_new_f32(ctx0, 0.0);\r
-\r
-        // initialize auto-diff parameters:\r
-        _ = c.ggml_set_param(ctx0, t0);\r
-        _ = c.ggml_set_param(ctx0, t1);\r
-\r
-        // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n)\r
-        const f =\r
-            c.ggml_div(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), c.ggml_repeat(ctx0, t0, x)), y))), c.ggml_new_f32(ctx0, @as(f32, 2.0) * n));\r
-\r
-        const res = c.ggml_opt(null, opt_params, f);\r
-\r
-        std.debug.print("t0 = {d:.6}\n", .{c.ggml_get_f32_1d(t0, 0)});\r
-        std.debug.print("t1 = {d:.6}\n", .{c.ggml_get_f32_1d(t1, 0)});\r
-\r
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-3));\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-3));\r
-    }\r
-\r
-    {\r
-        const t0 = c.ggml_new_f32(ctx0, -1.0);\r
-        const t1 = c.ggml_new_f32(ctx0, 9.0);\r
-\r
-        _ = c.ggml_set_param(ctx0, t0);\r
-        _ = c.ggml_set_param(ctx0, t1);\r
-\r
-        // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n\r
-        const f =\r
-            c.ggml_mul(ctx0, c.ggml_new_f32(ctx0, @as(f32, 1.0) / (2 * n)), c.ggml_sum(ctx0, c.ggml_abs(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), c.ggml_repeat(ctx0, t0, x)), y))));\r
-\r
-        const res = c.ggml_opt(null, opt_params, f);\r
-\r
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-2));\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-2));\r
-    }\r
-\r
-    {\r
-        const t0 = c.ggml_new_f32(ctx0, 5.0);\r
-        const t1 = c.ggml_new_f32(ctx0, -4.0);\r
-\r
-        _ = c.ggml_set_param(ctx0, t0);\r
-        _ = c.ggml_set_param(ctx0, t1);\r
-\r
-        // f = t0^2 + t1^2\r
-        const f =\r
-            c.ggml_add(ctx0, c.ggml_sqr(ctx0, t0), c.ggml_sqr(ctx0, t1));\r
-\r
-        const res = c.ggml_opt(null, opt_params, f);\r
-\r
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3));\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 0.0, 1e-3));\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 0.0, 1e-3));\r
-    }\r
-\r
-    /////////////////////////////////////////\r
-\r
-    {\r
-        const t0 = c.ggml_new_f32(ctx0, -7.0);\r
-        const t1 = c.ggml_new_f32(ctx0, 8.0);\r
-\r
-        _ = c.ggml_set_param(ctx0, t0);\r
-        _ = c.ggml_set_param(ctx0, t1);\r
-\r
-        // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2\r
-        const f =\r
-            c.ggml_add(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, t0, c.ggml_mul(ctx0, t1, c.ggml_new_f32(ctx0, 2.0))), c.ggml_new_f32(ctx0, 7.0))), c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, t0, c.ggml_new_f32(ctx0, 2.0)), t1), c.ggml_new_f32(ctx0, 5.0))));\r
-\r
-        const res = c.ggml_opt(null, opt_params, f);\r
-\r
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3));\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 1.0, 1e-3));\r
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 3.0, 1e-3));\r
-    }\r
-\r
-    _ = try std.io.getStdIn().reader().readByte();\r
-}\r
+const std = @import("std");
+const Thread = std.Thread;
+const c = @cImport({
+    @cInclude("ggml.h");
+});
+
+fn is_close(a: f32, b: f32, epsilon: f32) bool {
+    return @abs(a - b) < epsilon;
+}
+
+pub fn main() !void {
+    const params = .{
+        .mem_size = 128 * 1024 * 1024,
+        .mem_buffer = null,
+        .no_alloc = false,
+    };
+
+    var opt_params = c.ggml_opt_default_params(c.GGML_OPT_TYPE_LBFGS);
+
+    const nthreads = try Thread.getCpuCount();
+    opt_params.n_threads = @intCast(nthreads);
+    std.debug.print("test2: n_threads:{}\n", .{opt_params.n_threads});
+
+    const xi = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 };
+    const yi = [_]f32{ 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0, 105.0 };
+
+    const n = xi.len;
+
+    const ctx0 = c.ggml_init(params);
+    defer c.ggml_free(ctx0);
+
+    const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n);
+    const y = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n);
+
+    for (0..n) |i| {
+        const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data));
+        x_data_pointer[i] = xi[i];
+        const y_data_pointer: [*]f32 = @ptrCast(@alignCast(y.*.data));
+        y_data_pointer[i] = yi[i];
+    }
+
+    {
+        const t0 = c.ggml_new_f32(ctx0, 0.0);
+        const t1 = c.ggml_new_f32(ctx0, 0.0);
+
+        // initialize auto-diff parameters:
+        _ = c.ggml_set_param(ctx0, t0);
+        _ = c.ggml_set_param(ctx0, t1);
+
+        // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n)
+        const f =
+            c.ggml_div(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), c.ggml_repeat(ctx0, t0, x)), y))), c.ggml_new_f32(ctx0, @as(f32, 2.0) * n));
+
+        const res = c.ggml_opt(null, opt_params, f);
+
+        std.debug.print("t0 = {d:.6}\n", .{c.ggml_get_f32_1d(t0, 0)});
+        std.debug.print("t1 = {d:.6}\n", .{c.ggml_get_f32_1d(t1, 0)});
+
+        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-3));
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-3));
+    }
+
+    {
+        const t0 = c.ggml_new_f32(ctx0, -1.0);
+        const t1 = c.ggml_new_f32(ctx0, 9.0);
+
+        _ = c.ggml_set_param(ctx0, t0);
+        _ = c.ggml_set_param(ctx0, t1);
+
+        // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n
+        const f =
+            c.ggml_mul(ctx0, c.ggml_new_f32(ctx0, @as(f32, 1.0) / (2 * n)), c.ggml_sum(ctx0, c.ggml_abs(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), c.ggml_repeat(ctx0, t0, x)), y))));
+
+        const res = c.ggml_opt(null, opt_params, f);
+
+        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-2));
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-2));
+    }
+
+    {
+        const t0 = c.ggml_new_f32(ctx0, 5.0);
+        const t1 = c.ggml_new_f32(ctx0, -4.0);
+
+        _ = c.ggml_set_param(ctx0, t0);
+        _ = c.ggml_set_param(ctx0, t1);
+
+        // f = t0^2 + t1^2
+        const f =
+            c.ggml_add(ctx0, c.ggml_sqr(ctx0, t0), c.ggml_sqr(ctx0, t1));
+
+        const res = c.ggml_opt(null, opt_params, f);
+
+        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3));
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 0.0, 1e-3));
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 0.0, 1e-3));
+    }
+
+    /////////////////////////////////////////
+
+    {
+        const t0 = c.ggml_new_f32(ctx0, -7.0);
+        const t1 = c.ggml_new_f32(ctx0, 8.0);
+
+        _ = c.ggml_set_param(ctx0, t0);
+        _ = c.ggml_set_param(ctx0, t1);
+
+        // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2
+        const f =
+            c.ggml_add(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, t0, c.ggml_mul(ctx0, t1, c.ggml_new_f32(ctx0, 2.0))), c.ggml_new_f32(ctx0, 7.0))), c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, t0, c.ggml_new_f32(ctx0, 2.0)), t1), c.ggml_new_f32(ctx0, 5.0))));
+
+        const res = c.ggml_opt(null, opt_params, f);
+
+        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3));
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 1.0, 1e-3));
+        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 3.0, 1e-3));
+    }
+
+    _ = try std.io.getStdIn().reader().readByte();
+}
index 12aac708416f6c47bff6371d2732140088dc4ee3..d1e9fcc619b6cde6d7c0572a48a8a2d497c48bf3 100644 (file)
@@ -1,4 +1,4 @@
-#include "ggml/ggml.h"
+#include "ggml.h"
 
 #include <math.h>
 #include <stdio.h>
index fe87df80c78a59259c20e4d2ae4d90a586550712..ecaf1b014a04d2034dee3d1b062b31f40ca68a48 100644 (file)
@@ -1,87 +1,87 @@
-const std = @import("std");\r
-const Thread = std.Thread;\r
-const c = @cImport({\r
-    @cInclude("stdlib.h");\r
-    @cInclude("ggml/ggml.h");\r
-});\r
-\r
-fn is_close(a: f32, b: f32, epsilon: f32) bool {\r
-    return @abs(a - b) < epsilon;\r
-}\r
-\r
-pub fn main() !void {\r
-    const params = .{\r
-        .mem_size = 128 * 1024 * 1024,\r
-        .mem_buffer = null,\r
-        .no_alloc = false,\r
-    };\r
-\r
-    var opt_params = c.ggml_opt_default_params(c.GGML_OPT_TYPE_LBFGS);\r
-\r
-    const nthreads = try Thread.getCpuCount();\r
-    opt_params.n_threads = @intCast(nthreads);\r
-\r
-    const NP = 1 << 12;\r
-    const NF = 1 << 8;\r
-\r
-    const ctx0 = c.ggml_init(params);\r
-    defer c.ggml_free(ctx0);\r
-\r
-    const F = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_F32, NF, NP);\r
-    const l = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NP);\r
-\r
-    // regularization weight\r
-    const lambda = c.ggml_new_f32(ctx0, 1e-5);\r
-\r
-    c.srand(0);\r
-\r
-    const l_data_pointer: [*]f32 = @ptrCast(@alignCast(l.*.data));\r
-    const f_data_pointer: [*]f32 = @ptrCast(@alignCast(F.*.data));\r
-    for (0..NP) |j| {\r
-        const ll = if (j < NP / 2) @as(f32, 1.0) else @as(f32, -1.0);\r
-        l_data_pointer[j] = ll;\r
-\r
-        for (0..NF) |i| {\r
-            const c_rand: f32 = @floatFromInt(c.rand());\r
-            f_data_pointer[j * NF + i] =\r
-                ((if (ll > 0 and i < NF / 2) @as(f32, 1.0) else if (ll < 0 and i >= NF / 2) @as(f32, 1.0) else @as(f32, 0.0)) +\r
-                (c_rand / c.RAND_MAX - 0.5) * 0.1) / (0.5 * NF);\r
-        }\r
-    }\r
-\r
-    {\r
-        // initial guess\r
-        const x = c.ggml_set_f32(c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NF), 0.0);\r
-\r
-        c.ggml_set_param(ctx0, x);\r
-\r
-        // f = sum[(fj*x - l)^2]/n + lambda*|x^2|\r
-        const f =\r
-            c.ggml_add(ctx0, c.ggml_div(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_mul_mat(ctx0, F, x), l))), c.ggml_new_f32(ctx0, @as(f32, NP))), c.ggml_mul(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, x)), lambda));\r
-\r
-        const res = c.ggml_opt(null, opt_params, f);\r
-\r
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);\r
-\r
-        const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data));\r
-        // print results\r
-        for (0..16) |i| {\r
-            std.debug.print("x[{d:3}] = {d:.6}\n", .{ i, x_data_pointer[i] });\r
-        }\r
-        std.debug.print("...\n", .{});\r
-        for (NF - 16..NF) |i| {\r
-            std.debug.print("x[{d:3}] = {d:.6}\n", .{ i, x_data_pointer[i] });\r
-        }\r
-        std.debug.print("\n", .{});\r
-\r
-        for (0..NF) |i| {\r
-            if (i < NF / 2) {\r
-                try std.testing.expect(is_close(x_data_pointer[i], 1.0, 1e-2));\r
-            } else {\r
-                try std.testing.expect(is_close(x_data_pointer[i], -1.0, 1e-2));\r
-            }\r
-        }\r
-    }\r
-\r
-    _ = try std.io.getStdIn().reader().readByte();\r
-}\r
+const std = @import("std");
+const Thread = std.Thread;
+const c = @cImport({
+    @cInclude("stdlib.h");
+    @cInclude("ggml.h");
+});
+
+fn is_close(a: f32, b: f32, epsilon: f32) bool {
+    return @abs(a - b) < epsilon;
+}
+
+pub fn main() !void {
+    const params = .{
+        .mem_size = 128 * 1024 * 1024,
+        .mem_buffer = null,
+        .no_alloc = false,
+    };
+
+    var opt_params = c.ggml_opt_default_params(c.GGML_OPT_TYPE_LBFGS);
+
+    const nthreads = try Thread.getCpuCount();
+    opt_params.n_threads = @intCast(nthreads);
+
+    const NP = 1 << 12;
+    const NF = 1 << 8;
+
+    const ctx0 = c.ggml_init(params);
+    defer c.ggml_free(ctx0);
+
+    const F = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_F32, NF, NP);
+    const l = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NP);
+
+    // regularization weight
+    const lambda = c.ggml_new_f32(ctx0, 1e-5);
+
+    c.srand(0);
+
+    const l_data_pointer: [*]f32 = @ptrCast(@alignCast(l.*.data));
+    const f_data_pointer: [*]f32 = @ptrCast(@alignCast(F.*.data));
+    for (0..NP) |j| {
+        const ll = if (j < NP / 2) @as(f32, 1.0) else @as(f32, -1.0);
+        l_data_pointer[j] = ll;
+
+        for (0..NF) |i| {
+            const c_rand: f32 = @floatFromInt(c.rand());
+            f_data_pointer[j * NF + i] =
+                ((if (ll > 0 and i < NF / 2) @as(f32, 1.0) else if (ll < 0 and i >= NF / 2) @as(f32, 1.0) else @as(f32, 0.0)) +
+                (c_rand / c.RAND_MAX - 0.5) * 0.1) / (0.5 * NF);
+        }
+    }
+
+    {
+        // initial guess
+        const x = c.ggml_set_f32(c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NF), 0.0);
+
+        c.ggml_set_param(ctx0, x);
+
+        // f = sum[(fj*x - l)^2]/n + lambda*|x^2|
+        const f =
+            c.ggml_add(ctx0, c.ggml_div(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_mul_mat(ctx0, F, x), l))), c.ggml_new_f32(ctx0, @as(f32, NP))), c.ggml_mul(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, x)), lambda));
+
+        const res = c.ggml_opt(null, opt_params, f);
+
+        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
+
+        const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data));
+        // print results
+        for (0..16) |i| {
+            std.debug.print("x[{d:3}] = {d:.6}\n", .{ i, x_data_pointer[i] });
+        }
+        std.debug.print("...\n", .{});
+        for (NF - 16..NF) |i| {
+            std.debug.print("x[{d:3}] = {d:.6}\n", .{ i, x_data_pointer[i] });
+        }
+        std.debug.print("\n", .{});
+
+        for (0..NF) |i| {
+            if (i < NF / 2) {
+                try std.testing.expect(is_close(x_data_pointer[i], 1.0, 1e-2));
+            } else {
+                try std.testing.expect(is_close(x_data_pointer[i], -1.0, 1e-2));
+            }
+        }
+    }
+
+    _ = try std.io.getStdIn().reader().readByte();
+}