]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Nomic Vulkan backend (#4456)
authorJared Van Bortel <redacted>
Mon, 29 Jan 2024 20:50:50 +0000 (15:50 -0500)
committerGitHub <redacted>
Mon, 29 Jan 2024 20:50:50 +0000 (15:50 -0500)
Signed-off-by: Jared Van Bortel <redacted>
Co-authored-by: niansa <redacted>
Co-authored-by: Adam Treat <redacted>
Co-authored-by: Aaron Miller <redacted>
Co-authored-by: ToKiNoBug <redacted>
Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: slaren <redacted>
45 files changed:
.ecrc
.github/workflows/build.yml
.gitmodules [new file with mode: 0644]
CMakeLists.txt
ggml-backend.c
ggml-kompute.cpp [new file with mode: 0644]
ggml-kompute.h [new file with mode: 0644]
kompute [new submodule]
kompute-shaders/common.comp [new file with mode: 0644]
kompute-shaders/op_add.comp [new file with mode: 0644]
kompute-shaders/op_addrow.comp [new file with mode: 0644]
kompute-shaders/op_cpy_f16_f16.comp [new file with mode: 0644]
kompute-shaders/op_cpy_f16_f32.comp [new file with mode: 0644]
kompute-shaders/op_cpy_f32_f16.comp [new file with mode: 0644]
kompute-shaders/op_cpy_f32_f32.comp [new file with mode: 0644]
kompute-shaders/op_diagmask.comp [new file with mode: 0644]
kompute-shaders/op_gelu.comp [new file with mode: 0644]
kompute-shaders/op_getrows.comp [new file with mode: 0644]
kompute-shaders/op_getrows_f16.comp [new file with mode: 0644]
kompute-shaders/op_getrows_q4_0.comp [new file with mode: 0644]
kompute-shaders/op_getrows_q4_1.comp [new file with mode: 0644]
kompute-shaders/op_getrows_q6_k.comp [new file with mode: 0644]
kompute-shaders/op_mul.comp [new file with mode: 0644]
kompute-shaders/op_mul_mat_f16.comp [new file with mode: 0644]
kompute-shaders/op_mul_mat_mat_f32.comp [new file with mode: 0644]
kompute-shaders/op_mul_mat_q4_0.comp [new file with mode: 0644]
kompute-shaders/op_mul_mat_q4_1.comp [new file with mode: 0644]
kompute-shaders/op_mul_mat_q6_k.comp [new file with mode: 0644]
kompute-shaders/op_mul_mat_q8_0.comp [new file with mode: 0644]
kompute-shaders/op_mul_mv_q_n.comp [new file with mode: 0644]
kompute-shaders/op_mul_mv_q_n_pre.comp [new file with mode: 0644]
kompute-shaders/op_norm.comp [new file with mode: 0644]
kompute-shaders/op_relu.comp [new file with mode: 0644]
kompute-shaders/op_rmsnorm.comp [new file with mode: 0644]
kompute-shaders/op_rope_f16.comp [new file with mode: 0644]
kompute-shaders/op_rope_f32.comp [new file with mode: 0644]
kompute-shaders/op_scale.comp [new file with mode: 0644]
kompute-shaders/op_scale_8.comp [new file with mode: 0644]
kompute-shaders/op_silu.comp [new file with mode: 0644]
kompute-shaders/op_softmax.comp [new file with mode: 0644]
kompute-shaders/rope_common.comp [new file with mode: 0644]
llama.cpp
llama.h
tests/test-backend-ops.cpp
tests/test-c.c

diff --git a/.ecrc b/.ecrc
index b682057dd6891683a089bc17ce86ecfe6d94be0d..a3351f4e6442dfdf9f97280722730e26721fedb2 100644 (file)
--- a/.ecrc
+++ b/.ecrc
@@ -1,4 +1,5 @@
 {
+  "Exclude": ["^\\.gitmodules$"],
   "Disable": {
     "IndentSize": true
   }
index e5e435a70db133e33317da4bf7031cff2acfcf8e..fb719a5506c6619bc0be036fb841bba0553d73ed 100644 (file)
@@ -337,6 +337,7 @@ jobs:
       OPENCL_VERSION: 2023.04.17
       CLBLAST_VERSION: 1.6.0
       SDE_VERSION: 9.33.0-2024-01-07
+      VULKAN_VERSION: 1.3.261.1
 
     strategy:
       matrix:
@@ -353,6 +354,8 @@ jobs:
             defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_CLBLAST=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/clblast"'
           - build: 'openblas'
             defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_BLAS=ON -DBUILD_SHARED_LIBS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"'
+          - build: 'kompute'
+            defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON -DBUILD_SHARED_LIBS=ON'
 
     steps:
       - name: Clone
@@ -361,6 +364,12 @@ jobs:
         with:
           fetch-depth: 0
 
+      - name: Clone Kompute submodule
+        id: clone_kompute
+        if: ${{ matrix.build == 'kompute' }}
+        run: |
+          git submodule update --init kompute
+
       - name: Download OpenCL SDK
         id: get_opencl
         if: ${{ matrix.build == 'clblast' }}
@@ -395,6 +404,15 @@ jobs:
           $lib =  $(join-path $msvc 'bin\Hostx64\x64\lib.exe')
           & $lib /machine:x64 "/def:${env:RUNNER_TEMP}/openblas/lib/libopenblas.def" "/out:${env:RUNNER_TEMP}/openblas/lib/openblas.lib" /name:openblas.dll
 
+      - name: Install Vulkan SDK
+        id: get_vulkan
+        if: ${{ matrix.build == 'kompute' }}
+        run: |
+          curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
+          & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
+          Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
+          Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"
+
       - name: Build
         id: cmake_build
         run: |
@@ -432,7 +450,8 @@ jobs:
 
       - name: Test
         id: cmake_test
-        if: ${{ matrix.build != 'clblast' && (matrix.build != 'avx512' || env.HAS_AVX512F == '1') }} # not all machines have native AVX-512
+        # not all machines have native AVX-512
+        if: ${{ matrix.build != 'clblast' && matrix.build != 'kompute' && (matrix.build != 'avx512' || env.HAS_AVX512F == '1') }}
         run: |
           cd build
           ctest -L main -C Release --verbose --timeout 900
diff --git a/.gitmodules b/.gitmodules
new file mode 100644 (file)
index 0000000..b7e8b8f
--- /dev/null
@@ -0,0 +1,3 @@
+[submodule "kompute"]
+       path = kompute
+       url = https://github.com/nomic-ai/kompute.git
index ed8f39c625b13a94b3a52065daef9049629bf4c2..65a6f3971a939cab150fa4e5a56b445af13ab7ce 100644 (file)
@@ -103,6 +103,7 @@ option(LLAMA_VULKAN                          "llama: use Vulkan"
 option(LLAMA_METAL                           "llama: use Metal"                                 ${LLAMA_METAL_DEFAULT})
 option(LLAMA_METAL_NDEBUG                    "llama: disable Metal debugging"                   OFF)
 option(LLAMA_METAL_SHADER_DEBUG              "llama: compile Metal with -fno-fast-math"         OFF)
+option(LLAMA_KOMPUTE                         "llama: use Kompute"                               OFF)
 option(LLAMA_MPI                             "llama: use MPI"                                   OFF)
 option(LLAMA_QKK_64                          "llama: use super-block size of 64 for k-quants"   OFF)
 option(LLAMA_SYCL                            "llama: use SYCL"                                  OFF)
@@ -484,7 +485,6 @@ if (LLAMA_HIPBLAS)
     endif()
 endif()
 
-
 if (LLAMA_SYCL)
     if ( NOT DEFINED ENV{ONEAPI_ROOT})
         message(FATAL_ERROR "Not detect ENV {ONEAPI_ROOT}, please install oneAPI & source it, like: source /opt/intel/oneapi/setvars.sh")
@@ -510,6 +510,160 @@ if (LLAMA_SYCL)
     set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
 endif()
 
+if (LLAMA_KOMPUTE)
+    add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
+    find_package(Vulkan COMPONENTS glslc REQUIRED)
+    find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
+    if (NOT glslc_executable)
+        message(FATAL_ERROR "glslc not found")
+    endif()
+
+    function(compile_shader)
+      set(options)
+      set(oneValueArgs)
+      set(multiValueArgs SOURCES)
+      cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+      foreach(source ${compile_shader_SOURCES})
+        get_filename_component(filename ${source} NAME)
+        set(spv_file ${filename}.spv)
+        add_custom_command(
+            OUTPUT ${spv_file}
+            DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
+              ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
+              ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
+              ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
+              ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
+              COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
+            COMMENT "Compiling ${source} to ${spv_file}"
+        )
+
+        get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
+        set(FILE_NAME "shader${RAW_FILE_NAME}")
+        string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
+        string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
+        string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
+        set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
+        message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
+        if(CMAKE_GENERATOR MATCHES "Visual Studio")
+            add_custom_command(
+              OUTPUT ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+              DEPENDS ${spv_file} xxd
+              COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
+            )
+        else()
+            add_custom_command(
+              OUTPUT ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
+              COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+              DEPENDS ${spv_file} xxd
+              COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
+            )
+        endif()
+      endforeach()
+    endfunction()
+
+    if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
+        message(STATUS "Kompute found")
+        set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
+        add_subdirectory(kompute)
+
+        # Compile our shaders
+        compile_shader(SOURCES
+          kompute-shaders/op_scale.comp
+          kompute-shaders/op_scale_8.comp
+          kompute-shaders/op_add.comp
+          kompute-shaders/op_addrow.comp
+          kompute-shaders/op_mul.comp
+          kompute-shaders/op_silu.comp
+          kompute-shaders/op_relu.comp
+          kompute-shaders/op_gelu.comp
+          kompute-shaders/op_softmax.comp
+          kompute-shaders/op_norm.comp
+          kompute-shaders/op_rmsnorm.comp
+          kompute-shaders/op_diagmask.comp
+          kompute-shaders/op_mul_mat_mat_f32.comp
+          kompute-shaders/op_mul_mat_f16.comp
+          kompute-shaders/op_mul_mat_q8_0.comp
+          kompute-shaders/op_mul_mat_q4_0.comp
+          kompute-shaders/op_mul_mat_q4_1.comp
+          kompute-shaders/op_mul_mat_q6_k.comp
+          kompute-shaders/op_getrows_f16.comp
+          kompute-shaders/op_getrows_q4_0.comp
+          kompute-shaders/op_getrows_q4_1.comp
+          kompute-shaders/op_getrows_q6_k.comp
+          kompute-shaders/op_rope_f16.comp
+          kompute-shaders/op_rope_f32.comp
+          kompute-shaders/op_cpy_f16_f16.comp
+          kompute-shaders/op_cpy_f16_f32.comp
+          kompute-shaders/op_cpy_f32_f16.comp
+          kompute-shaders/op_cpy_f32_f32.comp
+        )
+
+        # Create a custom target for our generated shaders
+        add_custom_target(generated_shaders DEPENDS
+          shaderop_scale.h
+          shaderop_scale_8.h
+          shaderop_add.h
+          shaderop_addrow.h
+          shaderop_mul.h
+          shaderop_silu.h
+          shaderop_relu.h
+          shaderop_gelu.h
+          shaderop_softmax.h
+          shaderop_norm.h
+          shaderop_rmsnorm.h
+          shaderop_diagmask.h
+          shaderop_mul_mat_mat_f32.h
+          shaderop_mul_mat_f16.h
+          shaderop_mul_mat_q8_0.h
+          shaderop_mul_mat_q4_0.h
+          shaderop_mul_mat_q4_1.h
+          shaderop_mul_mat_q6_k.h
+          shaderop_getrows_f16.h
+          shaderop_getrows_q4_0.h
+          shaderop_getrows_q4_1.h
+          shaderop_getrows_q6_k.h
+          shaderop_rope_f16.h
+          shaderop_rope_f32.h
+          shaderop_cpy_f16_f16.h
+          shaderop_cpy_f16_f32.h
+          shaderop_cpy_f32_f16.h
+          shaderop_cpy_f32_f32.h
+        )
+
+        # Create a custom command that depends on the generated_shaders
+        add_custom_command(
+            OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
+            COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
+            DEPENDS generated_shaders
+            COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
+        )
+
+        # Add the stamp to the main sources to ensure dependency tracking
+        set(GGML_SOURCES_KOMPUTE ggml-kompute.cpp ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
+        set(GGML_HEADERS_KOMPUTE ggml-kompute.h ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
+        add_compile_definitions(GGML_USE_KOMPUTE)
+        set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} kompute)
+        set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${CMAKE_BINARY_DIR})
+    else()
+        message(WARNING "Kompute not found")
+    endif()
+endif()
+
 function(get_flags CCID CCVER)
     set(C_FLAGS "")
     set(CXX_FLAGS "")
@@ -852,13 +1006,14 @@ add_library(ggml OBJECT
             ggml-backend.h
             ggml-quants.c
             ggml-quants.h
-            ${GGML_SOURCES_CUDA}   ${GGML_HEADERS_CUDA}
-            ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
-            ${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
-            ${GGML_SOURCES_METAL}  ${GGML_HEADERS_METAL}
-            ${GGML_SOURCES_MPI}    ${GGML_HEADERS_MPI}
-            ${GGML_SOURCES_EXTRA}  ${GGML_HEADERS_EXTRA}
-            ${GGML_SOURCES_SYCL}   ${GGML_HEADERS_SYCL}
+            ${GGML_SOURCES_CUDA}    ${GGML_HEADERS_CUDA}
+            ${GGML_SOURCES_OPENCL}  ${GGML_HEADERS_OPENCL}
+            ${GGML_SOURCES_VULKAN}  ${GGML_HEADERS_VULKAN}
+            ${GGML_SOURCES_METAL}   ${GGML_HEADERS_METAL}
+            ${GGML_SOURCES_MPI}     ${GGML_HEADERS_MPI}
+            ${GGML_SOURCES_EXTRA}   ${GGML_HEADERS_EXTRA}
+            ${GGML_SOURCES_SYCL}    ${GGML_HEADERS_SYCL}
+            ${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
             )
 
 target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
index 8b6cf7c9f1e48ce9da3d068c6350ba2c74b50949..0764dfebca673647babc92e9f58abae085b28be9 100644 (file)
@@ -373,6 +373,11 @@ GGML_CALL static void ggml_backend_registry_init(void) {
     extern GGML_CALL int ggml_backend_vk_reg_devices(void);
     ggml_backend_vk_reg_devices();
 #endif
+
+#ifdef GGML_USE_KOMPUTE
+    extern GGML_CALL void ggml_backend_kompute_reg_devices(void);
+    ggml_backend_kompute_reg_devices();
+#endif
 }
 
 GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp
new file mode 100644 (file)
index 0000000..51c5af8
--- /dev/null
@@ -0,0 +1,1990 @@
+#include "ggml.h"
+#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
+#include "ggml-kompute.h"
+
+// These are generated at build time by cmake custom command
+#include "shaderop_scale.h"
+#include "shaderop_scale_8.h"
+#include "shaderop_add.h"
+#include "shaderop_addrow.h"
+#include "shaderop_mul.h"
+#include "shaderop_silu.h"
+#include "shaderop_relu.h"
+#include "shaderop_gelu.h"
+#include "shaderop_softmax.h"
+#include "shaderop_norm.h"
+#include "shaderop_rmsnorm.h"
+#include "shaderop_diagmask.h"
+#include "shaderop_mul_mat_f16.h"
+#include "shaderop_mul_mat_q8_0.h"
+#include "shaderop_mul_mat_q4_0.h"
+#include "shaderop_mul_mat_q4_1.h"
+#include "shaderop_mul_mat_q6_k.h"
+#include "shaderop_mul_mat_mat_f32.h"
+#include "shaderop_getrows_f16.h"
+#include "shaderop_getrows_q4_0.h"
+#include "shaderop_getrows_q4_1.h"
+#include "shaderop_getrows_q6_k.h"
+#include "shaderop_rope_f16.h"
+#include "shaderop_rope_f32.h"
+#include "shaderop_cpy_f16_f16.h"
+#include "shaderop_cpy_f16_f32.h"
+#include "shaderop_cpy_f32_f16.h"
+#include "shaderop_cpy_f32_f32.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cstdint>
+#include <cstdio>
+#include <cstring>
+#include <iostream>
+#include <memory>
+#include <stdexcept>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include <kompute/Kompute.hpp>
+#include <vulkan/vulkan.hpp>
+
+#ifdef __linux__
+#include <cstdlib> // for setenv
+#endif
+
+#define QK4_0 32
+#define QR4_0 2
+#define QK4_1 32
+#define QK_NL 16
+
+typedef ggml_fp16_t half;
+
+static std::string ggml_kompute_format_name(int device) {
+    return "Kompute" + std::to_string(device);
+}
+
+struct ggml_kompute_context {
+    int device;
+    std::string name;
+    std::shared_ptr<vk::DescriptorPool> pool;
+
+    ggml_kompute_context(int device)
+        : device(device), name(ggml_kompute_format_name(device)) {}
+};
+
+// FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
+// and consolidate the init functions and simplify object lifetime management. As it currently stands,
+// we *have* to have the kompute manager no matter what for device discovery, but the kompute context
+// is only created when a device is set and vulkan is explicitly turned on.
+static ggml_kompute_context *s_kompute_context = nullptr;
+
+class kompute_manager {
+    kp::Manager *s_mgr = nullptr;
+
+public:
+    kp::Manager *operator()() {
+        if (s_mgr && !s_mgr->hasInstance()) {
+            destroy();
+        }
+        if (!s_mgr) {
+            s_mgr = new kp::Manager;
+        }
+        return s_mgr;
+    }
+
+    void destroy() {
+        delete s_mgr;
+        s_mgr = nullptr;
+    }
+};
+
+static kompute_manager komputeManager;
+
+struct ggml_vk_memory {
+    void *data = nullptr;
+    size_t size = 0;
+    vk::DeviceMemory *primaryMemory = nullptr;
+    vk::Buffer *primaryBuffer = nullptr;
+    vk::DeviceMemory *stagingMemory = nullptr;
+    vk::Buffer *stagingBuffer = nullptr;
+};
+
+#ifdef __linux__
+__attribute__((constructor))
+static void enable_sam() {
+    setenv("RADV_PERFTEST", "sam", false);
+}
+#endif
+
+static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
+    vk::PhysicalDeviceFeatures availableFeatures;
+    physical_device.getFeatures(&availableFeatures);
+
+    if (!availableFeatures.shaderInt16)
+        return false;
+
+    vk::PhysicalDeviceVulkan11Features availableFeatures11;
+    vk::PhysicalDeviceVulkan12Features availableFeatures12;
+
+    availableFeatures11.pNext = &availableFeatures12;
+    availableFeatures12.pNext = nullptr;
+
+    vk::PhysicalDeviceFeatures2 features2;
+    features2.pNext = &availableFeatures11;
+
+    physical_device.getFeatures2(&features2);
+
+    if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
+        !availableFeatures11.storageBuffer16BitAccess) {
+        return false;
+    }
+
+    if (!availableFeatures12.storageBuffer8BitAccess ||
+        !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
+        !availableFeatures12.shaderFloat16 ||
+        !availableFeatures12.shaderInt8) {
+        return false;
+    }
+
+    return true;
+}
+
+static const char * ggml_vk_getVendorName(uint32_t vendorID) {
+    switch (vendorID) {
+        case 0x10DE:
+            return "nvidia";
+        case 0x1002:
+            return "amd";
+        case 0x8086:
+            return "intel";
+        default:
+            return "unknown";
+    }
+}
+
+static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t memoryRequired) {
+    std::vector<ggml_vk_device> results;
+    if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
+        return results;
+
+    std::vector<vk::PhysicalDevice> physical_devices;
+    try {
+        physical_devices = komputeManager()->listDevices();
+    } catch (vk::SystemError & err) {
+        std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
+        return results;
+    }
+
+    uint32_t deviceCount = physical_devices.size();
+    if (deviceCount == 0)
+        return results;
+
+    std::unordered_map<std::string, size_t> count_by_name;
+
+    for (uint32_t i = 0; i < deviceCount; i++) {
+        const auto & physical_device = physical_devices[i];
+
+        VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
+        VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
+        const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
+        const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
+        if (major < 1 || minor < 2)
+            continue;
+
+        if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
+            continue;
+
+        size_t heapSize = 0;
+        for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
+            VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
+            if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
+                heapSize = heap.size;
+                break;
+            }
+        }
+
+        if (heapSize < memoryRequired)
+            continue;
+
+        auto ext_props = physical_device.enumerateDeviceExtensionProperties();
+        bool has_maintenance4 = false;
+
+        // Check if maintenance4 is supported
+        for (const auto & properties : ext_props) {
+            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
+                has_maintenance4 = true;
+            }
+        }
+
+        vk::PhysicalDeviceSubgroupProperties subgroup_props;
+        vk::PhysicalDeviceProperties2 dev_props2;
+        vk::PhysicalDeviceMaintenance3Properties dev_props3;
+        vk::PhysicalDeviceMaintenance4Properties dev_props4;
+        dev_props2.pNext = &dev_props3;
+        dev_props3.pNext = &subgroup_props;
+        if (has_maintenance4) {
+            subgroup_props.pNext = &dev_props4;
+        }
+        physical_device.getProperties2(&dev_props2);
+
+        if (subgroup_props.subgroupSize < 32)
+            continue;
+
+        ggml_vk_device d;
+        d.index = i;
+        d.type = dev_props.deviceType;
+        d.heapSize = heapSize;
+        d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
+        d.subgroupSize = subgroup_props.subgroupSize;
+        d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
+
+        if (has_maintenance4) {
+            d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
+        } else {
+            d.maxAlloc = dev_props3.maxMemoryAllocationSize;
+        }
+
+        std::string name(dev_props.deviceName);
+        size_t n_idx = ++count_by_name[name];
+        if (n_idx > 1) {
+            name += " (" + std::to_string(n_idx) + ")";
+        }
+        d.name = strdup(name.c_str());
+
+        results.push_back(d);
+    }
+
+    std::stable_sort(results.begin(), results.end(),
+        [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
+            if (lhs.type != rhs.type) {
+                if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
+                if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
+
+                if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
+                if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
+            }
+            return lhs.heapSize < rhs.heapSize;
+        }
+    );
+
+    return results;
+}
+
+// public API returns a C-style array
+ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) {
+    auto devices = ggml_vk_available_devices_internal(memoryRequired);
+    *count = devices.size();
+    if (devices.empty()) {
+        return nullptr;
+    }
+
+    size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
+    auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
+    memcpy(arr, devices.data(), nbytes);
+    return arr;
+}
+
+static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
+    devices.erase(
+        std::remove_if(devices.begin(), devices.end(),
+            [&targetVendor](const ggml_vk_device& device) {
+                return device.vendor != targetVendor;
+            }),
+        devices.end()
+    );
+}
+
+static void ggml_vk_filterByName(std::vector<ggml_vk_device>& devices, const std::string& targetName) {
+    devices.erase(
+        std::remove_if(devices.begin(), devices.end(),
+            [&targetName](const ggml_vk_device& device) {
+                return device.name != targetName;
+            }),
+        devices.end()
+    );
+}
+
+static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
+    if (name.empty())
+        return false;
+
+    auto devices = ggml_vk_available_devices_internal(memoryRequired);
+    if (name == "amd" || name == "nvidia" || name == "intel") {
+        ggml_vk_filterByVendor(devices, name);
+    } else if (name != "gpu") {
+        ggml_vk_filterByName(devices, name);
+    }
+
+    if (devices.empty())
+        return false;
+
+    *device = devices.front();
+    return true;
+}
+
+bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
+    return ggml_vk_get_device(device, memoryRequired, std::string(name));
+}
+
+bool ggml_vk_has_vulkan() {
+    return komputeManager()->hasVulkan();
+}
+
+bool ggml_vk_has_device() {
+    return komputeManager()->hasDevice();
+}
+
+ggml_vk_device ggml_vk_current_device() {
+    if (!komputeManager()->hasDevice())
+        return ggml_vk_device();
+
+    auto devices = ggml_vk_available_devices_internal(0);
+    ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
+    GGML_ASSERT(!devices.empty());
+    return devices.front();
+}
+
+static
+void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
+    std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
+        vk::DescriptorPoolSize(
+          vk::DescriptorType::eStorageBuffer,
+          3 * size // Descriptor count is number of possible tensors to pass into an algorithm
+          )
+    };
+
+    vk::DescriptorPoolCreateInfo descriptorPoolInfo(
+      vk::DescriptorPoolCreateFlags(),
+      size, // Max sets
+      static_cast<uint32_t>(descriptorPoolSizes.size()),
+      descriptorPoolSizes.data());
+
+    ctx->pool = std::make_shared<vk::DescriptorPool>();
+    vk::Result r = komputeManager()->device()->createDescriptorPool(
+      &descriptorPoolInfo, nullptr, ctx->pool.get());
+    if (r != vk::Result::eSuccess)
+        std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
+}
+
+static
+void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
+    if (ctx->pool) {
+        komputeManager()->device()->destroy(
+          *ctx->pool,
+          (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+        ctx->pool = nullptr;
+    }
+}
+
+static
+vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
+    vk::BufferCreateInfo bufferCreateInfo;
+    bufferCreateInfo.size = size;
+    bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
+                             vk::BufferUsageFlagBits::eTransferSrc |
+                             vk::BufferUsageFlagBits::eTransferDst;
+    bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
+
+    vk::Buffer *vkBuffer = new vk::Buffer;
+    vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
+    if (r != vk::Result::eSuccess)
+        std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
+    return vkBuffer;
+}
+
+static
+vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
+
+    uint32_t memoryTypeIndex = -1;
+    bool memoryTypeIndexFound = false;
+    vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
+    for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
+        const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
+        const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
+        if (memoryHeap.size < size) {
+            continue;
+        }
+
+        if (requirements.memoryTypeBits & (1 << i)) {
+            if (((memoryProperties.memoryTypes[i]).propertyFlags &
+                 flags) == flags) {
+                memoryTypeIndex = i;
+                memoryTypeIndexFound = true;
+                if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
+                    *isHostVisible = true;
+                }
+                break;
+            }
+        }
+    }
+    if (!memoryTypeIndexFound) {
+        throw std::runtime_error(
+          "Memory type index for buffer creation not found");
+    }
+
+    vk::MemoryAllocateInfo allocInfo;
+    allocInfo.allocationSize = size;
+    allocInfo.memoryTypeIndex = memoryTypeIndex;
+    vk::DeviceMemory *vkDeviceMemory =  new vk::DeviceMemory;
+    vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
+    if (r != vk::Result::eSuccess) {
+        std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
+        throw std::runtime_error("Error allocating vulkan memory.");
+    }
+    return vkDeviceMemory;
+}
+
+static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
+    size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
+
+    // If offset is already aligned, return it directly
+    if (offset % minStorageBufferOffsetAlignment == 0) {
+        return offset;
+    }
+
+    // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
+    return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
+}
+
+static ggml_vk_memory ggml_vk_allocate(size_t size) {
+    ggml_vk_memory memory;
+    bool isHostVisible = false;
+    {
+        memory.primaryBuffer = ggml_vk_allocate_buffer(size);
+        vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
+        vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
+        memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
+        komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
+        if (isHostVisible) {
+            vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
+            if (r != vk::Result::eSuccess)
+                std::cerr << "Error mapping memory" << vk::to_string(r);
+        }
+    }
+
+    if (!isHostVisible) {
+        memory.stagingBuffer = ggml_vk_allocate_buffer(size);
+        vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
+        vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
+                                                      vk::MemoryPropertyFlagBits::eHostCoherent |
+                                                      vk::MemoryPropertyFlagBits::eHostCached;
+        memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
+        komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
+        vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
+        if (r != vk::Result::eSuccess)
+            std::cerr << "Error mapping memory" << vk::to_string(r);
+    }
+
+    memory.size = size;
+    return memory;
+}
+
+static void ggml_vk_free_memory(ggml_vk_memory &memory)
+{
+    komputeManager()->device()->destroy(
+      *memory.primaryBuffer,
+      (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+    if (memory.stagingBuffer) {
+        komputeManager()->device()->destroy(
+          *memory.stagingBuffer,
+          (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+    }
+    komputeManager()->device()->freeMemory(
+      *memory.primaryMemory,
+      (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+    if (memory.stagingMemory) {
+        komputeManager()->device()->freeMemory(
+          *memory.stagingMemory,
+          (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+    }
+}
+
+static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
+
+static
+ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
+    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
+
+    // compatibility with ggml-backend
+    GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
+
+    ggml_vk_memory * buf_ctx = static_cast<ggml_vk_memory *>(buffer->context);
+
+    const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
+
+    GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
+
+    offset = uint64_t(ioffs);
+    return buf_ctx;
+}
+
+static
+const std::shared_ptr<kp::Tensor> ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
+    uint64_t originalOffset = 0;
+    auto * res = ggml_vk_find_tensor(t, originalOffset);
+    if (!res) {
+        static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
+        return nullTensor;
+    }
+
+    // Create a tensor whose memory will be composed of our buffers at the correct offset
+    const size_t nelements = ggml_nelements(t);
+    size_t nbytes = ggml_nbytes(t);
+
+    size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
+    if (alignedOffset) {
+        *alignedOffset = originalOffset - vulkanOffset;
+        nbytes += *alignedOffset;
+    }
+
+    return komputeManager()->tensor(
+        t->data,
+        nelements,
+        nbytes, kp::Tensor::TensorDataTypes::eFloat,
+        res->primaryMemory, res->primaryBuffer,
+        res->stagingMemory, res->stagingBuffer,
+        vulkanOffset);
+}
+
+static std::vector<uint32_t> getSpirvShader(const unsigned char* rawData, size_t size) {
+    if (size % sizeof(uint32_t) != 0) {
+        throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
+    }
+
+    const uint32_t* data_ptr = reinterpret_cast<const uint32_t*>(rawData);
+    size_t count = size / sizeof(uint32_t);
+    return std::vector<uint32_t>(data_ptr, data_ptr + count);
+}
+
+inline static
+uint32_t safe_divide(uint32_t a, uint32_t b) {
+    if (b <= 1) {
+        return a;
+    }
+    if ((a % b) != 0) {
+        fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
+        GGML_ASSERT(!"safe_divide result would've had remainder");
+    }
+    return a / b;
+}
+
+static void ggml_vk_add(
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
+    int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
+    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+    int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
+    int32_t ne0,
+    int32_t nb0,  int32_t nb1,  int32_t nb2,  int32_t nb3
+) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
+        kp::shader_data::op_add_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00;
+        int32_t nb00, nb01, nb02, nb03;
+        int32_t ne10, ne11, ne12, ne13;
+        int32_t nb10, nb11, nb12, nb13;
+        int32_t ne0;
+        int32_t nb0, nb1, nb2, nb3;
+    } const pushConsts {
+        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00,
+        nb00, nb01, nb02, nb03,
+        ne10, ne11, ne12, ne13,
+        nb10, nb11, nb12, nb13,
+        ne0,
+        nb0, nb1, nb2, nb3
+    };
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__)) {
+        s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_addrow(kp::Sequence& seq,
+                 const std::shared_ptr<kp::Tensor>& inA,
+                 const std::shared_ptr<kp::Tensor>& inB,
+                 const std::shared_ptr<kp::Tensor>& out,
+                 uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+                 uint32_t size, uint32_t row = 0) {
+
+    const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
+        kp::shader_data::op_addrow_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        uint32_t row;
+    } const pushConsts {
+        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        row
+    };
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__))
+        s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
+    else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({size});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_mul(
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
+    int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
+    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+    int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
+    int32_t ne0,
+    int32_t nb0,  int32_t nb1,  int32_t nb2,  int32_t nb3
+) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
+        kp::shader_data::op_mul_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00;
+        int32_t nb00, nb01, nb02, nb03;
+        int32_t ne10, ne11, ne12, ne13;
+        int32_t nb10, nb11, nb12, nb13;
+        int32_t ne0;
+        int32_t nb0, nb1, nb2, nb3;
+    } const pushConsts {
+        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00,
+        nb00, nb01, nb02, nb03,
+        ne10, ne11, ne12, ne13,
+        nb10, nb11, nb12, nb13,
+        ne0,
+        nb0, nb1, nb2, nb3
+    };
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__)) {
+        s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_scale(kp::Sequence& seq,
+                   const std::shared_ptr<kp::Tensor>& in,
+                   const std::shared_ptr<kp::Tensor>& out,
+                   uint32_t inOff, uint32_t outOff,
+                   uint32_t size, float scale) {
+    const static auto spirv_1 = getSpirvShader(
+        kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
+    );
+    const static auto spirv_8 = getSpirvShader(
+        kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
+    );
+
+    struct PushConstants {
+        uint32_t inOff, outOff;
+        float scale;
+    } const pushConsts {
+        safe_divide(inOff, 4), safe_divide(outOff, 4),
+        scale
+    };
+
+    const auto * spirv = &spirv_1;
+    std::string name(__func__);
+    if (size % 8 == 0) {
+        size /= 8;
+        name += "_8";
+        spirv = &spirv_8;
+    }
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(name)) {
+        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(name);
+        s_algo->setTensors({in, out});
+        s_algo->setWorkgroup({size});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_xxlu(
+    const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& in,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inOff, uint32_t outOff,
+    uint32_t size
+) {
+    struct PushConstants {
+        uint32_t inOff, outOff;
+    } const pushConsts {
+        safe_divide(inOff, 4), safe_divide(outOff, 4),
+    };
+
+    auto name = std::string(__func__) + "_" + suffix;
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(name)) {
+        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(name);
+        s_algo->setTensors({in, out});
+        s_algo->setWorkgroup({size});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_silu(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
+        kp::shader_data::op_silu_comp_spv_len);
+
+    ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_relu(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
+        kp::shader_data::op_relu_comp_spv_len);
+
+    ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_gelu(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
+        kp::shader_data::op_gelu_comp_spv_len);
+
+    ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
+}
+
+static void ggml_vk_soft_max(
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
+    float scale
+) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
+        kp::shader_data::op_softmax_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00, ne01, ne02;
+        float scale;
+        int32_t mask;
+    } pushConsts {
+        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00, ne01, ne02,
+        scale,
+        bool(inB)
+    };
+
+    auto & inB_ = inB ? inB : inA;
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__)) {
+        // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
+        const uint32_t local_x = 32;
+        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({inA, inB_, out});
+        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_norm_(
+    const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& in,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inOff, uint32_t outOff,
+    int32_t ne00, int32_t nb01,
+    int32_t nrows, float epsilon
+) {
+    GGML_ASSERT(nb01%sizeof(float) == 0);
+    GGML_ASSERT(ne00%sizeof(float) == 0);
+
+    struct PushConstants {
+        uint32_t inOff, outOff;
+        uint32_t ne00, nb01;
+        float eps;
+    } pushConsts {
+        safe_divide(inOff, 4), safe_divide(outOff, 4),
+        (uint32_t)ne00, (uint32_t)nb01, epsilon
+    };
+
+    auto name = std::string(__func__) + "_" + suffix;
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(name)) {
+        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(name);
+        s_algo->setTensors({in, out});
+        s_algo->setWorkgroup({(uint32_t)nrows});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_norm(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
+        kp::shader_data::op_norm_comp_spv_len);
+
+    ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_rms_norm(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
+        kp::shader_data::op_rmsnorm_comp_spv_len);
+
+    ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
+}
+
+static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
+                           const std::shared_ptr<kp::Tensor>& in,
+                           const std::shared_ptr<kp::Tensor>& out,
+                           uint32_t inOff, uint32_t outOff,
+                           uint32_t n_past,
+                           int32_t ne00, int32_t ne01, int32_t ne02) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
+        kp::shader_data::op_diagmask_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inOff, outOff;
+        uint32_t n_past;
+        int32_t ne00, ne01;
+    } pushConsts {
+        safe_divide(inOff, 4), safe_divide(outOff, 4),
+        n_past,
+        ne00, ne01
+    };
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__))
+        s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
+    else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({in, out});
+        s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_mul_mat_f16(
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    int32_t ne00, int32_t ne01, int32_t ne02,
+    uint32_t nb00, uint32_t nb01, uint32_t nb02,
+    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+    uint32_t nb10, uint32_t nb11, uint32_t nb12,
+    int32_t ne0, int32_t ne1,
+    uint32_t r2, uint32_t r3
+) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
+        kp::shader_data::op_mul_mat_f16_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00, ne01, ne02;
+        uint32_t nb00, nb01, nb02;
+        int32_t ne10, ne11, ne12;
+        uint32_t nb10, nb11, nb12;
+        int32_t ne0, ne1;
+        uint32_t r2, r3;
+    } pushConsts {
+        safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00, ne01, ne02,
+        nb00, nb01, nb02,
+        ne10, ne11, ne12,
+        nb10, nb11, nb12,
+        ne0, ne1,
+        r2, r3
+    };
+
+    const unsigned ny = unsigned((ne11 + 4 - 1)/4);
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__)) {
+        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
+        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
+                         const std::shared_ptr<kp::Tensor>& inA,
+                         const std::shared_ptr<kp::Tensor>& inB,
+                         const std::shared_ptr<kp::Tensor>& out,
+                         uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+                         int32_t ne00, int32_t ne01, int32_t ne02,
+                         uint32_t nb01, uint32_t nb02,
+                         int32_t ne11, int32_t ne12,
+                         uint32_t nb11, uint32_t nb12,
+                         uint32_t nb1, uint32_t nb2) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
+        kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00, ne01, ne02, ne11, ne12;
+        uint32_t nb01, nb02;
+        uint32_t nb11, nb12;
+        uint32_t nb1, nb2;
+    } pushConsts {
+        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00, ne01, ne02, ne11, ne12,
+        nb01, nb02, nb11, nb12,
+        nb1, nb2
+    };
+
+    const uint32_t local_x = ggml_vk_current_device().subgroupSize;
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__)) {
+        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(),
+        {inA, inB, out}, spirv,
+        {unsigned(ne01),
+         unsigned(ne11),
+         unsigned(std::max(ne12, ne02))
+         },
+        {local_x},
+        {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({unsigned(ne01),
+                              unsigned(ne11),
+                              unsigned(std::max(ne12, ne02)),
+                              });
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_mul_mat_impl(
+    const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    int32_t ne00, int32_t ne01, int32_t ne02,
+    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+    int32_t ne0, int32_t ne1,
+    uint32_t r2, uint32_t r3
+) {
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00, ne01, ne02;
+        int32_t ne10, ne12;
+        int32_t ne0, ne1;
+        uint32_t r2, r3;
+    } pushConsts {
+        safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00, ne01, ne02,
+        ne10, ne12,
+        ne0, ne1,
+        r2, r3
+    };
+
+    auto name = std::string(__func__) + "_" + suffix;
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(name)) {
+        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
+        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(name);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_mul_mat_q4_0(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
+        kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
+
+    ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_mul_mat_q4_1(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
+        kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
+
+    ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_mul_mat_q8_0(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
+        kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
+
+    ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
+}
+
+static void ggml_vk_mul_mat_q6_k(
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
+    int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
+) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
+        kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00, ne10, ne0, ne1, ne01, gqa;
+    } pushConsts {
+        inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00, ne10, ne0, ne1, ne01, ne12/ne02
+    };
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__)) {
+        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
+        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_get_rows(
+    const std::vector<uint32_t>& spirv,
+    const char * suffix,
+    unsigned element_size, unsigned qk,
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    int32_t ne00, int32_t nb01, int32_t nb1,
+    uint32_t size
+) {
+    GGML_ASSERT(nb01%element_size == 0);
+    GGML_ASSERT(nb1%sizeof(float) == 0);
+    if (qk) GGML_ASSERT(ne00%qk == 0);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00, nb01, nb1;
+    } pushConsts {
+        safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+        ne00, nb01, nb1
+    };
+
+    auto name = std::string(__func__) + "_" + suffix;
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(name)) {
+        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(name);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({size});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_f16(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
+        kp::shader_data::op_getrows_f16_comp_spv_len);
+
+    ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_q4_0(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
+        kp::shader_data::op_getrows_q4_0_comp_spv_len);
+
+    ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_q4_1(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
+        kp::shader_data::op_getrows_q4_1_comp_spv_len);
+
+    ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_q6_k(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
+        kp::shader_data::op_getrows_q6_k_comp_spv_len);
+    ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
+}
+
+static void ggml_vk_rope(
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
+    float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
+    int32_t ne01, int32_t ne02, int32_t ne03,
+    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
+    int32_t ne0,
+    uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
+) {
+    GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
+
+    static const auto spirv_f16 = getSpirvShader(
+        kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
+    );
+    static const auto spirv_f32 = getSpirvShader(
+        kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
+    );
+
+    int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
+
+    GGML_ASSERT(nb03 % type_size == 0);
+    GGML_ASSERT(nb02 % type_size == 0);
+    GGML_ASSERT(nb01 % type_size == 0);
+    GGML_ASSERT(nb00 % type_size == 0);
+    GGML_ASSERT(nb3  % type_size == 0);
+    GGML_ASSERT(nb2  % type_size == 0);
+    GGML_ASSERT(nb1  % type_size == 0);
+    GGML_ASSERT(nb0  % type_size == 0);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t n_dims, mode, n_orig_ctx;
+        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+        uint32_t nb00, nb01, nb02, nb03;
+        int32_t ne0;
+        uint32_t nb0, nb1, nb2, nb3;
+    } pushConsts {
+        safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
+        n_dims, mode, n_orig_ctx,
+        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
+        nb00, nb01, nb02, nb03,
+        ne0,
+        nb0, nb1, nb2, nb3
+    };
+
+    auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(name)) {
+        s_algo = komputeManager()->algorithm<float, PushConstants>(
+            name, s_kompute_context->pool.get(), {inA, inB, out},
+            src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
+            {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
+        );
+    } else {
+        s_algo = komputeManager()->getAlgorithm(name);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_cpy(
+    const std::vector<uint32_t>& spirv,
+    uint32_t in_element_size, uint32_t out_element_size,
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& in,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inOff, uint32_t outOff,
+    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
+    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
+    int32_t ne0, int32_t ne1, int32_t ne2,
+    uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
+) {
+    struct PushConstants {
+        uint32_t inOff, outOff;
+        int32_t ne00, ne01, ne02;
+        uint32_t nb00, nb01, nb02, nb03;
+        int32_t ne0, ne1, ne2;
+        uint32_t nb0, nb1, nb2, nb3;
+    } pushConsts {
+        safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
+        ne00, ne01, ne02,
+        nb00, nb01, nb02, nb03,
+        ne0, ne1, ne2,
+        nb0, nb1, nb2, nb3
+    };
+
+    std::string name = std::string(__func__)
+                       + "_i_" + std::to_string(in_element_size)
+                       + "_o_" + std::to_string(out_element_size);
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(name))
+        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
+    else {
+        s_algo = komputeManager()->getAlgorithm(name);
+        s_algo->setTensors({in, out});
+        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_cpy_f32_f16(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
+        kp::shader_data::op_cpy_f32_f16_comp_spv_len);
+    ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_cpy_f32_f32(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
+        kp::shader_data::op_cpy_f32_f32_comp_spv_len);
+    ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_cpy_f16_f16(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
+        kp::shader_data::op_cpy_f16_f16_comp_spv_len);
+    ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_cpy_f16_f32(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
+        kp::shader_data::op_cpy_f16_f32_comp_spv_len);
+    ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
+}
+
+static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
+    switch (op->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_F32:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            break;
+        default:
+            return false;
+    }
+
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_SILU:
+                    return true;
+                default:
+                    ;
+            }
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_ADD:
+        case GGML_OP_MUL:
+        case GGML_OP_SCALE:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_NORM:
+        case GGML_OP_ROPE:
+            return true;
+        case GGML_OP_DUP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                    break;
+                default:
+                    return false;
+            }
+            switch (op->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                    break;
+                default:
+                    return false;
+            }
+            return true;
+        case GGML_OP_DIAG_MASK_INF:
+            return op->ne[3] == 1;
+        case GGML_OP_GET_ROWS:
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F16:
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                case GGML_TYPE_Q6_K:
+                    return op->ne[2] == 1 && op->ne[3] == 1;
+                default:
+                    ;
+            }
+            return false;
+        case GGML_OP_MUL_MAT:
+            if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
+                return false;
+
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_Q6_K:
+                    return op->ne[3] == 1;
+                case GGML_TYPE_F16:
+                case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                    return true;
+                default:
+                    ;
+            }
+        default:
+            ;
+    }
+    return false;
+}
+
+static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
+    const int n_seq = 8;
+
+    // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
+    // it to the size of the graph, but I think it can be made smaller?
+    ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
+
+    std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
+
+    for (auto& sequence : sequences) {
+        sequence = komputeManager()->sequence();
+    }
+    for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
+        const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
+
+        auto& seq = *sequences[seq_idx];
+
+        const int node_start = (seq_idx + 0) * n_nodes_per_seq;
+        const int node_end   = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
+
+        bool any_commands_recorded = false;
+
+        for (int i = node_start; i < node_end; ++i) {
+            struct ggml_tensor * src0 = gf->nodes[i]->src[0];
+            struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+            struct ggml_tensor * dst = gf->nodes[i];
+            GGML_ASSERT(dst->data != nullptr);
+
+            switch (dst->op) {
+                case GGML_OP_NONE:
+                case GGML_OP_RESHAPE:
+                case GGML_OP_VIEW:
+                case GGML_OP_TRANSPOSE:
+                case GGML_OP_PERMUTE:
+                    continue; // noop -> next node
+                default:
+                    break;
+            }
+
+            any_commands_recorded = true;
+
+            if (!ggml_vk_supports_op(dst)) {
+                 fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+                 GGML_ASSERT(!"unsupported op");
+            }
+
+            const int32_t ne00 = src0 ? src0->ne[0] : 0;
+            const int32_t ne01 = src0 ? src0->ne[1] : 0;
+            const int32_t ne02 = src0 ? src0->ne[2] : 0;
+            const int32_t ne03 = src0 ? src0->ne[3] : 0;
+
+            const uint32_t nb00 = src0 ? src0->nb[0] : 0;
+            const uint32_t nb01 = src0 ? src0->nb[1] : 0;
+            const uint32_t nb02 = src0 ? src0->nb[2] : 0;
+            const uint32_t nb03 = src0 ? src0->nb[3] : 0;
+
+            const int32_t ne10 = src1 ? src1->ne[0] : 0;
+            const int32_t ne11 = src1 ? src1->ne[1] : 0;
+            const int32_t ne12 = src1 ? src1->ne[2] : 0;
+            const int32_t ne13 = src1 ? src1->ne[3] : 0;
+
+            const uint32_t nb10 = src1 ? src1->nb[0] : 0;
+            const uint32_t nb11 = src1 ? src1->nb[1] : 0;
+            const uint32_t nb12 = src1 ? src1->nb[2] : 0;
+            const uint32_t nb13 = src1 ? src1->nb[3] : 0;
+
+            const int32_t ne0 = dst ? dst->ne[0] : 0;
+            const int32_t ne1 = dst ? dst->ne[1] : 0;
+            const int32_t ne2 = dst ? dst->ne[2] : 0;
+//            const int32_t ne3 = dst ? dst->ne[3] : 0;
+
+            const uint32_t nb0 = dst ? dst->nb[0] : 0;
+            const uint32_t nb1 = dst ? dst->nb[1] : 0;
+            const uint32_t nb2 = dst ? dst->nb[2] : 0;
+            const uint32_t nb3 = dst ? dst->nb[3] : 0;
+
+            const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+            const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+            const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
+
+            const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
+            uint32_t off_src0 = 0;
+            uint32_t off_src1 = 0;
+            uint32_t off_dst  = 0;
+            const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
+            const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
+            const std::shared_ptr<kp::Tensor>& id_dst  = dst  ? ggml_vk_get_tensor(dst,  &off_dst)  : nullTensor;
+
+            switch (dst->op) {
+                case GGML_OP_ADD:
+                    {
+                        if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+                            // src1 is a row
+                            ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
+                        } else {
+                            ggml_vk_add(
+                                seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                                ne00, ne01, ne02, ne03,
+                                nb00, nb01, nb02, nb03,
+                                ne10, ne11, ne12, ne13,
+                                nb10, nb11, nb12, nb13,
+                                ne0,
+                                nb0, nb1, nb2, nb3
+                            );
+                        }
+                    } break;
+                case GGML_OP_MUL:
+                    {
+                        ggml_vk_mul(
+                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                            ne00, ne01, ne02, ne03,
+                            nb00, nb01, nb02, nb03,
+                            ne10, ne11, ne12, ne13,
+                            nb10, nb11, nb12, nb13,
+                            ne0,
+                            nb0, nb1, nb2, nb3
+                        );
+                    } break;
+                case GGML_OP_SCALE:
+                    {
+                        float scale; memcpy(&scale, dst->op_params, sizeof(float));
+
+                        ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
+                    } break;
+                case GGML_OP_UNARY:
+                    {
+                        int64_t n = ggml_nelements(dst);
+                        GGML_ASSERT(n % 4 == 0);
+                        switch (ggml_get_unary_op(gf->nodes[i])) {
+                            case GGML_UNARY_OP_SILU:
+                                {
+                                    ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
+                                } break;
+                            case GGML_UNARY_OP_RELU:
+                                {
+                                    ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
+                                } break;
+                            case GGML_UNARY_OP_GELU:
+                                {
+                                    GGML_ASSERT(n % 8 == 0);
+                                    ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
+                                } break;
+                            default:
+                                {
+                                    fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                                    GGML_ASSERT(false);
+                                }
+                        }
+                    } break;
+                case GGML_OP_SOFT_MAX:
+                    {
+                        float scale;
+                        memcpy(&scale, dst->op_params, sizeof(float));
+                        ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
+                    } break;
+                case GGML_OP_DIAG_MASK_INF:
+                    {
+                        const int n_past = ((int32_t *)(dst->op_params))[0];
+                        ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
+                    } break;
+                case GGML_OP_NORM:
+                    {
+                        float eps;
+                        memcpy(&eps, dst->op_params, sizeof(float));
+                        ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
+                    } break;
+                case GGML_OP_RMS_NORM:
+                    {
+                        GGML_ASSERT(ne00 % 4 == 0);
+
+                        float eps;
+                        memcpy(&eps, dst->op_params, sizeof(float));
+                        ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
+                    } break;
+                case GGML_OP_MUL_MAT:
+                    {
+                        GGML_ASSERT(ne00 == ne10);
+
+                        // TODO: assert that dim2 and dim3 are contiguous
+                        GGML_ASSERT(ne12 % ne02 == 0);
+                        GGML_ASSERT(ne13 % ne03 == 0);
+
+                        const uint32_t r2 = ne12/ne02;
+                        const uint32_t r3 = ne13/ne03;
+
+                        if (src1t != GGML_TYPE_F32) {
+                            fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
+                            goto not_implemented;
+                        }
+
+                        if (ggml_is_transposed(src0) ||
+                            ggml_is_transposed(src1)) {
+                            fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
+                            goto not_implemented;
+                        }
+
+                        switch (src0t) {
+                            case GGML_TYPE_F32:
+                                ggml_vk_mul_mat_mat_f32(
+                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                                    ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
+                                );
+                                break;
+                            case GGML_TYPE_F16:
+                                ggml_vk_mul_mat_f16(
+                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                                    ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
+                                    ne0, ne1, r2, r3
+                                );
+                                break;
+                            case GGML_TYPE_Q8_0:
+                                ggml_vk_mul_mat_q8_0(
+                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+                                );
+                                break;
+                            case GGML_TYPE_Q4_0:
+                                ggml_vk_mul_mat_q4_0(
+                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+                                );
+                                break;
+                            case GGML_TYPE_Q4_1:
+                                ggml_vk_mul_mat_q4_1(
+                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+                                );
+                                break;
+                            case GGML_TYPE_Q6_K:
+                                ggml_vk_mul_mat_q6_k(
+                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                                    ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
+                                );
+                                break;
+                            default: {
+                                fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
+                                goto not_implemented;
+                            }
+                        }
+
+                    } break;
+                case GGML_OP_GET_ROWS:
+                    {
+                        if (src0t == GGML_TYPE_F16) {
+                            ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+                        } else if (src0t == GGML_TYPE_Q4_0) {
+                            ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+                        } else if (src0t == GGML_TYPE_Q4_1) {
+                            ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+                        } else if (src0t == GGML_TYPE_Q6_K) {
+                            ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+                        } else {
+                            fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
+                            goto not_implemented;
+                        }
+                    } break;
+                case GGML_OP_ROPE:
+                    {
+                        GGML_ASSERT(ne10 == ne02);
+                        GGML_ASSERT(src0t == dstt);
+                        // const int n_past = ((int32_t *) dst->op_params)[0];
+                        const int n_dims     = ((int32_t *) dst->op_params)[1];
+                        const int mode       = ((int32_t *) dst->op_params)[2];
+                        // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
+                        const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+
+                        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+                        memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+                        memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+                        memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+                        memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+                        memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+                        memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+                        ggml_vk_rope(
+                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
+                            freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
+                            ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
+                        );
+                    } break;
+                case GGML_OP_DUP:
+                case GGML_OP_CPY:
+                case GGML_OP_CONT:
+                    {
+                        switch (src0t) {
+                            case GGML_TYPE_F32:
+                                {
+                                    switch (dstt) {
+                                        case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
+                                        case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
+                                        default: goto not_implemented;
+                                    }
+                                } break;
+                            case GGML_TYPE_F16:
+                                {
+                                    switch (dstt) {
+                                        case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
+                                        case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
+                                    default: goto not_implemented;
+                                } break;
+                            default: goto not_implemented;
+                            }
+                        }
+                    } break;
+                default: goto not_implemented;
+            }
+            continue;
+            not_implemented: {}
+            fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+            //GGML_ASSERT(false);
+        }
+
+        // Evaluate sequence
+        if (any_commands_recorded) {
+            seq.evalAsync();
+        }
+    }
+
+    // Wait for all sequences to finish
+    for (auto& sequence : sequences) {
+        if (sequence->isRunning())
+            sequence->evalAwait();
+    }
+
+    ggml_vk_free_descriptor_pool(ctx);
+}
+
+template<>
+kp::Tensor::TensorDataTypes
+kp::TensorT<half>::dataType()
+{
+    return TensorDataTypes::eFloat;
+}
+
+template<>
+kp::Tensor::TensorDataTypes
+kp::TensorT<uint8_t>::dataType()
+{
+    return TensorDataTypes::eUnsignedInt;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend interface
+
+struct ggml_backend_kompute_buffer_type_context {
+    int         device;
+    int         device_ref = 0;
+    uint64_t    buffer_alignment;
+    uint64_t    max_alloc;
+    std::string name;
+
+    ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
+        : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
+};
+
+static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
+    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+
+    if (!ctx->device_ref) {
+        komputeManager()->initializeDevice(
+            ctx->device, {}, {
+                "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
+                "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
+            }
+        );
+    }
+
+    assert(ggml_vk_has_device());
+    ctx->device_ref++;
+}
+
+static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
+    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+
+    assert(ctx->device_ref > 0);
+
+    ctx->device_ref--;
+
+    if (!ctx->device_ref) {
+        komputeManager.destroy();
+    }
+}
+
+static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) {
+    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buffer->buft->context);
+    return ctx->name.c_str();
+}
+
+static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    auto * memory = (ggml_vk_memory *)buffer->context;
+    if (ggml_vk_has_device()) {
+        ggml_vk_free_memory(*memory);
+    }
+    delete memory;
+}
+
+static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
+    return ((ggml_vk_memory *)buffer->context)->data;
+}
+
+static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_UNUSED(buffer);
+
+    const auto res = ggml_vk_get_tensor(tensor);
+    GGML_ASSERT(res);
+
+    memcpy((char *)tensor->data + offset, data, size);
+
+    komputeManager()->sequence()->eval<kp::OpTensorSyncDevice>({res});
+}
+
+static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_UNUSED(buffer);
+
+    const auto res = ggml_vk_get_tensor(tensor);
+    GGML_ASSERT(res);
+
+    komputeManager()->sequence()->eval<kp::OpTensorSyncLocal>({res});
+
+    memcpy(data, (const char *)tensor->data + offset, size);
+}
+
+static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    auto * memory = (ggml_vk_memory *)buffer->context;
+    memset(memory->data, value, buffer->size);
+
+    if (memory->stagingBuffer)
+        komputeManager()->sequence()->eval<kp::OpBufferSyncDevice>(memory->primaryBuffer, memory->stagingBuffer, memory->size);
+}
+
+static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
+    /* .get_name        = */ ggml_backend_kompute_buffer_get_name,
+    /* .free_buffer     = */ ggml_backend_kompute_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_kompute_buffer_get_base,
+    /* .init_tensor     = */ NULL,
+    /* .set_tensor      = */ ggml_backend_kompute_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_kompute_buffer_get_tensor,
+    /* .cpy_tensor      = */ NULL,
+    /* .clear           = */ ggml_backend_kompute_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// default buffer type
+
+static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+    return ctx->name.c_str();
+}
+
+static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    ggml_backend_kompute_device_ref(buft);
+    auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
+    return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
+}
+
+static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+    return ctx->buffer_alignment;
+}
+
+static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+    return ctx->max_alloc;
+}
+
+static bool ggml_backend_kompute_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    GGML_UNUSED(buft);
+    return ggml_backend_is_kompute(backend);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_kompute_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_kompute_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_kompute_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+    /* .supports_backend = */ ggml_backend_kompute_buffer_type_supports_backend,
+    /* .is_host          = */ NULL,
+};
+
+ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
+    static std::vector<ggml_backend_buffer_type> bufts = []() {
+        std::vector<ggml_backend_buffer_type> vec;
+        auto devices = ggml_vk_available_devices_internal(0);
+        vec.reserve(devices.size());
+
+        for (const auto & dev : devices) {
+            vec.push_back({
+                /* .iface   = */ ggml_backend_kompute_buffer_type_interface,
+                /* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
+            });
+        }
+        return vec;
+    }();
+
+    auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
+        return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
+    });
+    return it < bufts.end() ? &*it : nullptr;
+}
+
+// backend
+
+static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
+    auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
+    return ctx->name.c_str();
+}
+
+static void ggml_backend_kompute_free(ggml_backend_t backend) {
+    auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
+
+    assert(ctx == s_kompute_context);
+    s_kompute_context = nullptr;
+    if (ctx != nullptr) {
+        delete ctx;
+    }
+
+    delete backend;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) {
+    auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
+    return ggml_backend_kompute_buffer_type(ctx->device);
+}
+
+static bool ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
+    ggml_vk_graph_compute(ctx, cgraph);
+    return true;
+}
+
+static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    GGML_UNUSED(backend);
+    return ggml_vk_supports_op(op);
+}
+
+static struct ggml_backend_i kompute_backend_i = {
+    /* .get_name                = */ ggml_backend_kompute_name,
+    /* .free                    = */ ggml_backend_kompute_free,
+    /* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_kompute_graph_compute,
+    /* .supports_op             = */ ggml_backend_kompute_supports_op,
+};
+
+ggml_backend_t ggml_backend_kompute_init(int device) {
+    GGML_ASSERT(s_kompute_context == nullptr);
+    s_kompute_context = new ggml_kompute_context(device);
+
+    ggml_backend_t kompute_backend = new ggml_backend {
+        /* .interface = */ kompute_backend_i,
+        /* .context   = */ s_kompute_context,
+    };
+
+    return kompute_backend;
+}
+
+bool ggml_backend_is_kompute(ggml_backend_t backend) {
+    return backend && backend->iface.get_name == ggml_backend_kompute_name;
+}
+
+static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
+    GGML_UNUSED(params);
+    return ggml_backend_kompute_init(intptr_t(user_data));
+}
+
+extern "C" int ggml_backend_kompute_reg_devices();
+
+int ggml_backend_kompute_reg_devices() {
+    auto devices = ggml_vk_available_devices_internal(0);
+    for (const auto & device : devices) {
+        ggml_backend_register(
+            ggml_kompute_format_name(device.index).c_str(),
+            ggml_backend_reg_kompute_init,
+            ggml_backend_kompute_buffer_type(device.index),
+            reinterpret_cast<void *>(intptr_t(device.index))
+        );
+    }
+    return devices.size();
+}
diff --git a/ggml-kompute.h b/ggml-kompute.h
new file mode 100644 (file)
index 0000000..1714654
--- /dev/null
@@ -0,0 +1,46 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct ggml_vk_device {
+    int index;
+    int type; // same as VkPhysicalDeviceType
+    size_t heapSize;
+    const char * name;
+    const char * vendor;
+    int subgroupSize;
+    uint64_t bufferAlignment;
+    uint64_t maxAlloc;
+};
+
+struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
+bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
+bool ggml_vk_has_vulkan(void);
+bool ggml_vk_has_device(void);
+struct ggml_vk_device ggml_vk_current_device(void);
+
+//
+// backend API
+//
+
+// forward declaration
+typedef struct ggml_backend * ggml_backend_t;
+
+GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
+
+GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
+
+GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/kompute b/kompute
new file mode 160000 (submodule)
index 0000000..4565194
--- /dev/null
+++ b/kompute
@@ -0,0 +1 @@
+Subproject commit 4565194ed7c32d1d2efa32ceab4d3c6cae006306
diff --git a/kompute-shaders/common.comp b/kompute-shaders/common.comp
new file mode 100644 (file)
index 0000000..62d62b0
--- /dev/null
@@ -0,0 +1,102 @@
+#extension GL_EXT_shader_16bit_storage: require
+#extension GL_EXT_shader_8bit_storage: require
+#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
+#extension GL_EXT_control_flow_attributes: enable
+#extension GL_KHR_shader_subgroup_arithmetic : require
+#extension GL_EXT_debug_printf : enable
+
+#define QK4_0 32
+#define QK4_1 32
+
+#define GELU_COEF_A 0.044715
+#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
+#define TWOPI_F 6.283185307179586f
+
+#define QK_K 256
+
+#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
+#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
+#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
+#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
+
+#define sizeof_block_q4_0 0x12
+struct block_q4_0 {
+    float16_t d;
+    uint8_t qs[QK4_0 / 2];
+};
+mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
+    const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.f * xb.d;
+    const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
+    const uint16_t mask1 = mask0 << 8;
+
+    mat4 reg;
+    for (int i=0;i<8;i++) {
+        uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
+        reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
+        reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
+    }
+    return reg;
+}
+
+#define sizeof_block_q4_1 0x14
+struct block_q4_1 {
+    float16_t d;
+    float16_t m;
+    uint8_t qs[QK4_1 / 2];
+};
+mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
+    const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb.m;
+    const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
+    const uint16_t mask1 = mask0 << 8;
+
+    mat4 reg;
+    for (int i=0;i<8;i++) {
+        uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
+        reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
+        reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
+    }
+    return reg;
+}
+
+#define sizeof_block_q6_k 210
+struct block_q6_k {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    float16_t d;             // super-block scale
+};
+mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
+    const float16_t d_all = xb.d;
+
+    const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+    const uint qhIndex = 32*(il/8) + 16*(il&1);
+    float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
+    il = (il/2) & 3;
+
+    const uint16_t  kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
+    const uint16_t  kmask2 = il>1 ? uint8_t(0xF0)             : uint8_t(0x0F);
+    const float16_t coef   = il>1 ? float16_t(1.f/16.f)       : float16_t(1.f);
+    const float16_t ml = float16_t(d_all * sc * 32.f);
+    const float16_t dl = float16_t(d_all * sc * coef);
+    mat4 reg;
+    for (int i = 0; i < 16; ++i) {
+        const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
+                                        : ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
+        reg[i/4][i%4] = dl * q - ml;
+    }
+    return reg;
+}
+
+
+#define QK8_0 32
+// struct block_q8_0 {
+//     float16_t d;         // delta
+//     int8_t    qs[QK8_0]; // quants
+// };
+#define sizeof_block_q8_0 34
diff --git a/kompute-shaders/op_add.comp b/kompute-shaders/op_add.comp
new file mode 100644 (file)
index 0000000..b7b76a7
--- /dev/null
@@ -0,0 +1,58 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1024) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int nb00;
+    int nb01;
+    int nb02;
+    int nb03;
+    int ne10;
+    int ne11;
+    int ne12;
+    int ne13;
+    int nb10;
+    int nb11;
+    int nb12;
+    int nb13;
+    int ne0;
+    int nb0;
+    int nb1;
+    int nb2;
+    int nb3;
+  //int offs; // TODO: needed for GGML_OP_ACC, see metal code
+} pcs;
+
+// general-purpose kernel for addition of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
+// cons: not very efficient
+void main() {
+    const uint i03 = gl_WorkGroupID.z;
+    const uint i02 = gl_WorkGroupID.y;
+    const uint i01 = gl_WorkGroupID.x;
+
+    const uint i13 = i03 % pcs.ne13;
+    const uint i12 = i02 % pcs.ne12;
+    const uint i11 = i01 % pcs.ne11;
+
+    int offs = 0; // TMP (see above)
+
+    uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
+    uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11       ) / 4);
+    uint dst_off  = uint((i03*pcs.nb3  + i02*pcs.nb2  + i01*pcs.nb1  + offs) / 4);
+
+    for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
+        const uint i10 = i0 % pcs.ne10;
+        out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
+    }
+}
diff --git a/kompute-shaders/op_addrow.comp b/kompute-shaders/op_addrow.comp
new file mode 100644 (file)
index 0000000..2376a6b
--- /dev/null
@@ -0,0 +1,25 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    uint row;
+} pcs;
+
+void main() {
+    const uint baseIndex = gl_WorkGroupID.x * 4;
+
+    for (uint x = 0; x < 4; x++) {
+        const uint i = baseIndex + x;
+        out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
+    }
+}
diff --git a/kompute-shaders/op_cpy_f16_f16.comp b/kompute-shaders/op_cpy_f16_f16.comp
new file mode 100644 (file)
index 0000000..d57247d
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+#define IN_TYPE float16_t
+#define IN_TYPE_SIZE 2
+#define OUT_TYPE float16_t
+#define OUT_TYPE_SIZE 2
+
+layout(local_size_x = 1024) in;
+
+layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
+layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inOff;
+    uint outOff;
+    int ne00;
+    int ne01;
+    int ne02;
+    uint nb00;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    int ne0;
+    int ne1;
+    int ne2;
+    uint nb0;
+    uint nb1;
+    uint nb2;
+    uint nb3;
+} pcs;
+
+void main() {
+    const uint i03 = gl_WorkGroupID.z;
+    const uint i02 = gl_WorkGroupID.y;
+    const uint i01 = gl_WorkGroupID.x;
+
+    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
+
+    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
+    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
+    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
+    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
+
+    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
+
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
+        out_[dst_data+i00] = OUT_TYPE(in_[src]);
+    }
+}
diff --git a/kompute-shaders/op_cpy_f16_f32.comp b/kompute-shaders/op_cpy_f16_f32.comp
new file mode 100644 (file)
index 0000000..b568bcd
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+#define IN_TYPE float16_t
+#define IN_TYPE_SIZE 2
+#define OUT_TYPE float
+#define OUT_TYPE_SIZE 4
+
+layout(local_size_x = 1024) in;
+
+layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
+layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inOff;
+    uint outOff;
+    int ne00;
+    int ne01;
+    int ne02;
+    uint nb00;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    int ne0;
+    int ne1;
+    int ne2;
+    uint nb0;
+    uint nb1;
+    uint nb2;
+    uint nb3;
+} pcs;
+
+void main() {
+    const uint i03 = gl_WorkGroupID.z;
+    const uint i02 = gl_WorkGroupID.y;
+    const uint i01 = gl_WorkGroupID.x;
+
+    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
+
+    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
+    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
+    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
+    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
+
+    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
+
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
+        out_[dst_data+i00] = OUT_TYPE(in_[src]);
+    }
+}
diff --git a/kompute-shaders/op_cpy_f32_f16.comp b/kompute-shaders/op_cpy_f32_f16.comp
new file mode 100644 (file)
index 0000000..99b2283
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+#define IN_TYPE float
+#define IN_TYPE_SIZE 4
+#define OUT_TYPE float16_t
+#define OUT_TYPE_SIZE 2
+
+layout(local_size_x = 1024) in;
+
+layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
+layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inOff;
+    uint outOff;
+    int ne00;
+    int ne01;
+    int ne02;
+    uint nb00;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    int ne0;
+    int ne1;
+    int ne2;
+    uint nb0;
+    uint nb1;
+    uint nb2;
+    uint nb3;
+} pcs;
+
+void main() {
+    const uint i03 = gl_WorkGroupID.z;
+    const uint i02 = gl_WorkGroupID.y;
+    const uint i01 = gl_WorkGroupID.x;
+
+    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
+
+    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
+    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
+    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
+    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
+
+    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
+
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
+        out_[dst_data+i00] = OUT_TYPE(in_[src]);
+    }
+}
diff --git a/kompute-shaders/op_cpy_f32_f32.comp b/kompute-shaders/op_cpy_f32_f32.comp
new file mode 100644 (file)
index 0000000..2fc9984
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+#define IN_TYPE float
+#define IN_TYPE_SIZE 4
+#define OUT_TYPE float
+#define OUT_TYPE_SIZE 4
+
+layout(local_size_x = 1024) in;
+
+layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
+layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inOff;
+    uint outOff;
+    int ne00;
+    int ne01;
+    int ne02;
+    uint nb00;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    int ne0;
+    int ne1;
+    int ne2;
+    uint nb0;
+    uint nb1;
+    uint nb2;
+    uint nb3;
+} pcs;
+
+void main() {
+    const uint i03 = gl_WorkGroupID.z;
+    const uint i02 = gl_WorkGroupID.y;
+    const uint i01 = gl_WorkGroupID.x;
+
+    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
+
+    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
+    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
+    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
+    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
+
+    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
+
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
+        out_[dst_data+i00] = OUT_TYPE(in_[src]);
+    }
+}
diff --git a/kompute-shaders/op_diagmask.comp b/kompute-shaders/op_diagmask.comp
new file mode 100644 (file)
index 0000000..291c3fc
--- /dev/null
@@ -0,0 +1,30 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inOff;
+    uint outOff;
+    uint n_past;
+    int ne00;
+    int ne01;
+} pcs;
+
+void main() {
+    const uint i02 = gl_WorkGroupID.z;
+    const uint i01 = gl_WorkGroupID.y;
+    const uint i00 = gl_WorkGroupID.x;
+
+    const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00;
+
+    if (i00 > pcs.n_past + i01) {
+        out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000);
+    } else {
+        out_[index + pcs.outOff] = in_[index + pcs.inOff];
+    }
+}
diff --git a/kompute-shaders/op_gelu.comp b/kompute-shaders/op_gelu.comp
new file mode 100644 (file)
index 0000000..9d8c537
--- /dev/null
@@ -0,0 +1,22 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+layout(push_constant) uniform PushConstants {
+    uint inOff;
+    uint outOff;
+} pcs;
+
+void main() {
+    const uint baseIndex = gl_WorkGroupID.x * 8;
+
+    for (uint x = 0; x < 8; x++) {
+        const uint i = baseIndex + x;
+        const float y = in_[i + pcs.inOff];
+        out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0)));
+    }
+}
diff --git a/kompute-shaders/op_getrows.comp b/kompute-shaders/op_getrows.comp
new file mode 100644 (file)
index 0000000..1a5581b
--- /dev/null
@@ -0,0 +1,17 @@
+void main() {
+    const uint i = gl_WorkGroupID.x;
+    const int r = inB[i + pcs.inBOff];
+
+    int z = 0;
+    for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) {
+        const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK;
+        const mat4 result = dequantize_block(inIndex, ind%NL);
+        for (uint j = 0; j < 4; ++j) {
+            for (uint k = 0; k < 4; ++k) {
+                const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z;
+                out_[outIndex] = result[j][k];
+                ++z;
+            }
+        }
+    }
+}
diff --git a/kompute-shaders/op_getrows_f16.comp b/kompute-shaders/op_getrows_f16.comp
new file mode 100644 (file)
index 0000000..48c9361
--- /dev/null
@@ -0,0 +1,31 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int nb01;
+    int nb1;
+} pcs;
+
+void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
+    for (int j = 0; j < k; j++) {
+        out_[y + j] = inA[x + j];
+    }
+}
+
+void main() {
+    const uint i = gl_WorkGroupID.x;
+    const int r = inB[i + pcs.inBOff];
+
+    dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
+}
diff --git a/kompute-shaders/op_getrows_q4_0.comp b/kompute-shaders/op_getrows_q4_0.comp
new file mode 100644 (file)
index 0000000..32b2e89
--- /dev/null
@@ -0,0 +1,38 @@
+#version 450
+
+#include "common.comp"
+
+#define NL 2
+#define BYTES_FOR_TYPE 4 /*bytes for float*/
+#define SIZE_OF_BLOCK sizeof_block_q4_0
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int nb01;
+    int nb1;
+} pcs;
+
+block_q4_0 get_unaligned_block_q4_0(uint index) {
+    block_q4_0 fres;
+    fres.d = u8BufToFloat16(inA, index);
+    [[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) {
+        fres.qs[it] = inA[index+2+it];
+    }
+    return fres;
+}
+
+mat4 dequantize_block(uint index, uint il) {
+    const block_q4_0 block = get_unaligned_block_q4_0(index);
+    return dequantize_q4_0(block, il);
+}
+
+#include "op_getrows.comp"
diff --git a/kompute-shaders/op_getrows_q4_1.comp b/kompute-shaders/op_getrows_q4_1.comp
new file mode 100644 (file)
index 0000000..87f2fbe
--- /dev/null
@@ -0,0 +1,39 @@
+#version 450
+
+#include "common.comp"
+
+#define NL 2
+#define BYTES_FOR_TYPE 4 /*bytes for float*/
+#define SIZE_OF_BLOCK sizeof_block_q4_1
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int nb01;
+    int nb1;
+} pcs;
+
+block_q4_1 get_unaligned_block_q4_1(uint index) {
+    block_q4_1 fres;
+    fres.d = u8BufToFloat16(inA, index);
+    fres.m = u8BufToFloat16(inA, index+2);
+    [[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
+        fres.qs[it] = inA[index+4+it];
+    }
+    return fres;
+}
+
+mat4 dequantize_block(uint index, uint il) {
+    const block_q4_1 block = get_unaligned_block_q4_1(index);
+    return dequantize_q4_1(block, il);
+}
+
+#include "op_getrows.comp"
diff --git a/kompute-shaders/op_getrows_q6_k.comp b/kompute-shaders/op_getrows_q6_k.comp
new file mode 100644 (file)
index 0000000..9ce3545
--- /dev/null
@@ -0,0 +1,44 @@
+#version 450
+
+#include "common.comp"
+
+#define NL 16
+#define BYTES_FOR_TYPE 4 /*bytes for float*/
+#define SIZE_OF_BLOCK sizeof_block_q6_k
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int nb01;
+    int nb1;
+} pcs;
+
+block_q6_k get_unaligned_block_q6_k(uint index) {
+    block_q6_k fres;
+    [[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
+        fres.ql[it] = inA[index + it];
+    }
+    [[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
+        fres.qh[it] = inA[index + QK_K/2 + it];
+    }
+    [[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
+        fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
+    }
+    fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
+    return fres;
+}
+
+mat4 dequantize_block(uint index, uint il) {
+    const block_q6_k block = get_unaligned_block_q6_k(index);
+    return dequantize_q6_k(block, il);
+}
+
+#include "op_getrows.comp"
diff --git a/kompute-shaders/op_mul.comp b/kompute-shaders/op_mul.comp
new file mode 100644 (file)
index 0000000..c92647c
--- /dev/null
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1024) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int nb00;
+    int nb01;
+    int nb02;
+    int nb03;
+    int ne10;
+    int ne11;
+    int ne12;
+    int ne13;
+    int nb10;
+    int nb11;
+    int nb12;
+    int nb13;
+    int ne0;
+    int nb0;
+    int nb1;
+    int nb2;
+    int nb3;
+} pcs;
+
+void main() {
+    const uint i03 = gl_WorkGroupID.z;
+    const uint i02 = gl_WorkGroupID.y;
+    const uint i01 = gl_WorkGroupID.x;
+
+    const uint i13 = i03 % pcs.ne13;
+    const uint i12 = i02 % pcs.ne12;
+    const uint i11 = i01 % pcs.ne11;
+
+    uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4);
+    uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4);
+    uint dst_off  = uint((i03*pcs.nb3  + i02*pcs.nb2  + i01*pcs.nb1)  / 4);
+
+    for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
+        const uint i10 = i0 % pcs.ne10;
+        out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10];
+    }
+}
diff --git a/kompute-shaders/op_mul_mat_f16.comp b/kompute-shaders/op_mul_mat_f16.comp
new file mode 100644 (file)
index 0000000..8f0a903
--- /dev/null
@@ -0,0 +1,67 @@
+#version 450
+
+#include "common.comp"
+
+#extension GL_KHR_shader_subgroup_arithmetic : require
+
+layout(local_size_x_id = 0) in;
+
+layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { float inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int ne01;
+    int ne02;
+    uint nb00;
+    uint nb01;
+    uint nb02;
+    int ne10;
+    int ne11;
+    int ne12;
+    uint nb10;
+    uint nb11;
+    uint nb12;
+    int ne0;
+    int ne1;
+    uint r2;
+    uint r3;
+} pcs;
+
+#define N_F16_F32 4
+
+void main() {
+    const uint r0 = gl_WorkGroupID.x;
+    const uint rb = gl_WorkGroupID.y*N_F16_F32;
+    const uint im = gl_WorkGroupID.z;
+
+    const uint i12 = im%pcs.ne12;
+    const uint i13 = im/pcs.ne12;
+
+    const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
+
+    const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
+
+    for (uint row = 0; row < N_F16_F32; ++row) {
+        uint r1 = rb + row;
+        if (r1 >= pcs.ne11) {
+            break;
+        }
+
+        const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
+
+        float sumf = 0;
+        for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
+            sumf += float(inA[x+i]) * float(inB[y+i]);
+        }
+
+        const float all_sum = subgroupAdd(sumf);
+        if (subgroupElect()) {
+            out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
+        }
+    }
+}
diff --git a/kompute-shaders/op_mul_mat_mat_f32.comp b/kompute-shaders/op_mul_mat_mat_f32.comp
new file mode 100644 (file)
index 0000000..d1ca4ad
--- /dev/null
@@ -0,0 +1,51 @@
+#version 450
+
+#include "common.comp"
+
+#extension GL_KHR_shader_subgroup_arithmetic : require
+#extension GL_EXT_debug_printf : enable
+
+// device subgroup size
+layout (local_size_x_id = 0) in;
+
+layout(binding = 0) readonly buffer tensorInA { float inA[]; };
+layout(binding = 1) readonly buffer tensorInB { float inB[]; };
+layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout(push_constant) uniform parameter {
+  uint inAOff;
+  uint inBOff;
+  uint outOff;
+  int ne00;
+  int ne01;
+  int ne02;
+  int ne11;
+  int ne12;
+  uint nb01;
+  uint nb02;
+  uint nb11;
+  uint nb12;
+  uint nb1;
+  uint nb2;
+}
+pcs;
+
+
+void main() {
+  uvec3 gid = gl_WorkGroupID;
+
+  uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
+  uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
+
+  const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
+  const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
+  float sum = 0.0f;
+  for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
+      sum += float(inA[x+i]) * float(inB[y+i]);
+  }
+
+  const float all_sum = subgroupAdd(sum);
+  if (subgroupElect()) {
+    out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
+  }
+}
diff --git a/kompute-shaders/op_mul_mat_q4_0.comp b/kompute-shaders/op_mul_mat_q4_0.comp
new file mode 100644 (file)
index 0000000..b0cea8b
--- /dev/null
@@ -0,0 +1,33 @@
+#version 450
+
+#include "common.comp"
+
+#define BLOCKS_IN_QUANT QK4_0
+#define SIZE_OF_BLOCK sizeof_block_q4_0
+#define N_ROWS 4
+
+#include "op_mul_mv_q_n_pre.comp"
+
+// The q4_0 version of this function
+float block_q_n_dot_y(uint block_index, uint yb, uint il) {
+    vec2 acc = vec2(0.0, 0.0);
+    const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
+    float d = float(u8BufToFloat16(inA, index));
+    float sumy = 0.0f;
+    for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
+        const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
+
+        const float yl0 = inB[yb + i];
+        const float yl1 = inB[yb + i + 1];
+        const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
+        const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
+
+        sumy += yl0 + yl1 + yl8 + yl9;
+
+        acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
+        acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
+    }
+    return d * (sumy * -8.f + acc[0] + acc[1]);
+}
+
+#include "op_mul_mv_q_n.comp"
diff --git a/kompute-shaders/op_mul_mat_q4_1.comp b/kompute-shaders/op_mul_mat_q4_1.comp
new file mode 100644 (file)
index 0000000..8582c61
--- /dev/null
@@ -0,0 +1,35 @@
+#version 450
+
+#include "common.comp"
+
+#define BLOCKS_IN_QUANT QK4_1
+#define SIZE_OF_BLOCK sizeof_block_q4_1
+#define N_ROWS 4
+
+#include "op_mul_mv_q_n_pre.comp"
+
+// The q4_1 version of this function
+float block_q_n_dot_y(uint block_index, uint yb, uint il) {
+    vec2 acc = vec2(0.0, 0.0);
+    const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
+    float d = float(u8BufToFloat16(inA, index));
+    float m = float(u8BufToFloat16(inA, index+2));
+
+    float sumy = 0.0f;
+    for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
+        const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
+
+        const float yl0 = inB[yb + i];
+        const float yl1 = inB[yb + i + 1];
+        const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
+        const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
+
+        sumy += yl0 + yl1 + yl8 + yl9;
+
+        acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
+        acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
+    }
+    return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+#include "op_mul_mv_q_n.comp"
diff --git a/kompute-shaders/op_mul_mat_q6_k.comp b/kompute-shaders/op_mul_mat_q6_k.comp
new file mode 100644 (file)
index 0000000..c9baebd
--- /dev/null
@@ -0,0 +1,94 @@
+#version 450
+
+#include "common.comp"
+
+#define SIZE_OF_BLOCK sizeof_block_q6_k
+
+layout(local_size_x_id = 0) in;
+layout(local_size_y_id = 1) in;
+layout(local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { float inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int ne10;
+    int ne0;
+    int ne1;
+    int ne01;
+    int gqa;
+} pcs;
+
+void main() {
+    const uint8_t kmask1 = uint8_t(0x03);
+    const uint8_t kmask2 = uint8_t(0x0C);
+    const uint8_t kmask3 = uint8_t(0x30);
+    const uint8_t kmask4 = uint8_t(0xC0);
+
+    const uint nb = pcs.ne00/QK_K;
+
+    const uint r0 = gl_WorkGroupID.x;
+    const uint r1 = gl_WorkGroupID.y;
+    const uint r2 = gl_WorkGroupID.z;
+
+    const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
+    const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
+    const uint x = row * nb + offset0; // Based from inA without base offset
+    const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
+
+    float sumf = 0;
+
+    // bits of invocation ID for gl_SubgroupSize=32:
+    //  x   x   x   x   x
+    //  4   3   2   1   0
+    // (     tid     ) ix
+    //  ip (   il    )
+
+    const uint block_stride = gl_SubgroupSize / 16;         // number of blocks each subgroup processes
+    const uint tid  = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
+    const uint ix   = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
+    const uint ip   = tid/8;        // first or second half of block (0 or 1)
+    const uint il   = tid%8;        // each half has 8 parts, one per scale
+    const uint n    = 4;            // 4 scales at a time (and 4 sums)
+    const uint l0   = n*il;         // offset into half-block, 0..28
+    const uint is   = 8*ip + l0/16; // 0, 1, 8, 9
+
+    const uint y_offset = 128*ip + l0;
+    const uint q_offset_l = 64*ip + l0;
+    const uint q_offset_h = 32*ip + l0;
+
+    for (uint i = ix; i < nb; i += block_stride) {
+
+        const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
+
+        const uint qlIndex = q_offset_l;
+        const uint q2Index = qlIndex + QK_K/8;
+        const uint qhIndex = q_offset_h;
+        const uint y = yy + i * QK_K + y_offset;
+
+        float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
+        for (uint l = 0; l < n; ++l) {
+            const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
+            const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
+            const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
+
+            sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
+            sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
+            sums[2] += inB[y+l+64] * (int8_t((currentQ1  >> 4) | ((currentQh & kmask3) << 0)) - 32);
+            sums[3] += inB[y+l+96] * (int8_t((currentQ2  >> 4) | ((currentQh & kmask4) >> 2)) - 32);
+        }
+
+        float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
+        sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
+    }
+
+    const float tot = subgroupAdd(sumf);
+    if (subgroupElect()) {
+        out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
+    }
+}
diff --git a/kompute-shaders/op_mul_mat_q8_0.comp b/kompute-shaders/op_mul_mat_q8_0.comp
new file mode 100644 (file)
index 0000000..34d015e
--- /dev/null
@@ -0,0 +1,73 @@
+#version 450
+
+#include "common.comp"
+
+#include "op_mul_mv_q_n_pre.comp"
+
+#define SIZE_OF_D 2
+
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+#define NB_Q8_0 8
+
+void main() {
+    // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
+    if (gl_SubgroupInvocationID > 31)
+        return;
+
+    const int nr  = N_DST;
+    const int nsg = N_SIMDGROUP;
+    const int nw  = N_SIMDWIDTH;
+
+    const int nb = pcs.ne00/QK8_0;
+    const uint r0 = gl_WorkGroupID.x;
+    const uint r1 = gl_WorkGroupID.y;
+    const uint im = gl_WorkGroupID.z;
+
+    const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
+
+    const uint i12 = im%pcs.ne12;
+    const uint i13 = im/pcs.ne12;
+
+    const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
+
+    const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
+    const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
+
+    float yl[NB_Q8_0];
+    float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
+
+    const uint ix = gl_SubgroupInvocationID.x/4;
+    const uint il = gl_SubgroupInvocationID.x%4;
+
+    uint yb = y + ix * QK8_0 + NB_Q8_0*il;
+
+    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+    for (uint ib = ix; ib < nb; ib += nw/4) {
+        for (int i = 0; i < NB_Q8_0; ++i) {
+            yl[i] = inB[yb + i];
+        }
+
+        for (int row = 0; row < nr; row++) {
+            const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
+            float sumq = 0.f;
+            for (int iq = 0; iq < NB_Q8_0; ++iq) {
+                const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
+                sumq += qs_iq * yl[iq];
+            }
+            const float16_t d = u8BufToFloat16(inA, x + block_offset);
+            sumf[row] += sumq*d;
+        }
+
+        yb += NB_Q8_0 * nw;
+    }
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = subgroupAdd(sumf[row]);
+        if (subgroupElect() && first_row + row < pcs.ne01) {
+            out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
+        }
+    }
+}
diff --git a/kompute-shaders/op_mul_mv_q_n.comp b/kompute-shaders/op_mul_mv_q_n.comp
new file mode 100644 (file)
index 0000000..440b5ab
--- /dev/null
@@ -0,0 +1,48 @@
+void main() {
+    // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
+    if (gl_SubgroupInvocationID > 31)
+        return;
+
+    const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
+
+    const uint r0 = gl_WorkGroupID.x;
+    const uint r1 = gl_WorkGroupID.y;
+    const uint im = gl_WorkGroupID.z;
+
+    const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
+
+    const uint i12 = im%pcs.ne12;
+    const uint i13 = im/pcs.ne12;
+
+    const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
+
+    const uint x = offset0; // Based from inA without base offset
+    const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
+
+    float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
+
+    const uint ix = gl_SubgroupInvocationID/2;
+    const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
+
+    uint yb = y + ix * BLOCKS_IN_QUANT + il;
+
+    //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
+    //    gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
+    //    gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
+
+    for (uint ib = ix; ib < nb; ib += 16) {
+        for (int row = 0; row < N_ROWS; row++) {
+            const uint block_index = x + ib + row * nb;
+            sumf[row] += block_q_n_dot_y(block_index, yb, il);
+        }
+
+        yb += BLOCKS_IN_QUANT * 16;
+    }
+
+    for (int row = 0; row < N_ROWS; ++row) {
+        const float tot = subgroupAdd(sumf[row]);
+        if (first_row + row < pcs.ne01 && subgroupElect()) {
+            out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
+        }
+    }
+}
diff --git a/kompute-shaders/op_mul_mv_q_n_pre.comp b/kompute-shaders/op_mul_mv_q_n_pre.comp
new file mode 100644 (file)
index 0000000..7912b09
--- /dev/null
@@ -0,0 +1,22 @@
+layout(local_size_x_id = 0) in;
+layout(local_size_y = 1) in;
+layout(local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { float inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int  ne00;
+    int  ne01;
+    int  ne02;
+    int  ne10;
+    int  ne12;
+    int  ne0;
+    int  ne1;
+    uint r2;
+    uint r3;
+} pcs;
diff --git a/kompute-shaders/op_norm.comp b/kompute-shaders/op_norm.comp
new file mode 100644 (file)
index 0000000..ad0c3c0
--- /dev/null
@@ -0,0 +1,84 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 256) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inOff;
+    uint outOff;
+    uint ne00;
+    uint nb01;
+    float eps;
+} pcs;
+
+shared float sum[gl_WorkGroupSize.x];
+
+void main() {
+    const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
+    // MEAN
+    // parallel sum
+    sum[gl_LocalInvocationID.x] = 0.0;
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        sum[gl_LocalInvocationID.x] += in_[x+i00];
+    }
+
+    // reduce
+    barrier();
+    memoryBarrierShared();
+    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
+        if (gl_LocalInvocationID.x < i) {
+            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
+        }
+        barrier();
+        memoryBarrierShared();
+    }
+
+    // broadcast
+    if (gl_LocalInvocationID.x == 0) {
+        sum[0] /= float(pcs.ne00);
+    }
+    barrier();
+    memoryBarrierShared();
+    const float mean = sum[0];
+
+    // recenter
+    const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        out_[y+i00] = in_[x+i00] - mean;
+    }
+
+    // VARIANCE
+    // parallel sum
+    sum[gl_LocalInvocationID.x] = 0.0;
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
+    }
+
+    // reduce
+    barrier();
+    memoryBarrierShared();
+    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
+        if (gl_LocalInvocationID.x < i) {
+            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
+        }
+        barrier();
+        memoryBarrierShared();
+    }
+
+    // broadcast
+    if (gl_LocalInvocationID.x == 0) {
+        sum[0] /= float(pcs.ne00);
+    }
+    barrier();
+    memoryBarrierShared();
+    const float variance = sum[0];
+
+    const float scale = 1.0f/sqrt(variance + pcs.eps);
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        out_[y+i00] *= scale;
+    }
+}
diff --git a/kompute-shaders/op_relu.comp b/kompute-shaders/op_relu.comp
new file mode 100644 (file)
index 0000000..52a601f
--- /dev/null
@@ -0,0 +1,21 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+layout(push_constant) uniform PushConstants {
+    uint inOff;
+    uint outOff;
+} pcs;
+
+void main() {
+    const uint baseIndex = gl_WorkGroupID.x * 4;
+
+    for (uint x = 0; x < 4; x++) {
+        const uint i = baseIndex + x;
+        out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
+    }
+}
diff --git a/kompute-shaders/op_rmsnorm.comp b/kompute-shaders/op_rmsnorm.comp
new file mode 100644 (file)
index 0000000..da658c1
--- /dev/null
@@ -0,0 +1,53 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 512) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inOff;
+    uint outOff;
+    uint ne00;
+    uint nb01;
+    float eps;
+} pcs;
+
+shared float sum[gl_WorkGroupSize.x];
+
+void main() {
+    const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
+
+    // parallel sum
+    sum[gl_LocalInvocationID.x] = 0.0;
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
+    }
+
+    // reduce
+    barrier();
+    memoryBarrierShared();
+    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
+        if (gl_LocalInvocationID.x < i) {
+            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
+        }
+        barrier();
+        memoryBarrierShared();
+    }
+
+    // broadcast
+    if (gl_LocalInvocationID.x == 0) {
+        sum[0] /= float(pcs.ne00);
+    }
+    barrier();
+    memoryBarrierShared();
+
+    const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
+
+    const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
+    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+        out_[y+i00] = in_[x+i00] * scale;
+    }
+}
diff --git a/kompute-shaders/op_rope_f16.comp b/kompute-shaders/op_rope_f16.comp
new file mode 100644 (file)
index 0000000..b446225
--- /dev/null
@@ -0,0 +1,73 @@
+#version 450
+
+#include "rope_common.comp"
+
+layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
+layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
+
+void main() {
+    const uint i3 = gl_WorkGroupID.z;
+    const uint i2 = gl_WorkGroupID.y;
+    const uint i1 = gl_WorkGroupID.x;
+
+    const bool is_neox = (pcs.mode & 2) != 0;
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+
+    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
+
+    const int p = inB[pcs.inBOff + i2];
+
+    float theta = float(p);
+
+    if (!is_neox) {
+        for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
+            float cos_theta, sin_theta;
+            rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+            theta *= theta_scale;
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
+
+            const float x0 = float(inA[src]);
+            const float x1 = float(inA[src+1]);
+
+            out_[dst_data]   = float16_t(x0*cos_theta - x1*sin_theta);
+            out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
+        }
+    } else {
+        const float inv_ndims = -1.f/pcs.n_dims;
+        for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
+            const uint cur_rot = ic;
+
+            float cos_theta, sin_theta;
+            rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+            theta *= theta_scale;
+
+            const uint i0 = ic/2;
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
+
+            const float x0 = float(inA[src]);
+            const float x1 = float(inA[src+pcs.n_dims/2]);
+
+            out_[dst_data]              = float16_t(x0*cos_theta - x1*sin_theta);
+            out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
+        }
+
+        for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
+            const uint i0 = ic;
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
+
+            out_[dst_data + 0] = inA[src + 0];
+            out_[dst_data + 1] = inA[src + 1];
+        }
+    }
+}
diff --git a/kompute-shaders/op_rope_f32.comp b/kompute-shaders/op_rope_f32.comp
new file mode 100644 (file)
index 0000000..2c0235d
--- /dev/null
@@ -0,0 +1,73 @@
+#version 450
+
+#include "rope_common.comp"
+
+layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly  tensorInB { int   inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+void main() {
+    const uint i3 = gl_WorkGroupID.z;
+    const uint i2 = gl_WorkGroupID.y;
+    const uint i1 = gl_WorkGroupID.x;
+
+    const bool is_neox = (pcs.mode & 2) != 0;
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+
+    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
+
+    const int p = inB[pcs.inBOff + i2];
+
+    float theta = float(p);
+
+    if (!is_neox) {
+        for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
+            float cos_theta, sin_theta;
+            rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+            theta *= theta_scale;
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
+
+            const float x0 = inA[src];
+            const float x1 = inA[src+1];
+
+            out_[dst_data]   = x0*cos_theta - x1*sin_theta;
+            out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
+        }
+    } else {
+        const float inv_ndims = -1.f/pcs.n_dims;
+        for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
+            const uint cur_rot = ic;
+
+            float cos_theta, sin_theta;
+            rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+            theta *= theta_scale;
+
+            const uint i0 = ic/2;
+
+            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
+
+            const float x0 = inA[src];
+            const float x1 = inA[src+pcs.n_dims/2];
+
+            out_[dst_data] = x0*cos_theta - x1*sin_theta;
+            out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
+        }
+
+        for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
+            const uint i0 = ic;
+
+            const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
+
+            out_[dst_data + 0] = inA[src + 0];
+            out_[dst_data + 1] = inA[src + 1];
+        }
+    }
+}
diff --git a/kompute-shaders/op_scale.comp b/kompute-shaders/op_scale.comp
new file mode 100644 (file)
index 0000000..bdae267
--- /dev/null
@@ -0,0 +1,19 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inOff;
+    uint outOff;
+    float scale;
+} pcs;
+
+void main() {
+    const uint i = gl_WorkGroupID.x;
+    out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
+}
diff --git a/kompute-shaders/op_scale_8.comp b/kompute-shaders/op_scale_8.comp
new file mode 100644 (file)
index 0000000..ada6975
--- /dev/null
@@ -0,0 +1,23 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inOff;
+    uint outOff;
+    float scale;
+} pcs;
+
+void main() {
+    const uint baseIndex = gl_WorkGroupID.x * 8;
+
+    for (uint x = 0; x < 8; x++) {
+        const uint i = baseIndex + x;
+        out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
+    }
+}
diff --git a/kompute-shaders/op_silu.comp b/kompute-shaders/op_silu.comp
new file mode 100644 (file)
index 0000000..0fb8e4b
--- /dev/null
@@ -0,0 +1,22 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+layout(push_constant) uniform PushConstants {
+    uint inOff;
+    uint outOff;
+} pcs;
+
+void main() {
+    const uint baseIndex = gl_WorkGroupID.x * 4;
+
+    for (uint x = 0; x < 4; x++) {
+        const uint i = baseIndex + x;
+        const float y = in_[i + pcs.inOff];
+        out_[i + pcs.outOff] = y / (1.0 + exp(-y));
+    }
+}
diff --git a/kompute-shaders/op_softmax.comp b/kompute-shaders/op_softmax.comp
new file mode 100644 (file)
index 0000000..7bc9176
--- /dev/null
@@ -0,0 +1,56 @@
+// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
+
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x_id = 0) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int ne01;
+    int ne02;
+    float scale;
+    int mask;
+} pcs;
+
+void main() {
+    if (gl_SubgroupInvocationID > 31)
+        return;
+
+    const uint i03 = gl_WorkGroupID.z;
+    const uint i02 = gl_WorkGroupID.y;
+    const uint i01 = gl_WorkGroupID.x;
+
+    const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
+    const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
+    const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
+    const uint pdst = extra_off + pcs.outOff; // Based from out_
+
+    // parallel max
+    float localMax = uintBitsToFloat(0xFF800000);
+    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+        localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
+    }
+    float max_ = subgroupMax(localMax);
+
+    // parallel sum
+    float localSum = 0.0f;
+    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+        const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
+        localSum += exp_psrc0;
+        out_[pdst + i00] = exp_psrc0;
+    }
+
+    const float sum = subgroupAdd(localSum);
+    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+        out_[pdst + i00] /= sum;
+    }
+}
diff --git a/kompute-shaders/rope_common.comp b/kompute-shaders/rope_common.comp
new file mode 100644 (file)
index 0000000..57ba659
--- /dev/null
@@ -0,0 +1,67 @@
+#include "common.comp"
+
+// TODO: use a local size of 32 or more (Metal uses 1024)
+layout(local_size_x = 1) in;
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int n_dims;
+    int mode;
+    int n_orig_ctx;
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+    uint nb00;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    int ne0;
+    uint nb0;
+    uint nb1;
+    uint nb2;
+    uint nb3;
+} pcs;
+
+float rope_yarn_ramp(const float low, const float high, const float i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
+    out float cos_theta, out float sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+    }
+    cos_theta = cos(theta) * mscale;
+    sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
+    return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
+}
+
+void rope_yarn_corr_dims(
+    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
index 45569f7d378b09cb00b4a2c9aae6279e7b963e9e..9631506c6fd4e9499f7b4efa8aecbeb680454996 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -15,6 +15,8 @@
 #  include "ggml-vulkan.h"
 #elif defined(GGML_USE_SYCL)
 #  include "ggml-sycl.h"
+#elif defined(GGML_USE_KOMPUTE)
+#   include "ggml-kompute.h"
 #endif
 
 #ifdef GGML_USE_METAL
@@ -1313,6 +1315,11 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
     buft = ggml_backend_sycl_buffer_type(gpu);
 #elif defined(GGML_USE_CLBLAST)
     buft = ggml_backend_opencl_buffer_type();
+#elif defined(GGML_USE_KOMPUTE)
+    buft = ggml_backend_kompute_buffer_type(gpu);
+    if (buft == nullptr) {
+        LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu);
+    }
 #endif
 
     if (buft == nullptr) {
@@ -4107,7 +4114,7 @@ static bool llm_load_tensors(
 }
 
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
-static int llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
+static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
     try {
         llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
 
@@ -4128,6 +4135,22 @@ static int llama_model_load(const std::string & fname, llama_model & model, cons
             return 0;
         }
 
+#ifdef GGML_USE_KOMPUTE
+        if (ggml_vk_has_device() && params.n_gpu_layers > 0 && (
+            !(model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON)
+            || !(
+                model.ftype == LLAMA_FTYPE_ALL_F32 ||
+                model.ftype == LLAMA_FTYPE_MOSTLY_F16 ||
+                model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
+                model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1
+            )
+        )) {
+            // disable Vulkan due to unsupported model architecture or quantization type
+            // TODO(cebtenzzre): propagate this error outside of llama_load_model_from_file
+            params.n_gpu_layers = 0;
+        }
+#endif
+
         if (!llm_load_tensors(
             ml, model, params.n_gpu_layers, params.split_mode,  params.main_gpu, params.tensor_split, params.use_mlock,
             params.progress_callback, params.progress_callback_user_data
@@ -10259,6 +10282,16 @@ struct llama_context * llama_new_context_with_model(
             }
             ctx->backends.push_back(backend);
         }
+#elif defined(GGML_USE_KOMPUTE)
+        if (model->n_gpu_layers > 0) {
+            auto * backend = ggml_backend_kompute_init(model->main_gpu);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize Kompute backend\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        }
 #endif
         ctx->backend_cpu = ggml_backend_cpu_init();
         if (ctx->backend_cpu == nullptr) {
diff --git a/llama.h b/llama.h
index 3e33072c68c17edd2a61f45d47ffece144788a7b..01b293e64977aaba8c757306419ed8c1978a88d2 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -49,7 +49,8 @@
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
 #define LLAMA_SESSION_VERSION 4
 
-#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL)
+#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
+    defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
 // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
 #define LLAMA_SUPPORTS_GPU_OFFLOAD
 #endif
index 01593910584d40a3af94fb13b591abc0a28cfae0..775147d42bf1581ce6ace1dba578528466d87cc9 100644 (file)
@@ -370,12 +370,15 @@ struct test_case {
         printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
         fflush(stdout);
 
-        // check if backends support op
+        // check if the backends support the ops
         bool supported = true;
         for (ggml_backend_t backend : {backend1, backend2}) {
-            if (!ggml_backend_supports_op(backend, out)) {
-                printf("not supported [%s] ", ggml_backend_name(backend));
-                supported = false;
+            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+                if (!ggml_backend_supports_op(backend, t)) {
+                    printf("not supported [%s] ", ggml_backend_name(backend));
+                    supported = false;
+                    break;
+                }
             }
         }
         if (!supported) {
@@ -626,6 +629,13 @@ struct test_unary : public test_case {
         ggml_tensor * out = ggml_unary(ctx, in, op);
         return out;
     }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // test extended range of values to check for NaNs in GELU
+            init_tensor_uniform(t, -150.f, 150.f);
+        }
+    }
 };
 
 // GGML_OP_GET_ROWS
@@ -1066,18 +1076,24 @@ struct test_diag_mask_inf : public test_case {
 struct test_soft_max : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
+    const float scale;
+    const bool mask;
 
     std::string vars() override {
-        return VARS_TO_STR2(type, ne);
+        return VARS_TO_STR4(type, ne, scale, mask);
     }
 
     test_soft_max(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10})
-        : type(type), ne(ne) {}
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            float scale = 1.0f,
+            bool mask = false)
+        : type(type), ne(ne), scale(scale), mask(mask) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * out = ggml_soft_max(ctx, a);
+        ggml_tensor * b = nullptr;
+        if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
+        ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
         return out;
     }
 };
@@ -1474,6 +1490,393 @@ struct test_moe : public test_case {
     }
 };
 
+
+enum llm_norm_type {
+    LLM_NORM,
+    LLM_NORM_RMS,
+};
+
+struct llama_hparams {
+    uint32_t n_vocab;
+    uint32_t n_embd;
+    uint32_t n_head;
+    uint32_t n_head_kv;
+    static constexpr uint32_t n_layer = 1;
+    uint32_t n_rot;
+    uint32_t n_embd_head; // dimension of values (d_v)
+    uint32_t n_ff;
+
+    float f_norm_eps;
+    float f_norm_rms_eps;
+
+    // cparams
+    static constexpr uint32_t n_ctx = 512; // user-specified context size
+    static constexpr uint32_t n_orig_ctx = n_ctx;
+
+    // batch
+    int32_t n_tokens;
+
+    // llm_build_context
+    static constexpr int32_t n_kv    = 32; // size of KV cache to consider (n_kv <= n_ctx
+    static constexpr int32_t kv_head = 1;  // index of where we store new KV data in the cache
+
+    uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads
+        return n_embd_head * n_head_kv;
+    }
+};
+
+// LLM base class
+struct test_llm : public test_case {
+    llama_hparams hp;
+
+protected:
+    test_llm(llama_hparams hp)
+        : hp(std::move(hp)) {
+    }
+
+public:
+    struct ggml_tensor * llm_build_norm(
+            struct ggml_context * ctx,
+             struct ggml_tensor * cur,
+             struct ggml_tensor * mw,
+             struct ggml_tensor * mb,
+                  llm_norm_type   type) {
+        switch (type) {
+            case LLM_NORM:     cur = ggml_norm    (ctx, cur, hp.f_norm_eps); break;
+            case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;
+        }
+        cur = ggml_mul(ctx, cur, mw);
+        if (mb) {
+            cur = ggml_add(ctx, cur, mb);
+        }
+        return cur;
+    }
+
+    void llm_build_kv_store(
+            struct ggml_context * ctx,
+             struct ggml_tensor * k_l,
+             struct ggml_tensor * v_l,
+             struct ggml_tensor * k_cur,
+             struct ggml_tensor * v_cur) {
+        // compute the transposed [n_tokens, n_embd] V matrix
+        struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));
+
+        struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),
+                (ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);
+
+        struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),
+                (  hp.n_ctx)*ggml_element_size(v_l),
+                (hp.kv_head)*ggml_element_size(v_l));
+
+        // important: storing RoPE-ed version of K in the KV cache!
+        ggml_cpy(ctx, k_cur,   k_cache_view);
+        ggml_cpy(ctx, v_cur_t, v_cache_view);
+    }
+
+    // if max_alibi_bias > 0 then apply ALiBi
+    struct ggml_tensor * llm_build_kqv(
+            struct ggml_context * ctx,
+             struct ggml_tensor * k_l,
+             struct ggml_tensor * v_l,
+             struct ggml_tensor * q_cur,
+             struct ggml_tensor * kq_mask,
+                        float     kq_scale) {
+        struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
+
+        struct ggml_tensor * k =
+            ggml_view_3d(ctx, k_l,
+                    hp.n_embd_head, hp.n_kv, hp.n_head_kv,
+                    ggml_row_size(k_l->type, hp.n_embd_gqa()),
+                    ggml_row_size(k_l->type, hp.n_embd_head),
+                    0);
+
+        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
+
+        // split cached v into n_head heads
+        struct ggml_tensor * v =
+            ggml_view_3d(ctx, v_l,
+                    hp.n_kv, hp.n_embd_head, hp.n_head_kv,
+                    ggml_element_size(v_l)*hp.n_ctx,
+                    ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,
+                    0);
+
+        struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
+
+        struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
+
+        struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);
+
+        struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
+        cur = ggml_mul_mat(ctx, wo, cur);
+
+        return cur;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // pos
+                std::vector<int> data(hp.n_tokens);
+                for (int i = 0; i < hp.n_tokens; i++) {
+                    data[i] = rand() % hp.n_ctx;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+
+// Llama
+struct test_llama : public test_llm {
+    static constexpr float freq_base = 10000.0f;
+    static constexpr float freq_scale = 1.0f;
+    static constexpr float ext_factor = 0.0f;
+    static constexpr float attn_factor = 1.0f;
+    static constexpr float beta_fast = 32.0f;
+    static constexpr float beta_slow = 1.0f;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "LLAMA";
+    }
+
+    std::string vars() override {
+        auto n_tokens = hp.n_tokens;
+        return VARS_TO_STR1(n_tokens);
+    }
+
+    double max_nmse_err() override {
+        return 2e-3;
+    }
+
+    test_llama(int n_tokens = 1)
+        : test_llm({
+            /*n_vocab        =*/ 32000,
+            /*n_embd         =*/ 3200,
+            /*n_head         =*/ 32,
+            /*n_head_kv      =*/ 32,
+            /*n_rot          =*/ 100,
+            /*n_embd_head    =*/ 100,
+            /*n_ff           =*/ 8640,
+            /*f_norm_eps     =*/ 0.f,
+            /*f_norm_rms_eps =*/ 1e-5f,
+            /*n_tokens       =*/ n_tokens,
+        }) {
+    }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
+
+        ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+        ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+
+        for (uint32_t il = 0; il < hp.n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+            cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);
+
+            // self-attention
+            {
+                ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
+                ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
+                ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
+
+                Qcur = ggml_rope_custom(
+                    ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos,
+                    hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+                Kcur = ggml_rope_custom(
+                    ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
+                    hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+                llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
+
+                cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);
+
+            // feed-forward network
+            ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+            cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);
+
+            ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+            ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff,   hp.n_embd);
+            ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);
+            cur = ggml_mul_mat(ctx, ffn_gate, cur);
+            cur = ggml_silu(ctx, cur);
+            cur = ggml_mul(ctx, cur, tmp);
+            cur = ggml_mul_mat(ctx, ffn_down, cur);
+
+            cur = ggml_add(ctx, cur, ffn_inp);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+        cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);
+
+        // lm_head
+        ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);
+        cur = ggml_mul_mat(ctx, output, cur);
+
+        return cur;
+    }
+};
+
+// Falcon
+struct test_falcon : public test_llm {
+    static constexpr float freq_base = 10000.0f;
+    static constexpr float freq_scale = 1.0f;
+    static constexpr float ext_factor = 0.0f;
+    static constexpr float attn_factor = 1.0f;
+    static constexpr float beta_fast = 32.0f;
+    static constexpr float beta_slow = 1.0f;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "FALCON";
+    }
+
+    std::string vars() override {
+        auto n_tokens = hp.n_tokens;
+        return VARS_TO_STR1(n_tokens);
+    }
+
+    double max_nmse_err() override {
+        return 2e-3;
+    }
+
+    test_falcon(int n_tokens = 1)
+        : test_llm({
+            /*n_vocab        =*/ 32000,
+            /*n_embd         =*/ 3200,
+            /*n_head         =*/ 50,
+            /*n_head_kv      =*/ 1,
+            /*n_rot          =*/ 64,
+            /*n_embd_head    =*/ 64,
+            /*n_ff           =*/ 8640,
+            /*f_norm_eps     =*/ 1e-5f,
+            /*f_norm_rms_eps =*/ 0.f,
+            /*n_tokens       =*/ n_tokens,
+        }) {
+    }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
+
+        ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+        ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
+
+        for (uint32_t il = 0; il < hp.n_layer; ++il) {
+            // norm
+            ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+            ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+            ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);
+
+            // self-attention
+            {
+                cur = attn_norm;
+
+                ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());
+
+                cur = ggml_mul_mat(ctx, wqkv, cur);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd,     hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));
+
+                Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens);
+                Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
+
+                // using mode = 2 for neox mode
+                Qcur = ggml_rope_custom(
+                    ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+                Kcur = ggml_rope_custom(
+                    ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+                llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
+
+                cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
+            }
+
+            struct ggml_tensor * ffn_inp = cur;
+
+            // feed forward
+            {
+                ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
+                ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
+                cur = attn_norm;
+                cur = ggml_mul_mat(ctx, ffn_up, cur);
+                cur = ggml_gelu(ctx, cur);
+                cur = ggml_mul_mat(ctx, ffn_down, cur);
+            }
+
+            cur = ggml_add(ctx, cur, ffn_inp);
+
+            cur = ggml_add(ctx, cur, inpL);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        ggml_tensor * output_norm   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+        ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
+        cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);
+
+        // lm_head
+        ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);
+        cur = ggml_mul_mat(ctx, output, cur);
+
+        return cur;
+    }
+};
+
 static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
     std::vector<std::unique_ptr<test_case>> test_cases;
     std::default_random_engine rng(0);
@@ -1626,6 +2029,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         exponent <<= 1;
     }
 
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
+
     for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
         test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512)); // llama 7B
         test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512)); // llama 13B
@@ -1662,6 +2068,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
 #endif
 
+    // these tests are disabled to save execution time, but they can be handy for debugging
+#if 0
+    test_cases.emplace_back(new test_llama(1));
+    test_cases.emplace_back(new test_llama(2));
+    test_cases.emplace_back(new test_falcon(1));
+    test_cases.emplace_back(new test_falcon(2));
+#endif
+
     // run tests
     if (mode == MODE_TEST) {
         ggml_backend_t backend_cpu = ggml_backend_cpu_init();
index a05071080a1dffa4e22ad908371f6b9e2ba8077c..95ba73df39a3c0341d1cdba53ffec2999f7893d9 100644 (file)
@@ -1,3 +1,7 @@
 #include "llama.h"
 
+#ifdef GGML_USE_KOMPUTE
+#include "ggml-kompute.h"
+#endif
+
 int main(void) {}