]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : sync resolve (skip) (#0)
authorGeorgi Gerganov <redacted>
Tue, 19 Nov 2024 17:03:47 +0000 (19:03 +0200)
committerGeorgi Gerganov <redacted>
Wed, 20 Nov 2024 19:00:08 +0000 (21:00 +0200)
120 files changed:
ggml/CMakeLists.txt
ggml/src/ggml-amx.cpp [deleted file]
ggml/src/ggml-backend.cpp
ggml/src/ggml-blas.cpp [deleted file]
ggml/src/ggml-cann.cpp [deleted file]
ggml/src/ggml-cpu-impl.h [deleted file]
ggml/src/ggml-cpu.c [deleted file]
ggml/src/ggml-cuda.cu [deleted file]
ggml/src/ggml-cuda/CMakeLists.txt [new file with mode: 0644]
ggml/src/ggml-kompute.cpp [deleted file]
ggml/src/ggml-metal.m [deleted file]
ggml/src/ggml-metal.metal [deleted file]
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-musa/CMakeLists.txt [new file with mode: 0644]
ggml/src/ggml-rpc.cpp [deleted file]
ggml/src/ggml-sycl.cpp [deleted file]
ggml/src/ggml-vulkan.cpp [deleted file]
ggml/src/kompute-shaders/common.comp [deleted file]
ggml/src/kompute-shaders/op_add.comp [deleted file]
ggml/src/kompute-shaders/op_addrow.comp [deleted file]
ggml/src/kompute-shaders/op_cpy_f16_f16.comp [deleted file]
ggml/src/kompute-shaders/op_cpy_f16_f32.comp [deleted file]
ggml/src/kompute-shaders/op_cpy_f32_f16.comp [deleted file]
ggml/src/kompute-shaders/op_cpy_f32_f32.comp [deleted file]
ggml/src/kompute-shaders/op_diagmask.comp [deleted file]
ggml/src/kompute-shaders/op_gelu.comp [deleted file]
ggml/src/kompute-shaders/op_getrows.comp [deleted file]
ggml/src/kompute-shaders/op_getrows_f16.comp [deleted file]
ggml/src/kompute-shaders/op_getrows_f32.comp [deleted file]
ggml/src/kompute-shaders/op_getrows_q4_0.comp [deleted file]
ggml/src/kompute-shaders/op_getrows_q4_1.comp [deleted file]
ggml/src/kompute-shaders/op_getrows_q6_k.comp [deleted file]
ggml/src/kompute-shaders/op_mul.comp [deleted file]
ggml/src/kompute-shaders/op_mul_mat_f16.comp [deleted file]
ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp [deleted file]
ggml/src/kompute-shaders/op_mul_mat_q4_0.comp [deleted file]
ggml/src/kompute-shaders/op_mul_mat_q4_1.comp [deleted file]
ggml/src/kompute-shaders/op_mul_mat_q6_k.comp [deleted file]
ggml/src/kompute-shaders/op_mul_mat_q8_0.comp [deleted file]
ggml/src/kompute-shaders/op_mul_mv_q_n.comp [deleted file]
ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp [deleted file]
ggml/src/kompute-shaders/op_norm.comp [deleted file]
ggml/src/kompute-shaders/op_relu.comp [deleted file]
ggml/src/kompute-shaders/op_rmsnorm.comp [deleted file]
ggml/src/kompute-shaders/op_rope_f16.comp [deleted file]
ggml/src/kompute-shaders/op_rope_f32.comp [deleted file]
ggml/src/kompute-shaders/op_scale.comp [deleted file]
ggml/src/kompute-shaders/op_scale_8.comp [deleted file]
ggml/src/kompute-shaders/op_silu.comp [deleted file]
ggml/src/kompute-shaders/op_softmax.comp [deleted file]
ggml/src/kompute-shaders/rope_common.comp [deleted file]
ggml/src/sgemm.cpp [deleted file]
ggml/src/sgemm.h [deleted file]
ggml/src/vulkan-shaders/CMakeLists.txt [deleted file]
ggml/src/vulkan-shaders/acc.comp [deleted file]
ggml/src/vulkan-shaders/add.comp [deleted file]
ggml/src/vulkan-shaders/argsort.comp [deleted file]
ggml/src/vulkan-shaders/clamp.comp [deleted file]
ggml/src/vulkan-shaders/concat.comp [deleted file]
ggml/src/vulkan-shaders/contig_copy.comp [deleted file]
ggml/src/vulkan-shaders/copy.comp [deleted file]
ggml/src/vulkan-shaders/cos.comp [deleted file]
ggml/src/vulkan-shaders/dequant_f32.comp [deleted file]
ggml/src/vulkan-shaders/dequant_funcs.comp [deleted file]
ggml/src/vulkan-shaders/dequant_head.comp [deleted file]
ggml/src/vulkan-shaders/dequant_iq4_nl.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q2_k.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q3_k.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q4_0.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q4_1.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q4_k.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q5_0.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q5_1.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q5_k.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q6_k.comp [deleted file]
ggml/src/vulkan-shaders/dequant_q8_0.comp [deleted file]
ggml/src/vulkan-shaders/diag_mask_inf.comp [deleted file]
ggml/src/vulkan-shaders/div.comp [deleted file]
ggml/src/vulkan-shaders/gelu.comp [deleted file]
ggml/src/vulkan-shaders/gelu_quick.comp [deleted file]
ggml/src/vulkan-shaders/generic_binary_head.comp [deleted file]
ggml/src/vulkan-shaders/generic_head.comp [deleted file]
ggml/src/vulkan-shaders/generic_unary_head.comp [deleted file]
ggml/src/vulkan-shaders/get_rows.comp [deleted file]
ggml/src/vulkan-shaders/get_rows_quant.comp [deleted file]
ggml/src/vulkan-shaders/group_norm.comp [deleted file]
ggml/src/vulkan-shaders/im2col.comp [deleted file]
ggml/src/vulkan-shaders/leaky_relu.comp [deleted file]
ggml/src/vulkan-shaders/mul.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec_base.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec_nc.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec_p021.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp [deleted file]
ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp [deleted file]
ggml/src/vulkan-shaders/mul_mm.comp [deleted file]
ggml/src/vulkan-shaders/norm.comp [deleted file]
ggml/src/vulkan-shaders/pad.comp [deleted file]
ggml/src/vulkan-shaders/pool2d.comp [deleted file]
ggml/src/vulkan-shaders/relu.comp [deleted file]
ggml/src/vulkan-shaders/repeat.comp [deleted file]
ggml/src/vulkan-shaders/rms_norm.comp [deleted file]
ggml/src/vulkan-shaders/rope_head.comp [deleted file]
ggml/src/vulkan-shaders/rope_neox.comp [deleted file]
ggml/src/vulkan-shaders/rope_norm.comp [deleted file]
ggml/src/vulkan-shaders/scale.comp [deleted file]
ggml/src/vulkan-shaders/silu.comp [deleted file]
ggml/src/vulkan-shaders/sin.comp [deleted file]
ggml/src/vulkan-shaders/soft_max.comp [deleted file]
ggml/src/vulkan-shaders/square.comp [deleted file]
ggml/src/vulkan-shaders/sum_rows.comp [deleted file]
ggml/src/vulkan-shaders/tanh.comp [deleted file]
ggml/src/vulkan-shaders/timestep_embedding.comp [deleted file]
ggml/src/vulkan-shaders/types.comp [deleted file]
ggml/src/vulkan-shaders/upscale.comp [deleted file]
ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp [deleted file]

index f5f556cdcf88d0f435f3385fa1c53b1e36ae4e1f..9ab91421a7d2532945f4170b79a70c1295f1476b 100644 (file)
@@ -225,6 +225,7 @@ set(GGML_PUBLIC_HEADERS
     include/ggml-cann.h
     include/ggml-cuda.h
     include/ggml-kompute.h
+    include/ggml-opt.h
     include/ggml-metal.h
     include/ggml-rpc.h
     include/ggml-sycl.h
@@ -241,7 +242,7 @@ install(TARGETS ggml-base LIBRARY)
 if (GGML_METAL)
     # FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY?
     install(
-        FILES ggml/src/ggml-metal/ggml-metal.metal
+        FILES src/ggml-metal/ggml-metal.metal
         PERMISSIONS
             OWNER_READ
             OWNER_WRITE
diff --git a/ggml/src/ggml-amx.cpp b/ggml/src/ggml-amx.cpp
deleted file mode 100644 (file)
index 144dc9d..0000000
+++ /dev/null
@@ -1,436 +0,0 @@
-#include "ggml-amx.h"
-#include "ggml-amx/common.h"
-#include "ggml-amx/mmq.h"
-#include "ggml-backend-impl.h"
-#include "ggml-impl.h"
-
-#if defined(__gnu_linux__)
-#include <sys/syscall.h>
-#include <unistd.h>
-#endif
-
-#include <cstdlib>
-#include <cstring>
-#include <memory>
-
-#if defined(__AMX_INT8__)
-
-// AMX buffer interface
-static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    free(buffer->context);
-}
-
-static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
-    return (void *)(buffer->context);
-}
-
-static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
-    memset((char *)tensor->data + offset, value, size);
-
-    GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    if (qtype_has_amx_kernels(tensor->type)) {
-        ggml_backend_amx_convert_weight(tensor, data, offset, size);
-    } else {
-        memcpy((char *)tensor->data + offset, data, size);
-    }
-
-    GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
-    memcpy(data, (const char *)tensor->data + offset, size);
-
-    GGML_UNUSED(buffer);
-}
-
-static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
-    if (ggml_backend_buffer_is_host(src->buffer)) {
-        if (qtype_has_amx_kernels(src->type)) {
-            ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
-        } else {
-            memcpy(dst->data, src->data, ggml_nbytes(src));
-        }
-        return true;
-    }
-    return false;
-
-    GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    memset(buffer->context, value, buffer->size);
-}
-
-static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
-    /* .free_buffer     = */ ggml_backend_amx_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_amx_buffer_get_base,
-    /* .init_tensor     = */ NULL, // no initialization required
-    /* .memset_tensor   = */ ggml_backend_amx_buffer_memset_tensor,
-    /* .set_tensor      = */ ggml_backend_amx_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_amx_buffer_get_tensor,
-    /* .cpy_tensor      = */ ggml_backend_amx_buffer_cpy_tensor,
-    /* .clear           = */ ggml_backend_amx_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    return "AMX";
-
-    GGML_UNUSED(buft);
-}
-
-static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
-    if (data == NULL) {
-        fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
-        return NULL;
-    }
-
-    return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
-}
-
-static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return TENSOR_ALIGNMENT;
-
-    GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
-    return ggml_backend_amx_get_alloc_size(tensor);
-
-    GGML_UNUSED(buft);
-}
-
-static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
-    return false;
-
-    GGML_UNUSED(buft);
-}
-
-ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
-    static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_amx_buffer_type_get_name,
-            /* .alloc_buffer     = */ ggml_backend_amx_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_amx_buffer_type_get_alignment,
-            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-            /* .get_alloc_size   = */ ggml_backend_amx_buffer_type_get_alloc_size,
-            /* .is_host          = */ ggml_backend_amx_buffer_type_is_host,
-        },
-        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
-        /* .context = */ NULL,
-    };
-
-    return &ggml_backend_buffer_type_amx;
-}
-
-// backend interface
-
-static const char * ggml_backend_amx_name(ggml_backend_t backend) {
-    return "AMX";
-
-    GGML_UNUSED(backend);
-}
-
-static void ggml_backend_amx_free(ggml_backend_t backend) {
-    ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
-    delete ctx;
-    delete backend;
-}
-
-static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
-
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * node = cgraph->nodes[i];
-
-        switch (node->op) {
-        case GGML_OP_MUL_MAT:
-            ggml_backend_amx_mul_mat(ctx, node);
-            break;
-
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-            break;
-
-        default:
-            fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
-            GGML_ASSERT(false);
-        }
-    }
-
-    return GGML_STATUS_SUCCESS;
-
-    GGML_UNUSED(backend);
-}
-
-static struct ggml_backend_i ggml_backend_amx_i = {
-    /* .get_name                = */ ggml_backend_amx_name,
-    /* .free                    = */ ggml_backend_amx_free,
-    /* .set_tensor_async        = */ NULL,
-    /* .get_tensor_async        = */ NULL,
-    /* .cpy_tensor_async        = */ NULL,
-    /* .synchronize             = */ NULL,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_amx_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_amx_guid() {
-    static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
-    return &guid;
-}
-
-#define ARCH_GET_XCOMP_PERM     0x1022
-#define ARCH_REQ_XCOMP_PERM     0x1023
-#define XFEATURE_XTILECFG       17
-#define XFEATURE_XTILEDATA      18
-
-static bool ggml_amx_init() {
-#if defined(__gnu_linux__)
-    if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
-        fprintf(stderr, "AMX is not ready to be used!\n");
-        return false;
-    }
-    return true;
-#elif defined(_WIN32)
-    return true;
-#endif
-}
-
-ggml_backend_t ggml_backend_amx_init() {
-
-    // invoke a Linux system call to request access to AMX features
-    ggml_amx_init();
-
-    // backend context
-    ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
-
-    // ggml amx backend
-    ggml_backend_t backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_amx_guid(),
-        /* .interface = */ ggml_backend_amx_i,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
-        /* .context   = */ ctx,
-    };
-
-    return backend;
-}
-
-bool ggml_backend_is_amx(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
-}
-
-void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
-    GGML_ASSERT(ggml_backend_is_amx(backend_amx));
-
-    ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
-    ctx->n_threads = n_threads;
-}
-
-// device interface
-
-static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
-    return "AMX";
-
-    GGML_UNUSED(dev);
-}
-
-static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
-    return "Intel Advanced Matrix Extensions";
-
-    GGML_UNUSED(dev);
-}
-
-static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    // TODO
-    *free = 0;
-    *total = 0;
-
-    GGML_UNUSED(dev);
-}
-
-static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
-    return GGML_BACKEND_DEVICE_TYPE_ACCEL;
-
-    GGML_UNUSED(dev);
-}
-
-static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_amx_device_get_name(dev);
-    props->description = ggml_backend_amx_device_get_description(dev);
-    props->type        = ggml_backend_amx_device_get_type(dev);
-    ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
-
-    // `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
-    props->caps = {
-        /* .async                 = */ false,
-        /* .host_buffer           = */ false,
-        /* .buffer_from_host_ptr  = */ false,
-        /* .events                = */ false,
-    };
-}
-
-static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
-    return ggml_backend_amx_init();
-
-    GGML_UNUSED(dev);
-    GGML_UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
-    return ggml_backend_amx_buffer_type();
-
-    GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-
-    // handle only 2d gemm for now
-    auto is_contiguous_2d = [](const struct ggml_tensor * t) {
-        return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
-    };
-
-    switch (op->op) {
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-            return true;
-
-        case GGML_OP_MUL_MAT: {
-            const struct ggml_tensor * src0 = op->src[0];
-            const struct ggml_tensor * src1 = op->src[1];
-
-            const enum ggml_type type = src0->type;
-            const int64_t ne0 = op->ne[0];
-
-            bool is_training = src0->grad || src1->grad;
-
-            // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
-            // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
-            bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
-
-            bool can_use_amx =
-                is_contiguous_2d(src0) &&       // src0 must be contiguous
-                is_contiguous_2d(src1) &&       // src1 must be contiguous
-                !is_training &&                 // inference only
-                src1->type == GGML_TYPE_F32 &&  // src1 must be float32
-                has_amx_kernels &&              // with amx kernel impls
-                ne0 % (TILE_N * 2) == 0;        // out_features is 32x
-
-            return can_use_amx;
-        }
-        default:
-            return false;
-    }
-
-    GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
-
-    GGML_UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
-    /* .get_name             = */ ggml_backend_amx_device_get_name,
-    /* .get_description      = */ ggml_backend_amx_device_get_description,
-    /* .get_memory           = */ ggml_backend_amx_device_get_memory,
-    /* .get_type             = */ ggml_backend_amx_device_get_type,
-    /* .get_props            = */ ggml_backend_amx_device_get_props,
-    /* .init_backend         = */ ggml_backend_amx_device_init,
-    /* .get_buffer_type      = */ ggml_backend_amx_device_get_buffer_type,
-    /* .get_host_buffer_type = */ NULL,
-    /* .buffer_from_host_ptr = */ NULL,
-    /* .supports_op          = */ ggml_backend_amx_device_supports_op,
-    /* .supports_buft        = */ ggml_backend_amx_device_supports_buft,
-    /* .offload_op           = */ NULL,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
-};
-
-// backend reg interface
-
-static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
-    return "AMX";
-
-    GGML_UNUSED(reg);
-}
-
-static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
-    return 1;
-
-    GGML_UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
-    GGML_ASSERT(index == 0);
-
-    static ggml_backend_device ggml_backend_amx_device = {
-        /* .iface   = */ ggml_backend_amx_device_i,
-        /* .reg     = */ reg,
-        /* .context = */ nullptr,
-    };
-
-    return &ggml_backend_amx_device;
-
-    GGML_UNUSED(reg);
-    GGML_UNUSED(index);
-}
-
-static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
-    if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
-        return (void *)ggml_backend_amx_set_n_threads;
-    }
-    return NULL;
-
-    GGML_UNUSED(reg);
-    GGML_UNUSED(name);
-}
-
-static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
-    /* .get_name         = */ ggml_backend_amx_reg_get_name,
-    /* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
-    /* .get_device       = */ ggml_backend_amx_reg_get_device,
-    /* .get_proc_address = */ ggml_backend_amx_get_proc_address,
-};
-
-ggml_backend_reg_t ggml_backend_amx_reg(void) {
-    static struct ggml_backend_reg ggml_backend_amx_reg = {
-        /* .iface   = */ ggml_backend_amx_reg_i,
-        /* .context = */ NULL,
-    };
-
-    return &ggml_backend_amx_reg;
-}
-
-#else // if defined(__AMX_INT8__)
-
-ggml_backend_t ggml_backend_amx_init(void) {
-    fprintf(stderr, "GGML is not compiled with AMX support!\n");
-    return ggml_backend_t{};
-}
-
-void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
-    fprintf(stderr, "GGML is not compiled with AMX support!\n");
-
-    GGML_UNUSED(backend_amx);
-    GGML_UNUSED(n_threads);
-}
-
-#endif
index e30a9a00d58e5de385115af906c63af7353f2218..9dcde8d11952ae4fd6c0d395081f75760c843912 100644 (file)
@@ -525,197 +525,6 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
     return reg->iface.get_proc_address(reg, name);
 }
 
-// Backend registry
-
-#ifdef GGML_USE_CUDA
-#include "ggml-cuda.h"
-#endif
-
-#ifdef GGML_USE_METAL
-#include "ggml-metal.h"
-#endif
-
-#ifdef GGML_USE_SYCL
-#include "ggml-sycl.h"
-#endif
-
-#ifdef GGML_USE_VULKAN
-#include "ggml-vulkan.h"
-#endif
-
-#ifdef GGML_USE_BLAS
-#include "ggml-blas.h"
-#endif
-
-#ifdef GGML_USE_RPC
-#include "ggml-rpc.h"
-#endif
-
-#ifndef __AMX_INT8__
-#undef GGML_USE_AMX
-#endif
-
-#ifdef GGML_USE_AMX
-#  include "ggml-amx.h"
-#endif
-
-#ifdef GGML_USE_CANN
-#include "ggml-cann.h"
-#endif
-
-#ifdef GGML_USE_KOMPUTE
-#include "ggml-kompute.h"
-#endif
-
-#include "ggml-cpu.h"
-
-struct ggml_backend_registry {
-    std::vector<ggml_backend_reg_t> backends;
-    std::vector<ggml_backend_dev_t> devices;
-
-    ggml_backend_registry() {
-#ifdef GGML_USE_CUDA
-        register_backend(ggml_backend_cuda_reg());
-#endif
-#ifdef GGML_USE_METAL
-        register_backend(ggml_backend_metal_reg());
-#endif
-#ifdef GGML_USE_SYCL
-        register_backend(ggml_backend_sycl_reg());
-#endif
-#ifdef GGML_USE_VULKAN
-        register_backend(ggml_backend_vk_reg());
-#endif
-#ifdef GGML_USE_CANN
-        register_backend(ggml_backend_cann_reg());
-#endif
-#ifdef GGML_USE_BLAS
-        register_backend(ggml_backend_blas_reg());
-#endif
-#ifdef GGML_USE_RPC
-        register_backend(ggml_backend_rpc_reg());
-#endif
-#ifdef GGML_USE_AMX
-        register_backend(ggml_backend_amx_reg());
-#endif
-#ifdef GGML_USE_KOMPUTE
-        register_backend(ggml_backend_kompute_reg());
-#endif
-
-        register_backend(ggml_backend_cpu_reg());
-    }
-
-    void register_backend(ggml_backend_reg_t reg) {
-#ifndef NDEBUG
-        GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
-            __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
-#endif
-        backends.push_back(reg);
-        for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
-            register_device(ggml_backend_reg_dev_get(reg, i));
-        }
-    }
-
-    void register_device(ggml_backend_dev_t device) {
-#ifndef NDEBUG
-        GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
-#endif
-        devices.push_back(device);
-    }
-};
-
-static ggml_backend_registry & get_reg() {
-    static ggml_backend_registry reg;
-    return reg;
-}
-
-// Internal API
-void ggml_backend_register(ggml_backend_reg_t reg) {
-    get_reg().register_backend(reg);
-}
-
-void ggml_backend_device_register(ggml_backend_dev_t device) {
-    get_reg().register_device(device);
-}
-
-// Backend (reg) enumeration
-size_t ggml_backend_reg_count() {
-    return get_reg().backends.size();
-}
-
-ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
-    GGML_ASSERT(index < ggml_backend_reg_count());
-    return get_reg().backends[index];
-}
-
-ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
-    for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
-        ggml_backend_reg_t reg = ggml_backend_reg_get(i);
-        if (strcmp(ggml_backend_reg_name(reg), name) == 0) {
-            return reg;
-        }
-    }
-    return NULL;
-}
-
-// Device enumeration
-size_t ggml_backend_dev_count() {
-    return get_reg().devices.size();
-}
-
-ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
-    GGML_ASSERT(index < ggml_backend_dev_count());
-    return get_reg().devices[index];
-}
-
-ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
-    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (strcmp(ggml_backend_dev_name(dev), name) == 0) {
-            return dev;
-        }
-    }
-    return NULL;
-}
-
-ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {
-    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (ggml_backend_dev_type(dev) == type) {
-            return dev;
-        }
-    }
-    return NULL;
-}
-
-// Convenience functions
-ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) {
-    ggml_backend_dev_t dev = ggml_backend_dev_by_name(name);
-    if (!dev) {
-        return NULL;
-    }
-    return ggml_backend_dev_init(dev, params);
-}
-
-ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) {
-    ggml_backend_dev_t dev = ggml_backend_dev_by_type(type);
-    if (!dev) {
-        return NULL;
-    }
-    return ggml_backend_dev_init(dev, params);
-}
-
-ggml_backend_t ggml_backend_init_best(void) {
-    ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
-    if (!dev) {
-        dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-    }
-    if (!dev) {
-        return NULL;
-    }
-    return ggml_backend_dev_init(dev, NULL);
-}
-
 // multi-buffer buffer
 
 struct ggml_backend_multi_buffer_context {
@@ -1641,7 +1450,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
         bool parallel) {
     GGML_ASSERT(n_backends > 0);
     GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
-    GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU
+    GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
 
     struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched));
 
@@ -2038,17 +1847,6 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
     return true;
 }
 
-
-
-#include "ggml-backend.h"
-#include "ggml-backend-impl.h"
-#include "ggml-cpu.h"
-#include "ggml-impl.h"
-#include <cctype>
-#include <string>
-
-// ggml-backend interface
-
 // CPU backend - buffer
 
 static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
@@ -2122,7 +1920,9 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
     /* .reset           = */ NULL,
 };
 
-// CPU backend - buffer type
+// CPU backend buffer type
+
+// this buffer type is defined here to make it available to all backends
 
 static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
     return "CPU";
@@ -2163,7 +1963,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
             /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
             /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
         },
-        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .device  = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
         /* .context = */ NULL,
     };
 
@@ -2186,479 +1986,14 @@ static ggml_backend_buffer_type_t ggml_backend_cpu_buffer_from_ptr_type(void) {
             /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
             /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
         },
-        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .device  = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
         /* .context = */ NULL,
     };
 
     return &ggml_backend_cpu_buffer_type;
 }
 
-#ifdef GGML_USE_CPU_HBM
-
-// buffer type HBM
-
-#include <hbwmalloc.h>
-
-static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    return "CPU_HBM";
-
-    GGML_UNUSED(buft);
-}
-
-static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    hbw_free(buffer->context);
-}
-
-static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    void * ptr;
-    int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
-    if (result != 0) {
-        GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
-        return NULL;
-    }
-
-    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
-    buffer->buft = buft;
-    buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
-
-    return buffer;
-}
-
-ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
-    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
-        /* .iface    = */ {
-            /* .get_name         = */ ggml_backend_cpu_hbm_buffer_type_get_name,
-            /* .alloc_buffer     = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
-            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
-            /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
-        },
-        /* .context  = */ NULL,
-    };
-
-    return &ggml_backend_cpu_buffer_type_hbm;
-}
-#endif
-
-static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
-    static ggml_backend_buffer_type_t bufts[] = {
-#ifdef GGML_USE_CPU_HBM
-        ggml_backend_cpu_hbm_buffer_type(),
-#endif
-        NULL
-    };
-
-    return bufts;
-
-    GGML_UNUSED(device);
-}
-
-// CPU backend - backend (stream)
-
-struct ggml_backend_cpu_context {
-    int                 n_threads;
-    ggml_threadpool_t   threadpool;
-
-    uint8_t *           work_data;
-    size_t              work_size;
-
-    ggml_abort_callback abort_callback;
-    void *              abort_callback_data;
-};
-
-static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
-    return "CPU";
-
-    GGML_UNUSED(backend);
-}
-
-static void ggml_backend_cpu_free(ggml_backend_t backend) {
-    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
-    delete[] cpu_ctx->work_data;
-    delete cpu_ctx;
-    delete backend;
-}
-
-struct ggml_backend_plan_cpu {
-    struct ggml_cplan cplan;
-    struct ggml_cgraph cgraph;
-};
-
-static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
-    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
-
-    struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu;
-
-    cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
-    cpu_plan->cgraph = *cgraph; // FIXME: deep copy
-
-    if (cpu_plan->cplan.work_size > 0) {
-        cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size];
-        if (cpu_plan->cplan.work_data == NULL) {
-            delete cpu_plan;
-            return NULL;
-        }
-    }
-
-    cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
-    cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
-
-    return cpu_plan;
-}
-
-static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
-    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
-
-    delete[] cpu_plan->cplan.work_data;
-    delete cpu_plan;
-
-    GGML_UNUSED(backend);
-}
-
-static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
-    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
-
-    return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
-
-    GGML_UNUSED(backend);
-}
-
-static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
-
-    struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
-
-    if (cpu_ctx->work_size < cplan.work_size) {
-        delete[] cpu_ctx->work_data;
-        cpu_ctx->work_data = new uint8_t[cplan.work_size];
-        if (cpu_ctx->work_data == NULL) {
-            cpu_ctx->work_size = 0;
-            return GGML_STATUS_ALLOC_FAILED;
-        }
-        cpu_ctx->work_size = cplan.work_size;
-    }
-    cplan.work_data = (uint8_t *)cpu_ctx->work_data;
-
-    cplan.abort_callback      = cpu_ctx->abort_callback;
-    cplan.abort_callback_data = cpu_ctx->abort_callback_data;
-
-    return ggml_graph_compute(cgraph, &cplan);
-}
-
-static const struct ggml_backend_i ggml_backend_cpu_i = {
-    /* .get_name                = */ ggml_backend_cpu_get_name,
-    /* .free                    = */ ggml_backend_cpu_free,
-    /* .set_tensor_async        = */ NULL,
-    /* .get_tensor_async        = */ NULL,
-    /* .cpy_tensor_async        = */ NULL,
-    /* .synchronize             = */ NULL,
-    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
-    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
-    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_cpu_guid(void) {
-    static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
-    return &guid;
-}
-
-ggml_backend_t ggml_backend_cpu_init(void) {
-    // initialize CPU backend now to avoid slowing the first graph computation
-    ggml_cpu_init();
-
-    struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context;
-    if (ctx == NULL) {
-        return NULL;
-    }
-
-    ctx->n_threads           = GGML_DEFAULT_N_THREADS;
-    ctx->threadpool          = NULL;
-    ctx->work_data           = NULL;
-    ctx->work_size           = 0;
-    ctx->abort_callback      = NULL;
-    ctx->abort_callback_data = NULL;
-
-    ggml_backend_t cpu_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_cpu_guid(),
-        /* .interface = */ ggml_backend_cpu_i,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
-        /* .context   = */ ctx,
-    };
-
-    if (cpu_backend == NULL) {
-        delete ctx;
-        return NULL;
-    }
-
-    return cpu_backend;
-}
-
-bool ggml_backend_is_cpu(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
-}
-
-void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
-    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
-
-    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
-    ctx->n_threads = n_threads;
-}
-
-void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
-    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
-
-    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
-
-    if (ctx->threadpool && ctx->threadpool != threadpool) {
-        // already had a different threadpool, pause/suspend it before switching
-        ggml_threadpool_pause(ctx->threadpool);
-    }
-    ctx->threadpool = threadpool;
-}
-
-void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
-    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
-
-    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
-    ctx->abort_callback = abort_callback;
-    ctx->abort_callback_data = abort_callback_data;
-}
-
 ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
     GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
     return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);
 }
-
-// CPU backend - device
-
-struct ggml_backend_cpu_device_context {
-    std::string description = "CPU";
-
-    ggml_backend_cpu_device_context() {
-#ifdef __APPLE__
-        size_t len = 0;
-        if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) {
-            description.resize(len);
-            sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT
-        }
-#elif defined(__linux__)
-        FILE * f = fopen("/proc/cpuinfo", "r");
-        if (f) {
-            char buf[1024];
-            while (fgets(buf, sizeof(buf), f)) {
-                if (strncmp(buf, "model name", 10) == 0) {
-                    char * p = strchr(buf, ':');
-                    if (p) {
-                        p++;
-                        while (std::isspace(*p)) {
-                            p++;
-                        }
-                        while (std::isspace(p[strlen(p) - 1])) {
-                            p[strlen(p) - 1] = '\0';
-                        }
-                        description = p;
-                        break;
-                    }
-                }
-            }
-            fclose(f);
-        }
-#elif defined(_WIN32)
-        HKEY hKey;
-        if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,
-                        TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"),
-                        0,
-                        KEY_READ,
-                        &hKey) == ERROR_SUCCESS) {
-            DWORD cpu_brand_size = 0;
-            if (RegQueryValueExA(hKey,
-                                TEXT("ProcessorNameString"),
-                                NULL,
-                                NULL,
-                                NULL,
-                                &cpu_brand_size) == ERROR_SUCCESS) {
-                description.resize(cpu_brand_size);
-                if (RegQueryValueExA(hKey,
-                                    TEXT("ProcessorNameString"),
-                                    NULL,
-                                    NULL,
-                                    (LPBYTE)&description[0], // NOLINT
-                                    &cpu_brand_size) == ERROR_SUCCESS) {
-                    if (description.find('\0') != std::string::npos) {
-                        description.resize(description.find('\0'));
-                    }
-                }
-            }
-            RegCloseKey(hKey);
-        }
-#endif
-    }
-};
-
-static const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) {
-    return "CPU";
-
-    GGML_UNUSED(dev);
-}
-
-static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) {
-    struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context;
-
-    return ctx->description.c_str();
-}
-
-static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    // TODO
-    *free = 0;
-    *total = 0;
-
-    GGML_UNUSED(dev);
-}
-
-static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {
-    return GGML_BACKEND_DEVICE_TYPE_CPU;
-
-    GGML_UNUSED(dev);
-}
-
-static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_cpu_device_get_name(dev);
-    props->description = ggml_backend_cpu_device_get_description(dev);
-    props->type        = ggml_backend_cpu_device_get_type(dev);
-    ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
-    props->caps = {
-        /* .async                 = */ false,
-        /* .host_buffer           = */ false,
-        /* .buffer_from_host_ptr  = */ true,
-        /* .events                = */ false,
-    };
-}
-
-static ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) {
-    return ggml_backend_cpu_init();
-
-    GGML_UNUSED(dev);
-    GGML_UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) {
-    return ggml_backend_cpu_buffer_type();
-
-    GGML_UNUSED(dev);
-}
-
-static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
-    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
-
-    GGML_UNUSED(dev);
-    GGML_UNUSED(max_tensor_size);
-}
-
-static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    switch (op->op) {
-        case GGML_OP_CPY:
-            return
-                op->type != GGML_TYPE_IQ2_XXS &&
-                op->type != GGML_TYPE_IQ2_XS  &&
-                op->type != GGML_TYPE_IQ1_S   &&
-                op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
-        case GGML_OP_MUL_MAT:
-            //return op->src[1]->type == GGML_TYPE_F32; // TMP: workaround until sync with latest ggml
-            return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits_cpu(op->src[0]->type)->vec_dot_type;
-        case GGML_OP_ROPE_BACK:
-            return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
-        case GGML_OP_IM2COL_BACK:
-            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
-        case GGML_OP_OUT_PROD:
-            return (op->src[0]->type == GGML_TYPE_F32 || ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == GGML_TYPE_F32;
-        default:
-            return true;
-    }
-
-    GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    return ggml_backend_buft_is_host(buft);
-
-    GGML_UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_cpu_device_i = {
-    /* .get_name             = */ ggml_backend_cpu_device_get_name,
-    /* .get_description      = */ ggml_backend_cpu_device_get_description,
-    /* .get_memory           = */ ggml_backend_cpu_device_get_memory,
-    /* .get_type             = */ ggml_backend_cpu_device_get_type,
-    /* .get_props            = */ ggml_backend_cpu_device_get_props,
-    /* .init_backend         = */ ggml_backend_cpu_device_init_backend,
-    /* .get_buffer_type      = */ ggml_backend_cpu_device_get_buffer_type,
-    /* .get_host_buffer_type = */ NULL,
-    /* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr,
-    /* .supports_op          = */ ggml_backend_cpu_device_supports_op,
-    /* .supports_buft        = */ ggml_backend_cpu_device_supports_buft,
-    /* .offload_op           = */ NULL,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
-};
-
-// CPU backend - backend (reg)
-
-static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {
-    return "CPU";
-
-    GGML_UNUSED(reg);
-}
-
-static size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) {
-    return 1;
-
-    GGML_UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
-    GGML_ASSERT(index == 0);
-
-    static ggml_backend_cpu_device_context ctx;
-    static ggml_backend_device ggml_backend_cpu_device = {
-        /* .iface   = */ ggml_backend_cpu_device_i,
-        /* .reg     = */ reg,
-        /* .context = */ &ctx,
-    };
-
-    return &ggml_backend_cpu_device;
-}
-
-static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
-    if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
-        return (void *)ggml_backend_cpu_set_n_threads;
-    }
-    if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
-        return (void *)ggml_backend_cpu_get_extra_bufts;
-    }
-
-    return NULL;
-
-    GGML_UNUSED(reg);
-}
-
-static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
-    /* .get_name         = */ ggml_backend_cpu_reg_get_name,
-    /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
-    /* .get_device       = */ ggml_backend_cpu_reg_get_device,
-    /* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
-};
-
-ggml_backend_reg_t ggml_backend_cpu_reg(void) {
-    static struct ggml_backend_reg ggml_backend_cpu_reg = {
-        /* .iface   = */ ggml_backend_cpu_reg_i,
-        /* .context = */ NULL,
-    };
-
-    return &ggml_backend_cpu_reg;
-}
diff --git a/ggml/src/ggml-blas.cpp b/ggml/src/ggml-blas.cpp
deleted file mode 100644 (file)
index 8d96220..0000000
+++ /dev/null
@@ -1,514 +0,0 @@
-#include "ggml-impl.h"
-#include "ggml-blas.h"
-#include "ggml-backend-impl.h"
-
-#include <future>
-#include <vector>
-#include <cstring>
-
-#if defined(GGML_USE_ACCELERATE)
-#   include <Accelerate/Accelerate.h>
-#elif defined(GGML_BLAS_USE_MKL)
-#   include <mkl.h>
-#elif defined(GGML_BLAS_USE_BLIS)
-#   include <blis.h>
-#elif defined(GGML_BLAS_USE_NVPL)
-#   include <nvpl_blas.h>
-#else
-#   include <cblas.h>
-#endif
-
-struct ggml_backend_blas_context {
-    int n_threads = GGML_DEFAULT_N_THREADS;
-    std::unique_ptr<char[]> work_data;
-    size_t work_size = 0;
-#ifndef GGML_USE_OPENMP
-    std::vector<std::future<void>> tasks;
-#endif
-};
-
-static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const enum ggml_type type = src0->type;
-
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne12);
-    GGML_ASSERT(ne3 == ne13);
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
-
-    const int64_t ne_plane      = ne01*ne00;
-    const size_t  desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
-
-    if (ctx->work_size < desired_wsize) {
-        ctx->work_data.reset(new char[desired_wsize]);
-        ctx->work_size = desired_wsize;
-    }
-    void * wdata = ctx->work_data.get();
-
-    // convert src0 to float
-    if (type != GGML_TYPE_F32) {
-        const auto * type_traits = ggml_get_type_traits(type);
-        ggml_to_float_t const to_float = type_traits->to_float;
-
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                const void  *       x      = (char *)  src0->data + i02*nb02          + i03*nb03;
-                      float * const wplane = (float *) wdata      + i02*ne_plane      + i03*ne02*ne_plane;
-
-                const int min_cols_per_thread = 4096;
-                const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
-                const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
-
-#ifdef GGML_USE_OPENMP
-                #pragma omp parallel for num_threads(n_threads)
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
-                }
-#else
-                for (int i = 1; i < n_threads; i++) {
-                    const int64_t start =       i*ne01/n_threads;
-                    const int64_t end   = (i + 1)*ne01/n_threads;
-                    if (start < end) {
-                        ctx->tasks.push_back(std::async(std::launch::async, [=]() {
-                            for (int64_t i01 = start; i01 < end; i01++) {
-                                to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
-                            }
-                        }));
-                    }
-                }
-                {
-                    // reuse the current thread for the first task
-                    const int64_t start = 0;
-                    const int64_t end   = ne01/n_threads;
-                    for (int64_t i01 = start; i01 < end; i01++) {
-                        to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
-                    }
-                }
-#endif
-            }
-        }
-
-#ifndef GGML_USE_OPENMP
-        // wait for all tasks to finish
-        for (auto & task : ctx->tasks) {
-            task.get();
-        }
-        ctx->tasks.clear();
-#endif
-    }
-
-#if defined(OPENBLAS_VERSION)
-    openblas_set_num_threads(ctx->n_threads);
-#endif
-
-#if defined(GGML_BLAS_USE_BLIS)
-    bli_thread_set_num_threads(ctx->n_threads);
-#endif
-
-#if defined(GGML_BLAS_USE_NVPL)
-    nvpl_blas_set_num_threads(ctx->n_threads);
-#endif
-
-    for (int64_t i13 = 0; i13 < ne13; i13++) {
-        for (int64_t i12 = 0; i12 < ne12; i12++) {
-            const int64_t i03 = i13/r3;
-            const int64_t i02 = i12/r2;
-
-            const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
-            const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
-                  float * d = (float *) ((char *)  dst->data + i12*nb2  + i13*nb3);
-
-            if (type != GGML_TYPE_F32) {
-                x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
-            }
-
-            cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne1, ne01, ne10,
-                        1.0f,   y, ne10,
-                                x, ne00,
-                        0.0f,   d, ne01);
-        }
-    }
-}
-
-static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(ne0  == ne00);
-    GGML_ASSERT(ne1  == ne10);
-    GGML_ASSERT(ne2  == ne02);
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne3  == ne13);
-    GGML_ASSERT(ne03 == ne13);
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    // GGML_ASSERT(nb0 <= nb1);
-    // GGML_ASSERT(nb1 <= nb2);
-    // GGML_ASSERT(nb2 <= nb3);
-
-    // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
-    // src0: (k,n)
-    // src1: (k,m)
-    // dst:  (m,n)
-    //
-    // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
-    // Also expressed as (major,minor)
-    // a: (m,k): so src1 transposed
-    // b: (k,n): so src0
-    // c: (m,n)
-    //
-    // However, if ggml_is_transposed(src1) is true, then
-    // src1->data already contains a transposed version, so sgemm mustn't
-    // transpose it further.
-
-    int n = src0->ne[0];
-    int k = src0->ne[1];
-    int m = src1->ne[0];
-
-    CBLAS_TRANSPOSE transposeA;
-    int lda;
-
-    if (!ggml_is_transposed(src1)) {
-        transposeA = CblasTrans;
-        lda = m;
-    } else {
-        transposeA = CblasNoTrans;
-        lda = k;
-    }
-
-    float * a = (float *) ((char *) src1->data);
-    float * b = (float *) ((char *) src0->data);
-    float * c = (float *) ((char *) dst->data);
-
-    cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
-
-    GGML_UNUSED(ctx);
-}
-
-// backend interface
-
-static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
-    return "BLAS";
-
-    GGML_UNUSED(backend);
-}
-
-static void ggml_backend_blas_free(ggml_backend_t backend) {
-    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
-    delete ctx;
-    delete backend;
-}
-
-static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
-
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * node = cgraph->nodes[i];
-
-        switch (node->op) {
-            case GGML_OP_MUL_MAT:
-                ggml_backend_blas_mul_mat(ctx, node);
-                break;
-
-            case GGML_OP_OUT_PROD:
-                ggml_backend_blas_out_prod(ctx, node);
-                break;
-
-            case GGML_OP_NONE:
-            case GGML_OP_RESHAPE:
-            case GGML_OP_VIEW:
-            case GGML_OP_PERMUTE:
-            case GGML_OP_TRANSPOSE:
-                break;
-
-            default:
-                GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
-        }
-    }
-
-    return GGML_STATUS_SUCCESS;
-
-    GGML_UNUSED(backend);
-}
-
-static struct ggml_backend_i blas_backend_i = {
-    /* .get_name                = */ ggml_backend_blas_get_name,
-    /* .free                    = */ ggml_backend_blas_free,
-    /* .set_tensor_async        = */ NULL,
-    /* .get_tensor_async        = */ NULL,
-    /* .cpy_tensor_async        = */ NULL,
-    /* .synchronize             = */ NULL,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_blas_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_blas_guid(void) {
-    static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
-    return &guid;
-}
-
-ggml_backend_t ggml_backend_blas_init(void) {
-    ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
-
-    ggml_backend_t backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_blas_guid(),
-        /* .interface = */ blas_backend_i,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
-        /* .context   = */ ctx,
-    };
-
-#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
-    if (openblas_get_parallel() != OPENBLAS_OPENMP) {
-        GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
-    }
-#endif
-
-#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
-    GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
-#endif
-
-    return backend;
-}
-
-bool ggml_backend_is_blas(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
-}
-
-void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
-    GGML_ASSERT(ggml_backend_is_blas(backend_blas));
-
-    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
-    ctx->n_threads = n_threads;
-}
-
-// device interface
-
-static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
-    return "BLAS";
-
-    GGML_UNUSED(dev);
-}
-
-static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
-    #if defined(GGML_USE_ACCELERATE)
-        return "Accelerate";
-    #elif defined(GGML_BLAS_USE_MKL)
-        return "MKL";
-    #elif defined(GGML_BLAS_USE_BLIS)
-        return "BLIS";
-    #elif defined(GGML_BLAS_USE_NVPL)
-        return "NVPL";
-    #elif defined(OPENBLAS_VERSION)
-        return "OpenBLAS";
-    #else
-        return "BLAS";
-    #endif
-
-    GGML_UNUSED(dev);
-}
-
-static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    // TODO
-    *free = 0;
-    *total = 0;
-
-    GGML_UNUSED(dev);
-}
-
-static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
-    return GGML_BACKEND_DEVICE_TYPE_ACCEL;
-
-    GGML_UNUSED(dev);
-}
-
-static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_blas_device_get_name(dev);
-    props->description = ggml_backend_blas_device_get_description(dev);
-    props->type        = ggml_backend_blas_device_get_type(dev);
-    ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
-    props->caps = {
-        /* .async                 = */ false,
-        /* .host_buffer           = */ false,
-        /* .buffer_from_host_ptr  = */ true,
-        /* .events                = */ false,
-    };
-}
-
-static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
-    return ggml_backend_blas_init();
-
-    GGML_UNUSED(dev);
-    GGML_UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
-    return ggml_backend_cpu_buffer_type();
-
-    GGML_UNUSED(dev);
-}
-
-static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
-    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
-
-    GGML_UNUSED(dev);
-    GGML_UNUSED(max_tensor_size);
-}
-
-static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    const struct ggml_tensor * src0 = op->src[0];
-    const struct ggml_tensor * src1 = op->src[1];
-
-    switch (op->op) {
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-            return true;
-
-        case GGML_OP_MUL_MAT:
-        {
-            // BLAS usually is only faster for large matrices
-            const struct ggml_tensor * src0 = op->src[0];
-            const struct ggml_tensor * src1 = op->src[1];
-
-            const int64_t ne10 = src1->ne[0];
-
-            const int64_t ne0 = op->ne[0];
-            const int64_t ne1 = op->ne[1];
-
-            // TODO: find the optimal value
-            const int64_t min_batch = 32;
-
-            return ggml_is_contiguous(src0) &&
-                   ggml_is_contiguous(src1) &&
-                   src1->type == GGML_TYPE_F32 &&
-                   (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
-                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
-        }
-
-        case GGML_OP_OUT_PROD:
-            return op->src[0]->type == GGML_TYPE_F32 &&
-                   op->src[1]->type == GGML_TYPE_F32 &&
-                   ggml_is_matrix(src0) &&
-                   ggml_is_matrix(src1) &&
-                   ggml_is_contiguous(src0) &&
-                   (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
-                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
-
-        default:
-            return false;
-
-    }
-
-    GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    return ggml_backend_buft_is_host(buft);
-
-    GGML_UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
-    /* .get_name             = */ ggml_backend_blas_device_get_name,
-    /* .get_description      = */ ggml_backend_blas_device_get_description,
-    /* .get_memory           = */ ggml_backend_blas_device_get_memory,
-    /* .get_type             = */ ggml_backend_blas_device_get_type,
-    /* .get_props            = */ ggml_backend_blas_device_get_props,
-    /* .init_backend         = */ ggml_backend_blas_device_init_backend,
-    /* .get_buffer_type      = */ ggml_backend_blas_device_get_buffer_type,
-    /* .get_host_buffer_type = */ NULL,
-    /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
-    /* .supports_op          = */ ggml_backend_blas_device_supports_op,
-    /* .supports_buft        = */ ggml_backend_blas_device_supports_buft,
-    /* .offload_op           = */ NULL,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
-};
-
-// backend reg interface
-
-static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
-    return "BLAS";
-
-    GGML_UNUSED(reg);
-}
-
-static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
-    return 1;
-
-    GGML_UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
-    GGML_ASSERT(index == 0);
-
-    static ggml_backend_device ggml_backend_blas_device = {
-        /* .iface   = */ ggml_backend_blas_device_i,
-        /* .reg     = */ reg,
-        /* .context = */ nullptr,
-    };
-
-    return &ggml_backend_blas_device;
-
-    GGML_UNUSED(reg);
-    GGML_UNUSED(index);
-}
-
-static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
-    if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
-        return (void *)ggml_backend_blas_set_n_threads;
-    }
-    return NULL;
-
-    GGML_UNUSED(reg);
-    GGML_UNUSED(name);
-}
-
-static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
-    /* .get_name         = */ ggml_backend_blas_reg_get_name,
-    /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
-    /* .get_device       = */ ggml_backend_blas_reg_get_device,
-    /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
-};
-
-ggml_backend_reg_t ggml_backend_blas_reg(void) {
-    static struct ggml_backend_reg ggml_backend_blas_reg = {
-        /* .iface   = */ ggml_backend_blas_reg_i,
-        /* .context = */ NULL,
-    };
-
-    return &ggml_backend_blas_reg;
-}
diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp
deleted file mode 100644 (file)
index 7763408..0000000
+++ /dev/null
@@ -1,2128 +0,0 @@
-/*
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
- * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
- * IN THE SOFTWARE.
- */
-
-#include "ggml-cann.h"
-
-#include <acl/acl.h>
-#include <stdarg.h>
-
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <mutex>
-
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-#include "ggml-cann/aclnn_ops.h"
-#include "ggml-cann/common.h"
-
-#define GGML_COMMON_DECL_C
-
-#include "ggml-common.h"
-
-#define GGML_CANN_NAME "CANN"
-
-/**
- * @brief Handles CANN errors by printing an error message and aborting.
- *
- * @param stmt The statement that caused the error.
- * @param func The function in which the error occurred.
- * @param file The file in which the error occurred.
- * @param line The line number where the error occurred.
- * @param msg The error message.
- */
-[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
-                                  const char* file, int line, const char* msg) {
-    int32_t id = -1;
-    aclrtGetDevice(&id);
-
-    GGML_LOG_ERROR("CANN error: %s\n", msg);
-    GGML_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func,
-            file, line);
-    GGML_LOG_ERROR("  %s\n", stmt);
-    // abort with GGML_ASSERT to get a stack trace
-    GGML_ABORT("CANN error");
-}
-
-/**
- * @brief Sets the device to be used by CANN.
- *
- * @param device The device ID to set.
- */
-void ggml_cann_set_device(const int32_t device) {
-    // TODO: uncomment these lines after empty context has fixed.
-    // int current_device;
-    // ACL_CHECK(aclrtGetDevice(&current_device));
-
-    // if (device == current_device) {
-    //   return;
-    // }
-    ACL_CHECK(aclrtSetDevice(device));
-}
-
-/**
- * @brief Retrieves the current device ID.
- *
- * @return The current device ID.
- */
-int32_t ggml_cann_get_device() {
-    int32_t id;
-    ACL_CHECK(aclrtGetDevice(&id));
-    return id;
-}
-
-/**
- * @brief Initialize the CANN device information.
- *
- * This function initializes the CANN device information by obtaining the
- * device count and setting the memory allocation granularity for each device.
- *
- * @return A structure containing the device information.
- */
-static ggml_cann_device_info ggml_cann_init() {
-    ggml_cann_device_info info = {};
-
-    aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
-
-    if (err != ACL_SUCCESS) {
-        GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
-                __func__, aclGetRecentErrMsg());
-        return info;
-    }
-
-    GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
-
-    for (int id = 0; id < info.device_count; ++id) {
-        aclrtPhysicalMemProp prop = {};
-        prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
-        prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
-        prop.memAttr = ACL_HBM_MEM_HUGE;
-        prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
-        prop.location.id = id;
-        prop.reserve = 0;
-        ACL_CHECK(aclrtMemGetAllocationGranularity(
-            &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
-            &info.devices[id].vmm_granularity));
-    }
-
-    // TODO: add more device info later.
-    return info;
-}
-
-/**
- * @brief Retrieve the CANN device information.
- *
- * This function returns a reference to a structure containing the CANN device
- * information. The device information is initialized once and reused on
- * subsequent calls.
- *
- * @return A reference to the structure containing the device information.
- */
-const ggml_cann_device_info& ggml_cann_info() {
-    static ggml_cann_device_info info = ggml_cann_init();
-    return info;
-}
-
-//#define DEBUG_CANN_MALLOC
-/**
- * @brief A pool of CANN buffers(legacy).
- *
- * This class manages a pool of CANN buffers for a specific device.
- */
-struct ggml_cann_pool_leg : public ggml_cann_pool {
-    /**
-     * @brief The maximum number of buffers in the pool.
-     */
-    static const int MAX_BUFFERS = 256;
-
-    /**
-     * @brief The device ID associated with this buffer pool.
-     */
-    int device;
-
-    /**
-     * @brief Structure representing a CANN buffer.
-     */
-    struct ggml_cann_buffer {
-        void* ptr = nullptr;  ///< Pointer to the buffer memory.
-        size_t size = 0;      ///< Size of the buffer.
-    };
-
-    /**
-     * @brief Array of CANN buffers in the pool.
-     */
-    ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
-
-    /**
-     * @brief Total size of all buffers in the pool.
-     */
-    size_t pool_size = 0;
-
-    /**
-     * @brief Constructor to initialize the buffer pool for a specific device.
-     *
-     * @param device The device ID to associate with this buffer pool.
-     */
-    explicit ggml_cann_pool_leg(int device) : device(device) {}
-
-    /**
-     * @brief Destructor to free all buffers in the pool.
-     */
-    ~ggml_cann_pool_leg() {
-        ggml_cann_set_device(device);
-        for (int i = 0; i < MAX_BUFFERS; ++i) {
-            ggml_cann_buffer& b = buffer_pool[i];
-            if (b.ptr != nullptr) {
-                ACL_CHECK(aclrtFree(b.ptr));
-                pool_size -= b.size;
-            }
-        }
-        GGML_ASSERT(pool_size == 0);
-    }
-
-    /**
-     * @brief Allocate a buffer of the given size.
-     *
-     * @param size The size of the buffer to allocate.
-     * @param actual_size A pointer to a variable to receive the actual size of
-     * the allocated buffer.
-     * @return A pointer to the allocated buffer.
-     */
-    void* alloc(size_t size, size_t* actual_size) override {
-#ifdef DEBUG_CANN_MALLOC
-        int nnz = 0;
-        size_t max_size = 0;
-#endif
-        size_t best_diff = 1ull << 36;
-        int ibest = -1;
-        for (int i = 0; i < MAX_BUFFERS; ++i) {
-            ggml_cann_buffer& b = buffer_pool[i];
-            if (b.ptr != nullptr) {
-#ifdef DEBUG_CANN_MALLOC
-                ++nnz;
-                if (b.size > max_size) max_size = b.size;
-#endif
-                if (b.size >= size) {
-                    size_t diff = b.size - size;
-                    if (diff < best_diff) {
-                        best_diff = diff;
-                        ibest = i;
-                        if (!best_diff) {
-                            void* ptr = b.ptr;
-                            *actual_size = b.size;
-                            b.ptr = nullptr;
-                            b.size = 0;
-                            return ptr;
-                        }
-                    }
-                }
-            }
-        }
-        if (ibest >= 0) {
-            ggml_cann_buffer& b = buffer_pool[ibest];
-            void* ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
-        }
-        void* ptr;
-        size_t look_ahead_size = (size_t)(1.05 * size);
-        look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
-        ggml_cann_set_device(device);
-        ACL_CHECK(
-            aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
-        *actual_size = look_ahead_size;
-        pool_size += look_ahead_size;
-#ifdef DEBUG_CANN_MALLOC
-        GGML_LOG_INFO(
-            "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
-            "requested %u MB\n",
-            __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
-            (uint32_t)(pool_size / 1024 / 1024),
-            (uint32_t)(size / 1024 / 1024));
-#endif
-        return ptr;
-    }
-
-    /**
-     * @brief Free a buffer and return it to the pool.
-     *
-     * @param ptr Pointer to the buffer to free.
-     * @param size Size of the buffer to free.
-     */
-    void free(void* ptr, size_t size) override {
-        for (int i = 0; i < MAX_BUFFERS; ++i) {
-            ggml_cann_buffer& b = buffer_pool[i];
-            if (b.ptr == nullptr) {
-                b.ptr = ptr;
-                b.size = size;
-                return;
-            }
-        }
-        // memory should always buffered. these memory may still needed by
-        // tasks in stream.
-        // TODO, fix me.
-        GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
-    }
-};
-
-/**
- * @brief A pool of CANN buffers with virtual memory.
- *
- * This class manages a pool of CANN buffers with virtual memory for a specific
- * device.
- */
-struct ggml_cann_pool_vmm : public ggml_cann_pool {
-    /**
-     * @brief The maximum size of the virtual memory pool (32 GB).
-     */
-    static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35;  // 32 GB
-
-    /**
-     * @brief The device ID associated with this buffer pool.
-     */
-    int device;
-
-    /**
-     * @brief Pointer to the start of the virtual memory pool.
-     */
-    void* pool_addr = 0;
-
-    /**
-     * @brief Amount of virtual memory used in the pool.
-     */
-    size_t pool_used = 0;
-
-    /**
-     * @brief Total size of the virtual memory pool.
-     */
-    size_t pool_size = 0;
-
-    /**
-     * @brief Allocation granularity for the virtual memory pool.
-     */
-    size_t granularity;
-
-    /**
-     * @brief Handles for the physical memory allocated.
-     */
-    std::vector<aclrtDrvMemHandle> handles;
-
-    /**
-     * @brief Offsets for the mapped memory regions.
-     */
-    std::vector<void*> map_offsets;
-
-    /**
-     * @brief Constructor to initialize the buffer pool with virtual memory for
-     * a specific device.
-     *
-     * @param device The device ID to associate with this buffer pool.
-     */
-    explicit ggml_cann_pool_vmm(int device)
-        : device(device),
-          granularity(ggml_cann_info().devices[device].vmm_granularity) {}
-
-    /**
-     * @brief Destructor to free all buffers in the virtual memory pool.
-     */
-    ~ggml_cann_pool_vmm() {
-        if (pool_addr != 0) {
-            for (auto& offset : map_offsets) {
-                ACL_CHECK(aclrtUnmapMem(offset));
-            }
-            for (auto& handle : handles) {
-                ACL_CHECK(aclrtFreePhysical(handle));
-            }
-            ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
-        }
-    }
-
-    /**
-     * @brief Allocate a buffer of the given size in the virtual memory pool.
-     *
-     * @param size The size of the buffer to allocate.
-     * @param actual_size A pointer to a variable to receive the actual size of
-     * the allocated buffer.
-     * @return A pointer to the allocated buffer.
-     */
-    void* alloc(size_t size, size_t* actual_size) override {
-        // round up the allocation size to the alignment to ensure that all
-        // allocations are aligned for all data types
-        const size_t alignment = 128;
-        size = alignment * ((size + alignment - 1) / alignment);
-
-        size_t avail = pool_size - pool_used;
-
-        if (size > avail) {
-            // round up to the next multiple of the granularity
-            size_t reserve_size = size - avail;
-            reserve_size =
-                granularity * ((reserve_size + granularity - 1) / granularity);
-
-            GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
-
-            // allocate more physical memory
-            aclrtPhysicalMemProp prop = {};
-            prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
-            prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
-            prop.memAttr = ACL_HBM_MEM_HUGE;
-            prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
-            prop.location.id = device;
-            prop.reserve = 0;
-            aclrtDrvMemHandle handle;
-            ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
-
-            // reserve virtual address space (if not already reserved)
-            if (pool_addr == 0) {
-                ACL_CHECK(aclrtReserveMemAddress(
-                    &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
-            }
-
-            // map at the end of the pool
-            ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
-                                  handle, 0));
-
-            handles.push_back(handle);
-            map_offsets.push_back((char*)pool_addr + pool_size);
-
-            // add to the pool
-            pool_size += reserve_size;
-
-            // GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (
-            // reserved %llu MB)\n",
-            //       device, (unsigned long long) (pool_size/1024/1024),
-            //       (unsigned long long) (reserve_size/1024/1024));
-        }
-
-        GGML_ASSERT(pool_addr != 0);
-
-        void* ptr = (void*)((char*)pool_addr + pool_used);
-        *actual_size = size;
-        pool_used += size;
-
-#ifdef DEBUG_CANN_MALLOC
-        GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
-               (unsigned long long)size, (unsigned long long)ptr);
-#endif
-        return ptr;
-    }
-
-    /**
-     * @brief Free a buffer and return it to the virtual memory pool.
-     *
-     * @param ptr Pointer to the buffer to free.
-     * @param size Size of the buffer to free.
-     */
-    void free(void* ptr, size_t size) override {
-#ifdef DEBUG_CANN_MALLOC
-        GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
-               (unsigned long long)size, (unsigned long long)ptr);
-#endif
-
-        pool_used -= size;
-
-        // all deallocations must be in reverse order of the allocations
-        GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
-    }
-};
-
-/**
- * @brief Create a new CANN pool for a specific device.
- *
- * Factory method to create a new CANN pool object based on the device type.
- *
- * @param device The device ID for which to create the pool.
- * @return A unique pointer to the created CANN pool.
- */
-std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
-    int device) {
-    // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
-    return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
-}
-
-// cann buffer
-/**
- * @brief Context for managing a CANN buffer associated with a specific device.
- *
- * This structure holds information about a CANN buffer, including the device
- * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
- */
-struct ggml_backend_cann_buffer_context {
-    int32_t device;  ///< The device ID associated with this buffer context.
-    void* dev_ptr =
-        nullptr;  ///< Pointer to the device memory allocated for the buffer.
-
-    /**
-     * @brief Constructor to initialize the CANN buffer context.
-     *
-     * @param device The device ID associated with this buffer context.
-     * @param dev_ptr Pointer to the device memory allocated for the buffer.
-     */
-    ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
-        : device(device),
-          dev_ptr(dev_ptr) {}
-
-    /**
-     * @brief Destructor to free the device memory allocated for the buffer.
-     */
-    ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
-};
-
-/**
- * @brief Check if a buffer is a CANN buffer.
- *
- * This function checks if a given buffer is a CANN buffer by comparing its
- * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
- *
- * @param buffer The buffer to check.
- * @return true if the buffer is a CANN buffer, false otherwise.
- */
-static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
-static bool ggml_backend_buffer_is_cann(
-    ggml_backend_buffer_t buffer) {
-    return ggml_backend_buft_is_cann(buffer->buft);
-}
-
-/**
- * @brief Free resources associated with a CANN buffer.
- *
- * This function frees the resources associated with a CANN buffer, including
- * its context.
- *
- * @param buffer The CANN buffer to free.
- */
-static void ggml_backend_cann_buffer_free_buffer(
-    ggml_backend_buffer_t buffer) {
-    ggml_backend_cann_buffer_context* ctx =
-        (ggml_backend_cann_buffer_context*)buffer->context;
-    delete ctx;
-}
-
-/**
- * @brief Retrieve the base pointer of a CANN buffer.
- *
- * This function returns the base pointer of a CANN buffer, which points to the
- * device memory allocated for the buffer.
- *
- * @param buffer The CANN buffer whose base pointer is to be retrieved.
- * @return A pointer to the base of the device memory allocated for the buffer.
- */
-static void* ggml_backend_cann_buffer_get_base(
-    ggml_backend_buffer_t buffer) {
-    ggml_backend_cann_buffer_context* ctx =
-        (ggml_backend_cann_buffer_context*)buffer->context;
-    return ctx->dev_ptr;
-}
-
-/**
- * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
- * processing.
- *
- * This function transforms quantized Q4.0 tensor data into a format suitable
- * for CANN processing. It extracts quantization values and scales from the
- * source data and prepares them in a format expected by CANN operations.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source data in Q4.0 format.
- * @param dst Pointer to the destination buffer where transformed data will be
- * stored.
- */
-static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
-                                             const void* src,
-                                             void* dst) {
-
-    int64_t n_elems = ggml_nelements(tensor);
-    int64_t groups = n_elems / QK4_0;
-    size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
-
-    uint8_t* quant_offset = (uint8_t*)dst;
-    uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
-
-    for (int i = 0; i < groups; i++) {
-        const block_q4_0* group =
-            (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
-        *scale_offset = group->d;
-        scale_offset++;
-
-        // 0-15
-        for (int j = 0; j < QK4_0 / 2; j += 2) {
-            (*quant_offset) = (group->qs[j] & 0x0F);
-            (*quant_offset) |= ((group->qs[j + 1] << 4));
-            quant_offset++;
-        }
-
-        // 16-31
-        for (int j = 0; j < QK4_0 / 2; j += 2) {
-            (*quant_offset) = (group->qs[j] >> 4);
-            (*quant_offset) |= (group->qs[j + 1] & 0xF0);
-            quant_offset++;
-        }
-    }
-
-    // put (uint4b_t -8) into int4b_t
-    for (quant_offset = (uint8_t*)dst;
-         quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
-        (*quant_offset) ^= 0x88;
-    }
-}
-
-/**
- * @brief Transform CANN processed data back into quantized Q4.0 format.
- *
- * This function transforms CANN processed data back into quantized Q4.0 format.
- * It reverses the transformation performed by
- * ggml_backend_cann_transform_q4_0(), converting the data back into its
- * original quantized form.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source buffer containing transformed data.
- * @param dst Pointer to the destination buffer where the Q4.0 formatted data
- * will be stored.
- */
-static void ggml_backend_cann_transform_back_q4_0(
-    const ggml_tensor* tensor, void* src, void* dst) {
-
-    int64_t n_elems = ggml_nelements(tensor);
-    int64_t groups = n_elems / QK4_0;
-    size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
-
-    uint8_t* quant_offset = (uint8_t*)src;
-    uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
-
-    for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
-        (*quant_offset) ^= 0x88;
-    }
-    quant_offset = (uint8_t*)src;
-
-    for (int i = 0; i < groups; i++) {
-        block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
-        group->d = *scale_offset;
-        scale_offset++;
-
-        // 0-15
-        for (int j = 0; j < QK4_0 / 2; j += 2) {
-            group->qs[j] = ((*quant_offset) & 0x0F);
-            group->qs[j + 1] = ((*quant_offset) >> 4);
-            quant_offset++;
-        }
-
-        // 16-31
-        for (int j = 0; j < QK4_0 / 2; j += 2) {
-            group->qs[j] |= ((*quant_offset) << 4);
-            group->qs[j + 1] |= ((*quant_offset) & 0xF0);
-            quant_offset++;
-        }
-    }
-}
-
-/**
- * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
- * processing.
- *
- * This function transforms quantized Q8.0 tensor data into a format suitable
- * for CANN processing. It extracts quantization values and scales from the
- * source data and prepares them in a format expected by CANN operations.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source data in Q8.0 format.
- * @param dst Pointer to the destination buffer where transformed data will be
- * stored.
- */
-static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
-                                             const void* src,
-                                             void* dst) {
-    int64_t n_elems = ggml_nelements(tensor);
-    int64_t groups = n_elems / QK8_0;
-    size_t quant_bytes = n_elems * sizeof(uint8_t);
-
-    uint8_t* quant_offset = (uint8_t*)dst;
-    uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
-
-    for (int i = 0; i < groups; i++) {
-        const block_q8_0* group =
-            (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
-        *scale_offset = group->d;
-        scale_offset++;
-        size_t group_quant_size = QK8_0 * sizeof(uint8_t);
-        memcpy(quant_offset, group->qs, group_quant_size);
-        quant_offset += group_quant_size;
-    }
-}
-
-/**
- * @brief Transform CANN processed data back into quantized Q8.0 format.
- *
- * This function transforms CANN processed data back into quantized Q8.0 format.
- * It reverses the transformation performed by
- * ggml_backend_cann_transform_q8_0(), converting the data back into its
- * original quantized form.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source buffer containing transformed data.
- * @param dst Pointer to the destination buffer where the Q8.0 formatted data
- * will be stored.
- */
-static void ggml_backend_cann_transform_back_q8_0(
-    const ggml_tensor* tensor, const void* src, void* dst) {
-    int64_t n_elems = ggml_nelements(tensor);
-    int64_t groups = n_elems / QK8_0;
-    size_t quant_bytes = n_elems * sizeof(uint8_t);
-
-    const uint8_t* quant_offset = (const uint8_t*)src;
-    const uint16_t* scale_offset =
-        (const uint16_t*)((const char*)src + quant_bytes);
-
-    for (int i = 0; i < groups; i++) {
-        block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
-        group->d = *scale_offset;
-        scale_offset++;
-        size_t group_quant_size = QK8_0 * sizeof(uint8_t);
-        memcpy(group->qs, quant_offset, group_quant_size);
-        quant_offset += group_quant_size;
-    }
-}
-
-/**
- * @brief Transform tensor data based on its type for CANN processing.
- *
- * This function transforms tensor data based on its quantization type for CANN
- * processing. It dispatches the transformation based on the tensor's type to
- * specialized functions handling Q4.0 and Q8.0 formats.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source data to be transformed.
- * @param dst Pointer to the destination buffer where transformed data will be
- * stored.
- */
-static void ggml_backend_cann_transform(ggml_tensor* tensor,
-                                        const void* src, void* dst) {
-    switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            ggml_backend_cann_transform_q4_0(tensor, src, dst);
-            break;
-        case GGML_TYPE_Q8_0:
-            ggml_backend_cann_transform_q8_0(tensor, src, dst);
-            break;
-        default:
-            break;
-    }
-}
-
-/**
- * @brief Transform CANN processed data back into tensor data based on its type.
- *
- * This function transforms CANN processed data back into tensor data based on
- * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
- * transformation based on the tensor's type to specialized functions.
- *
- * @param tensor Pointer to the tensor information.
- * @param src Pointer to the source data containing CANN processed data.
- * @param dst Pointer to the destination buffer where transformed tensor data
- * will be stored.
- */
-static void ggml_backend_cann_transform_back(
-    const ggml_tensor* tensor, void* src, void* dst) {
-    switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
-            break;
-        case GGML_TYPE_Q8_0:
-            ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
-            break;
-        default:
-            break;
-    }
-}
-
-/**
- * @brief Check if transformation is needed for a given tensor type.
- *
- * This function checks if transformation is needed for a given tensor type
- * to prepare data for CANN processing.
- *
- * @param type The tensor type to check.
- * @return true if transformation is needed, false otherwise.
- */
-static bool need_transform(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q8_0:
-            return true;
-        default:
-            return false;
-    }
-}
-
-/**
- * @brief Initialize a tensor using data from a CANN buffer.
- *
- * This function initializes a tensor using data from a CANN buffer.
- * It handles special cases such as views and quantization.
- *
- * @param buffer The CANN buffer from which to initialize the tensor.
- * @param tensor Pointer to the tensor to be initialized.
- */
-static void ggml_backend_cann_buffer_init_tensor(
-    ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
-    if (tensor->view_src != NULL && tensor->view_offs == 0) {
-        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
-        return;
-    }
-
-    // TODO: can backend doesn't support quantized yet. Just leave the code
-    // here.
-    if (ggml_is_quantized(tensor->type)) {
-        // Initialize padding to 0 to avoid possible NaN values
-        size_t original_size = ggml_nbytes(tensor);
-        size_t padded_size =
-            ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
-
-        if (padded_size > original_size && tensor->view_src == nullptr) {
-            size_t memset_size = padded_size - original_size;
-            ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
-                                  memset_size, 0, memset_size));
-        }
-    }
-}
-
-// TODO: need handle tensor which has paddings.
-/**
- * @brief Set tensor data in a CANN buffer.
- *
- * This function sets tensor data in a CANN buffer, handling transformations
- * if needed based on the tensor's type.
- *
- * @param buffer The CANN buffer where the tensor data will be set.
- * @param tensor Pointer to the tensor whose data will be set.
- * @param data Pointer to the source data to be copied into the tensor.
- * @param offset Offset in the source data from where to start copying.
- * @param size Size of the data to be copied, in bytes.
- */
-static void ggml_backend_cann_buffer_set_tensor(
-    ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
-    size_t offset, size_t size) {
-    ggml_backend_cann_buffer_context *ctx =
-        (ggml_backend_cann_buffer_context *)buffer->context;
-
-    ggml_cann_set_device(ctx->device);
-    // TODO: refer to cann(#6017), it use thread's default stream.
-    // For acl, synchronous functions use this default stream.
-    // Why aclrtSynchronizeDevice?
-
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
-                              ACL_MEMCPY_HOST_TO_DEVICE));
-    } else {
-        void *transform_buffer = malloc(size);
-        ggml_backend_cann_transform(tensor, data, transform_buffer);
-
-        ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
-                              transform_buffer, size,
-                              ACL_MEMCPY_HOST_TO_DEVICE));
-        free(transform_buffer);
-    }
-}
-
-/**
- * @brief Get tensor data from a CANN buffer.
- *
- * This function retrieves tensor data from a CANN buffer, handling
- * transformations if needed based on the tensor's type.
- *
- * @param buffer The CANN buffer from which to retrieve tensor data.
- * @param tensor Pointer to the tensor whose data will be retrieved.
- * @param data Pointer to the destination buffer where the tensor data will be
- * copied.
- * @param offset Offset in the destination buffer where to start copying.
- * @param size Size of the data to be copied, in bytes.
- */
-static void ggml_backend_cann_buffer_get_tensor(
-    ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
-    size_t offset, size_t size) {
-    ggml_backend_cann_buffer_context* ctx =
-        (ggml_backend_cann_buffer_context*)buffer->context;
-
-    ggml_cann_set_device(ctx->device);
-
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
-                              ACL_MEMCPY_DEVICE_TO_HOST));
-    } else {
-        void* transform_buffer = malloc(size);
-        ACL_CHECK(aclrtMemcpy(transform_buffer, size,
-                              (char*)tensor->data + offset, size,
-                              ACL_MEMCPY_DEVICE_TO_HOST));
-        ggml_backend_cann_transform_back(tensor, transform_buffer, data);
-        free(transform_buffer);
-    }
-}
-
-/**
- * @brief Copy tensor data between CANN buffers if possible.
- *
- * This function copies tensor data between CANN buffers if the source and
- * destination buffers are CANN buffers and they meet the necessary conditions
- * (same device or devices can access each other).
- *
- * @param buffer The destination CANN buffer where the tensor data will be
- * copied.
- * @param src Pointer to the source tensor whose data will be copied.
- * @param dst Pointer to the destination tensor where the data will be copied.
- * @return true if the copy operation succeeded, false otherwise.
- */
-static bool ggml_backend_cann_buffer_cpy_tensor(
-    ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
-    if (ggml_backend_buffer_is_cann(src->buffer)) {
-        ggml_backend_cann_buffer_context* src_ctx =
-            (ggml_backend_cann_buffer_context*)src->buffer->context;
-        ggml_backend_cann_buffer_context* dst_ctx =
-            (ggml_backend_cann_buffer_context*)buffer->context;
-
-        size_t memcpy_size = ggml_nbytes(src);
-        // Same device.
-        if (src_ctx->device == dst_ctx->device) {
-            ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
-                                  (const char*)src->data, memcpy_size,
-                                  ACL_MEMCPY_DEVICE_TO_DEVICE));
-            return true;
-        } else {
-            // Different device but can access by peer.
-            int32_t canAccessPeer = 0;
-            ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
-                                               dst_ctx->device));
-            if (canAccessPeer) {
-                ggml_cann_set_device(src_ctx->device);
-                ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
-                ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
-                                      (const char*)src->data, memcpy_size,
-                                      ACL_MEMCPY_DEVICE_TO_DEVICE));
-                return true;
-            }
-        }
-    }
-    return false;
-}
-
-/**
- * @brief Clear a CANN buffer by setting all its memory to a specified value.
- *
- * This function clears a CANN buffer by setting all its memory to a specified
- * value.
- *
- * @param buffer The CANN buffer to be cleared.
- * @param value The value to which each byte in the buffer will be set.
- */
-static void ggml_backend_cann_buffer_clear(
-    ggml_backend_buffer_t buffer, uint8_t value) {
-    ggml_backend_cann_buffer_context* ctx =
-        (ggml_backend_cann_buffer_context*)buffer->context;
-
-    ggml_cann_set_device(ctx->device);
-    ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
-}
-
-/**
- * @brief Interface for a CANN buffer in the backend.
- *
- * This structure defines function pointers to operations that can be performed
- * on a CANN buffer within the backend.
- */
-static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
-    /* .free_buffer     = */ ggml_backend_cann_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_cann_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_cann_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_cann_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_cann_buffer_get_tensor,
-    /* .cpy_tensor      = */ ggml_backend_cann_buffer_cpy_tensor,
-    /* .clear           = */ ggml_backend_cann_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// cann buffer type
-/**
- * @brief Structure representing context information for a specific backend
- * buffer type.
- */
-struct ggml_backend_cann_buffer_type_context {
-    int32_t
-        device; /**< Device identifier associated with the buffer context. */
-    std::string name; /**< Name associated with the buffer context. */
-};
-
-/**
- * @brief Retrieves the name associated with a CANN buffer type.
- *
- * This function returns the descriptive name associated with the specified
- * CANN buffer type context.
- *
- * @param buft Pointer to the buffer type context.
- * @return Const pointer to the C-style string containing the name.
- */
-static const char* ggml_backend_cann_buffer_type_name(
-    ggml_backend_buffer_type_t buft) {
-    ggml_backend_cann_buffer_type_context* buft_ctx =
-        (ggml_backend_cann_buffer_type_context*)buft->context;
-
-    return buft_ctx->name.c_str();
-}
-
-/**
- * @brief Allocates a new CANN buffer of the specified type and size.
- *
- * This function allocates a new CANN buffer on the specified device with the
- * given size.
- *
- * @param buft Pointer to the buffer type context.
- * @param size Size in bytes of the buffer to allocate.
- * @return Pointer to the allocated buffer, or nullptr if allocation fails.
- */
-static ggml_backend_buffer_t
-ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
-                                           size_t size) {
-    ggml_backend_cann_buffer_type_context* buft_ctx =
-        (ggml_backend_cann_buffer_type_context*)buft->context;
-
-    ggml_cann_set_device(buft_ctx->device);
-
-    size = std::max(size, (size_t)1);
-
-    void* dev_ptr;
-    aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
-    if (err != ACL_SUCCESS) {
-        GGML_LOG_ERROR(
-            "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
-            __func__, size / 1024.0 / 1024.0, buft_ctx->device,
-            aclGetRecentErrMsg());
-        return nullptr;
-    }
-
-    ggml_backend_cann_buffer_context* ctx =
-        new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
-
-    return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
-                                    ctx, size);
-}
-
-/**
- * @brief Retrieves the memory alignment requirement for CANN buffers of this
- * type.
- *
- * This function returns the alignment requirement in bytes for memory allocated
- * by the CANN buffer type.
- *
- * @param buft Pointer to the buffer type context (unused in this
- * implementation).
- * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
- * buffers).
- */
-static size_t ggml_backend_cann_buffer_type_get_alignment(
-    ggml_backend_buffer_type_t buft) {
-    return 128;
-
-    GGML_UNUSED(buft);
-}
-
-/**
- * @brief Calculates the allocation size required for a tensor in a CANN buffer.
- *
- * Computes the total allocation size needed for storing the tensor's data in a
- * CANN buffer, considering any necessary padding or adjustments for quantized
- * types.
- *
- * @param buft Pointer to the buffer type context (unused in this
- * implementation).
- * @param tensor Pointer to the tensor for which the allocation size is
- * calculated.
- * @return The total allocation size in bytes required for the tensor in the
- * CANN buffer.
- */
-static size_t ggml_backend_cann_buffer_type_get_alloc_size(
-    ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
-    size_t size = ggml_nbytes(tensor);
-    int64_t ne0 = tensor->ne[0];
-
-    // last line must bigger than 32, because every single op deal at
-    // least 32 bytes.
-    // TODO: quantized type?
-    // int64_t line_size = ne0 * ggml_element_size(tensor);
-    // int64_t line_size_align_32 = (line_size + 31) & ~31;
-    // size += (line_size_align_32 - line_size);
-
-    // TODO: not support quantized yet.
-    // TODO: consider un-continue tensor.
-    if (ggml_is_quantized(tensor->type)) {
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(
-                tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-    }
-
-    return size;
-
-    GGML_UNUSED(buft);
-}
-
-static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
-    return false;
-
-    GGML_UNUSED(buft);
-}
-
-/**
- * @brief Interface for managing CANN buffer types in the GGML backend.
- *
- * Provides function pointers for allocating, querying properties, and managing
- * memory for CANN buffer types in the GGML backend.
- */
-static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_cann_buffer_type_name,
-    /* .alloc_buffer     = */ ggml_backend_cann_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_cann_buffer_type_get_alignment,
-    /* .get_max_size     = */ NULL,  // defaults to SIZE_MAX
-    /* .get_alloc_size   = */ ggml_backend_cann_buffer_type_get_alloc_size,
-    /* .is_host          = */ ggml_backend_cann_buffer_type_is_host,
-};
-
-/**
- * @brief Retrieves the CANN buffer type for a specified device.
- *
- * This function initializes and returns the buffer type interface associated
- * with the given device. It ensures thread-safe access using a mutex.
- *
- * @param device The device index for which to retrieve the buffer type.
- * @return A pointer to the buffer type interface for the specified device, or
- * nullptr if the device index is out of range.
- */
-ggml_backend_buffer_type_t
-ggml_backend_cann_buffer_type(int32_t device) {
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-
-    if (device >= ggml_backend_cann_get_device_count()) {
-        return nullptr;
-    }
-
-    static ggml_backend_buffer_type
-        ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
-
-    static bool ggml_backend_cann_buffer_type_initialized = false;
-
-    if (!ggml_backend_cann_buffer_type_initialized) {
-        for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
-            ggml_backend_cann_buffer_types[i] = {
-                /* .iface    = */ ggml_backend_cann_buffer_type_interface,
-                /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
-                /* .context  = */
-                 new ggml_backend_cann_buffer_type_context{
-                    i, "CANN" + std::to_string(i)},
-            };
-        }
-        ggml_backend_cann_buffer_type_initialized = true;
-    }
-
-    return &ggml_backend_cann_buffer_types[device];
-}
-
-/**
- * @brief Retrieves the name associated with a CANN host buffer type.
- *
- * This function returns the descriptive name associated with the specified
- * CANN host buffer type context.
- *
- * @param buft Pointer to the host buffer type context.
- * @return Const pointer to the C-style string containing the name.
- */
-static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    return "CANN_Host";
-
-    GGML_UNUSED(buft);
-}
-
-/**
- * @brief Retrieves the name associated with a CANN host buffer.
- *
- * This function returns the descriptive name associated with the specified
- * CANN host buffer context.
- *
- * @param buft Pointer to the host buffer context.
- * @return Const pointer to the C-style string containing the name.
- */
-static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
-    return "CANN_Host";
-
-    GGML_UNUSED(buffer);
-}
-
-/**
- * @brief Free resources associated with a CANN host buffer.
- *
- * This function frees the resources associated with a CANN host buffer, including
- * its context.
- *
- * @param buffer The CANN host buffer to free.
- */
-static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
-    ACL_CHECK(aclrtFreeHost(buffer->context));
-}
-
-/**
- * @brief Allocates a new CANN host buffer of the specified size.
- *
- * This function allocates a new CANN host buffer with the given size.
- * @param size Size in bytes of the host buffer to allocate.
- * @return Pointer to the allocated host buffer, or nullptr if allocation fails.
- */
-static void * ggml_cann_host_malloc(size_t size) {
-    if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
-        return nullptr;
-    }
-
-    void * hostPtr = nullptr;
-    aclError err = aclrtMallocHost((void **) &hostPtr, size);
-    if (err != ACL_SUCCESS) {
-
-        GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
-                           size / 1024.0 / 1024.0, aclGetRecentErrMsg());
-        return nullptr;
-    }
-    return hostPtr;
-}
-
-/**
- * @brief Allocates a new CANN host buffer of the specified type and size.
- *
- * @param buft Pointer to the host buffer type context.
- * @param size Size in bytes of the host buffer to allocate.
- * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
- */
-static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    void * hostPtr = ggml_cann_host_malloc(size);
-
-    if (hostPtr == nullptr) {
-        // fallback to cpu buffer
-        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
-    }
-
-    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
-    buffer->buft = buft;
-    buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
-
-    return buffer;
-}
-
-/**
- * @brief Interface for managing CANN host buffer types in the GGML backend.
- *
- * Provides function pointers for allocating, querying properties, and managing
- * memory for CANN buffer types in the GGML backend.
- */
-ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
-    static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
-        /* .iface    = */ {
-            /* .get_name         = */ ggml_backend_cann_host_buffer_type_name,
-            /* .alloc_buffer     = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
-            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
-            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
-        },
-        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
-        /* .context  = */ nullptr,
-    };
-
-    return &ggml_backend_cann_buffer_type_host;
-}
-
-/**
- * @brief Computes the forward operation for a given tensor using CANN
- * operations.
- *
- * This function selects the appropriate CANN operation based on the type of
- * operation specified in the tensor and performs the computation.
- *
- * @param ctx The CANN context containing necessary resources and
- * configurations.
- * @param dst The destination tensor where the result of the computation will be
- * stored.
- * @return true if the computation was successful; false otherwise.
- */
-static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
-                                      struct ggml_tensor* dst) {
-    switch (dst->op) {
-        case GGML_OP_REPEAT:
-            ggml_cann_repeat(ctx, dst);
-            break;
-        case GGML_OP_GET_ROWS:
-            ggml_cann_get_rows(ctx, dst);
-            break;
-        case GGML_OP_DUP:
-            ggml_cann_dup(ctx, dst);
-            break;
-        case GGML_OP_ADD:
-            ggml_cann_add(ctx, dst);
-            break;
-        case GGML_OP_ACC:
-            ggml_cann_acc(ctx, dst);
-            break;
-        case GGML_OP_MUL:
-            ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
-            break;
-        case GGML_OP_DIV:
-            ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
-            break;
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(dst)) {
-                case GGML_UNARY_OP_GELU:
-                    ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
-                        ctx, dst);
-                    break;
-                case GGML_UNARY_OP_SILU:
-                    ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
-                        ctx, dst);
-                    break;
-                // TODO: Use faster gelu??
-                case GGML_UNARY_OP_GELU_QUICK:
-                    ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
-                        ctx, dst);
-                    break;
-                case GGML_UNARY_OP_TANH:
-                    ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
-                        ctx, dst);
-                    break;
-                case GGML_UNARY_OP_RELU:
-                    ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
-                        ctx, dst);
-                    break;
-                case GGML_UNARY_OP_HARDSIGMOID:
-                    ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
-                                         aclnnHardsigmoid>(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_HARDSWISH:
-                    ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
-                                         aclnnHardswish>(ctx, dst);
-                    break;
-                default:
-                    return false;
-            }
-            break;
-        case GGML_OP_NORM:
-            ggml_cann_norm(ctx, dst);
-            break;
-        case GGML_OP_GROUP_NORM:
-            ggml_cann_group_norm(ctx, dst);
-            break;
-        case GGML_OP_CONCAT:
-            ggml_cann_concat(ctx, dst);
-            break;
-        case GGML_OP_UPSCALE:
-            ggml_cann_upsample_nearest2d(ctx, dst);
-            break;
-        case GGML_OP_PAD:
-            ggml_cann_pad(ctx, dst);
-            break;
-        case GGML_OP_ARANGE:
-            ggml_cann_arange(ctx, dst);
-            break;
-        case GGML_OP_TIMESTEP_EMBEDDING:
-            ggml_cann_timestep_embedding(ctx, dst);
-            break;
-        case GGML_OP_LEAKY_RELU:
-            ggml_cann_leaky_relu(ctx, dst);
-            break;
-        case GGML_OP_RMS_NORM:
-            ggml_cann_rms_norm(ctx, dst);
-            break;
-        case GGML_OP_MUL_MAT:
-            ggml_cann_mul_mat(ctx, dst);
-            break;
-        case GGML_OP_MUL_MAT_ID:
-            return false;
-        case GGML_OP_SCALE:
-            ggml_cann_scale(ctx, dst);
-            break;
-        case GGML_OP_SQR:
-            ggml_cann_sqr(ctx, dst);
-            break;
-        case GGML_OP_CLAMP:
-            ggml_cann_clamp(ctx, dst);
-            break;
-        case GGML_OP_CPY:
-            ggml_cann_cpy(ctx, dst);
-            break;
-        case GGML_OP_CONT:
-            ggml_cann_dup(ctx, dst);
-            break;
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-            break;
-        case GGML_OP_DIAG_MASK_INF:
-            ggml_cann_diag_mask(ctx, dst, -INFINITY);
-            break;
-        case GGML_OP_SOFT_MAX:
-            ggml_cann_softmax(ctx, dst);
-            break;
-        case GGML_OP_ROPE:
-            ggml_cann_rope(ctx, dst);
-            break;
-        case GGML_OP_IM2COL:
-            ggml_cann_im2col(ctx, dst);
-            break;
-        case GGML_OP_POOL_2D:
-            ggml_cann_pool2d(ctx, dst);
-            break;
-        case GGML_OP_SUM_ROWS:
-            ggml_cann_sum_rows(ctx, dst);
-            break;
-        case GGML_OP_ARGSORT:
-            ggml_cann_argsort(ctx, dst);
-            break;
-        default:
-            return false;
-    }
-
-    return true;
-}
-
-// backend
-/**
- * @brief Retrieves the name associated with the CANN backend.
- *
- * This function returns the name assigned to the CANN backend, which is stored
- * in the context of the provided backend structure.
- *
- * @param backend Pointer to the CANN backend structure.
- * @return A pointer to a constant string representing the backend name.
- */
-static const char* ggml_backend_cann_name(ggml_backend_t backend) {
-    ggml_backend_cann_context* cann_ctx =
-        (ggml_backend_cann_context*)backend->context;
-
-    return cann_ctx->name.c_str();
-}
-
-/**
- * @brief Frees resources associated with the CANN backend.
- *
- * This function releases resources associated with the CANN backend context
- * and resets the device associated with the backend to its initial state.
- *
- * @param backend Pointer to the CANN backend structure to be freed.
- */
-static void ggml_backend_cann_free(ggml_backend_t backend) {
-    ggml_backend_cann_context* cann_ctx =
-        (ggml_backend_cann_context*)backend->context;
-    ACL_CHECK(aclrtSynchronizeDevice());
-    ACL_CHECK(aclrtResetDevice(cann_ctx->device));
-
-    // finalize when last backend freed.
-    if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
-        ACL_CHECK(aclFinalize());
-    }
-
-    delete cann_ctx;
-    delete backend;
-}
-
-/**
- * @brief Sets tensor data asynchronously in the CANN backend.
- *
- * This function asynchronously sets tensor data in the CANN backend. Depending
- * on the tensor type, it may perform data transformations before copying data
- * to the device.
- *
- * @param backend Pointer to the CANN backend structure.
- * @param tensor Pointer to the tensor structure to set data for.
- * @param data Pointer to the host data to copy to the tensor.
- * @param offset Offset in bytes within the host data.
- * @param size Size of the data to copy in bytes.
- */
-static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
-                                               ggml_tensor *tensor,
-                                               const void *data,
-                                               size_t offset,
-                                               size_t size) {
-    ggml_backend_cann_context *cann_ctx =
-        (ggml_backend_cann_context *)backend->context;
-
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
-                                   size, ACL_MEMCPY_HOST_TO_DEVICE,
-                                   cann_ctx->stream()));
-    } else {
-        void *transform_buffer = malloc(size);
-        ggml_backend_cann_transform(tensor, data, transform_buffer);
-
-        ACL_CHECK(aclrtMemcpyAsync(
-            (char *)tensor->data + offset, size, transform_buffer, size,
-            ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
-        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-        free(transform_buffer);
-    }
-}
-
-static void ggml_backend_cann_get_tensor_async(
-    ggml_backend_t backend, const ggml_tensor *tensor, void *data,
-    size_t offset, size_t size) {
-    ggml_backend_cann_context *cann_ctx =
-        (ggml_backend_cann_context *)backend->context;
-    ggml_backend_buffer_t buf =
-        tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
-    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
-                "unsupported buffer type");
-
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
-                                   size, ACL_MEMCPY_DEVICE_TO_HOST,
-                                   cann_ctx->stream()));
-    } else {
-        void *transform_buffer = malloc(size);
-        ACL_CHECK(aclrtMemcpyAsync(
-            transform_buffer, size, (char *)tensor->data + offset, size,
-            ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
-        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-        ggml_backend_cann_transform_back(tensor, transform_buffer, data);
-        free(transform_buffer);
-    }
-}
-
-/**
- * @brief Asynchronously copies tensor data between CANN backends.
- *
- * This function copies tensor data asynchronously between two CANN backends. It
- * checks if both tensors reside in CANN buffers and whether the devices support
- * peer-to-peer access for direct copying. If not, it returns false.
- *
- * @param backend_src Pointer to the source CANN backend structure.
- * @param backend_dst Pointer to the destination CANN backend structure.
- * @param src Pointer to the source tensor to copy data from.
- * @param dst Pointer to the destination tensor to copy data to.
- * @return true if the copy operation succeeds, false otherwise.
- */
-static bool ggml_backend_cann_cpy_tensor_async(
-    ggml_backend_t backend_src, ggml_backend_t backend_dst,
-    const ggml_tensor* src, ggml_tensor* dst) {
-    GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
-                ggml_backend_is_cann(backend_dst));
-
-    if (!ggml_backend_buffer_is_cann(src->buffer) ||
-        !ggml_backend_buffer_is_cann(dst->buffer)) {
-        return false;
-    }
-
-    ggml_backend_buffer_t buf_src =
-        src->view_src ? src->view_src->buffer : src->buffer;
-    ggml_backend_buffer_t buf_dst =
-        dst->view_src ? dst->view_src->buffer : dst->buffer;
-
-    ggml_backend_cann_context* cann_ctx_src =
-        (ggml_backend_cann_context*)backend_src->context;
-    ggml_backend_cann_context* cann_ctx_dst =
-        (ggml_backend_cann_context*)backend_dst->context;
-
-    size_t copy_size = ggml_nbytes(dst);
-    if (backend_src != backend_dst) {
-        ggml_backend_cann_buffer_context* buf_ctx_src =
-            (ggml_backend_cann_buffer_context*)buf_src->context;
-        ggml_backend_cann_buffer_context* buf_ctx_dst =
-            (ggml_backend_cann_buffer_context*)buf_dst->context;
-
-        GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
-        GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
-
-        int32_t canAccessPeer = 0;
-        ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
-                                           cann_ctx_dst->device));
-        if (!canAccessPeer) {
-            return false;
-        }
-
-        // need open both directions for memcpyasync between devices.
-        ggml_cann_set_device(cann_ctx_dst->device);
-        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
-        ggml_cann_set_device(cann_ctx_src->device);
-        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
-
-        ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
-                                   ACL_MEMCPY_DEVICE_TO_DEVICE,
-                                   cann_ctx_src->stream()));
-
-        //TODO: workaround for Event didn`t work here.
-        aclrtSynchronizeStream(cann_ctx_src->stream());
-    } else {
-        // src and dst are on the same backend
-        ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
-                                   ACL_MEMCPY_DEVICE_TO_DEVICE,
-                                   cann_ctx_dst->stream()));
-    }
-
-    return true;
-}
-
-/**
- * @brief Synchronizes a CANN backend.
- *
- * This function synchronizes the specified CANN backend by waiting for all
- * operations in its associated stream to complete.
- *
- * @param backend Pointer to the CANN backend structure to synchronize.
- */
-static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
-    ggml_backend_cann_context* cann_ctx =
-        (ggml_backend_cann_context*)backend->context;
-
-    ggml_cann_set_device(cann_ctx->device);
-
-    ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-}
-
-/**
- * @brief Computes a computational graph using a CANN backend.
- *
- * This function computes the operations defined in the computational graph
- * using the specified CANN backend.
- *
- * @param backend Pointer to the CANN backend structure to use for computation.
- * @param cgraph Pointer to the computational graph structure containing nodes
- *               representing operations to be computed.
- * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
- *         completes successfully, otherwise an appropriate error status.
- */
-static enum ggml_status ggml_backend_cann_graph_compute(
-    ggml_backend_t backend, ggml_cgraph* cgraph) {
-    ggml_backend_cann_context* cann_ctx =
-        (ggml_backend_cann_context*)backend->context;
-
-    ggml_cann_set_device(cann_ctx->device);
-
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_tensor* node = cgraph->nodes[i];
-
-        if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
-            continue;
-        }
-
-        bool ok = ggml_cann_compute_forward(*cann_ctx, node);
-
-        if (!ok) {
-            GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
-                    node->name, ggml_op_name(node->op));
-        }
-        GGML_ASSERT(ok);
-    }
-
-    return GGML_STATUS_SUCCESS;
-}
-
-/**
- * @brief Checks if the CANN backend supports a specific operation.
- *
- * This function checks whether the specified operation is supported by the
- * CANN backend.
- *
- * @param backend Pointer to the CANN backend structure to check support for
- *                the operation.
- * @param op Pointer to the tensor representing the operation to check.
- * @return bool Returns true if the operation is supported by the backend,
- *              otherwise false.
- */
-static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
-                                                    const ggml_tensor* op) {
-    switch (op->op) {
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_GELU:
-                case GGML_UNARY_OP_SILU:
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_HARDSIGMOID:
-                case GGML_UNARY_OP_HARDSWISH:
-                case GGML_UNARY_OP_GELU_QUICK:
-                case GGML_UNARY_OP_TANH:
-                    return true;
-                default:
-                    return false;
-            }
-        case GGML_OP_MUL_MAT: {
-            switch (op->src[0]->type) {
-                case GGML_TYPE_F16:
-                case GGML_TYPE_F32:
-                case GGML_TYPE_Q8_0:
-                    // TODO: fix me
-                    // Current groupsize should not be greater than k-1 in
-                    // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
-                case GGML_TYPE_Q4_0:
-                    return true;
-                default:
-                    return false;
-            }
-        }
-        case GGML_OP_MUL_MAT_ID:
-            return false;
-        // embedding
-        case GGML_OP_GET_ROWS: {
-            switch (op->src[0]->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                case GGML_TYPE_Q4_0:
-                case GGML_TYPE_Q8_0:
-                    return true;
-                default:
-                    return false;
-            }
-        } break;
-        case GGML_OP_CPY: {
-            switch (op->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                case GGML_TYPE_Q8_0:
-                case GGML_TYPE_Q4_0:
-                    return true;
-                default:
-                    return false;
-            }
-        }
-        case GGML_OP_DUP:
-        case GGML_OP_REPEAT:
-        case GGML_OP_CONCAT:
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_NORM:
-        case GGML_OP_ADD:
-        case GGML_OP_MUL:
-        case GGML_OP_DIV:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_SCALE:
-        case GGML_OP_SQR:
-        case GGML_OP_CLAMP:
-        case GGML_OP_CONT:
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX:
-        case GGML_OP_ROPE:
-        case GGML_OP_IM2COL:
-        case GGML_OP_POOL_2D:
-        case GGML_OP_SUM_ROWS:
-        case GGML_OP_ARGSORT:
-        case GGML_OP_ACC:
-        case GGML_OP_GROUP_NORM:
-        case GGML_OP_UPSCALE:
-        case GGML_OP_PAD:
-        case GGML_OP_ARANGE:
-        case GGML_OP_TIMESTEP_EMBEDDING:
-        case GGML_OP_LEAKY_RELU:
-            return true;
-        default:
-            return false;
-    }
-
-    GGML_UNUSED(dev);
-}
-
-/**
- * @brief Checks if the backend buffer type is associated with the CANN backend.
- *
- * This function checks whether the provided backend buffer type is associated
- * with the CANN backend based on the comparison of its name retrieval function
- * pointer.
- *
- * @param buft Pointer to the backend buffer type to check.
- * @return bool Returns true if the buffer type is associated with the CANN
- * backend, otherwise false.
- */
-static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
-}
-
-/**
- * @brief Determines if a tensor operation should be offloaded to the CANN
- * backend.
- *
- * This function checks if a given tensor operation should be offloaded to the
- * CANN backend based on the operation type and the size of the tensor. It
- * returns true if the second dimension (ne[1]) of the tensor is greater than or
- * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
- *
- * @param backend Pointer to the CANN backend.
- * @param op Pointer to the tensor operation to check.
- * @return bool Returns true if the operation should be offloaded, otherwise
- * false.
- */
-static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
-                                                   const ggml_tensor* op) {
-    const int min_batch_size = 32;
-    GGML_UNUSED(dev);
-
-    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
-}
-
-/**
- * @brief Records an event on the CANN backend stream.
- *
- * This function records the given event on the ACL runtime stream associated
- * with the backend context.
- *
- * @param event Pointer to the event structure to be recorded.
- */
-static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
-    ggml_backend_cann_context* cann_ctx =
-        (ggml_backend_cann_context*)backend->context;
-    ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
-}
-
-/**
- * @brief Waits for a recorded event to complete on the CANN backend stream.
- *
- * This function makes the given backend wait for the event to complete on its
- * ACL runtime stream.
- *
- * @param backend Pointer to the backend structure.
- * @param event Pointer to the event structure that the backend needs to wait
- * for.
- */
-static void ggml_backend_cann_event_wait(ggml_backend_t backend,
-                                         ggml_backend_event_t event) {
-    ggml_backend_cann_context* cann_ctx =
-        (ggml_backend_cann_context*)backend->context;
-    if (ggml_backend_is_cann(backend)) {
-        ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
-                                       (aclrtEvent)event->context));
-    } else {
-        GGML_ABORT("fatal error");
-    }
-}
-
-/**
- * @brief Structure defining the interface for the CANN backend.
- *
- * This structure contains function pointers for various operations
- * supported by the CANN backend, including name retrieval, memory
- * management, tensor operations, synchronization, and event handling.
- */
-static const ggml_backend_i ggml_backend_cann_interface = {
-    /* .get_name                = */ ggml_backend_cann_name,
-    /* .free                    = */ ggml_backend_cann_free,
-    /* .set_tensor_async        = */ ggml_backend_cann_set_tensor_async,
-    /* .get_tensor_async        = */ ggml_backend_cann_get_tensor_async,
-    /* .cpy_tensor_async        = */ ggml_backend_cann_cpy_tensor_async,
-    /* .synchronize             = */ ggml_backend_cann_synchronize,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_cann_graph_compute,
-    /* .event_record            = */ ggml_backend_cann_event_record,
-    /* .event_wait              = */ ggml_backend_cann_event_wait,
-};
-
-/**
- * @brief Return the hardcoded GUID for the CANN backend.
- *
- * This function returns a static GUID which uniquely identifies the CANN
- * backend.
- *
- * @return A pointer to the static GUID.
- */
-static ggml_guid_t ggml_backend_cann_guid() {
-    static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
-                             0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
-    return &guid;
-}
-
-// backend device
-struct ggml_backend_cann_device_context {
-    int device;
-    std::string name;
-    std::string description;
-};
-
-static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
-    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
-    return ctx->name.c_str();
-}
-
-static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
-    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
-    return ctx->description.c_str();
-}
-
-static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
-    ggml_backend_cann_get_device_memory(ctx->device, free, total);
-}
-
-static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
-    GGML_UNUSED(dev);
-    return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_cann_device_get_name(dev);
-    props->description = ggml_backend_cann_device_get_description(dev);
-    props->type        = ggml_backend_cann_device_get_type(dev);
-    ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
-
-    bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
-
-    props->caps = {
-        /* .async                 = */ false,
-        /* .host_buffer           = */ host_buffer,
-        /* .buffer_from_host_ptr  = */ false,
-        /* .events                = */ true,
-    };
-}
-
-static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
-    GGML_UNUSED(params);
-    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
-    return ggml_backend_cann_init(ctx->device);
-}
-
-/**
- * @brief Checks if the CANN backend supports a specific backend buffer type.
- *
- * This function determines whether the CANN backend supports the given backend
- * buffer type by comparing the device context of the backend and buffer type.
- * It returns true if the devices are same between the backend context and
- * buffer type context.
- *
- * @param backend Pointer to the CANN backend.
- * @param buft Pointer to the backend buffer type to check.
- * @return bool Returns true if the CANN backend supports the buffer type,
- *              otherwise false.
- */
-static bool ggml_backend_cann_supports_buft(
-    ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    if (ggml_backend_buft_is_cann(buft)) {
-        ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
-        ggml_backend_cann_buffer_type_context * buft_ctx =
-                        (ggml_backend_cann_buffer_type_context *)buft->context;
-        return buft_ctx->device == dev_ctx->device;
-    }
-    return false;
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
-    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
-    return ggml_backend_cann_buffer_type(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
-    GGML_UNUSED(dev);
-    return ggml_backend_cann_host_buffer_type();
-}
-
-/**
- * @brief Creates a new event for the CANN backend device.
- *
- * This function initializes a new event for the CANN backend by setting the
- * device and creating an ACL runtime event. The created event is then wrapped
- * in a ggml_backend_event structure and returned.
- *
- * @param backend Pointer to the CANN backend.
- * @return ggml_backend_event_t Returns a pointer to the new event structure.
- */
-static ggml_backend_event_t ggml_backend_cann_device_event_new(
-    ggml_backend_dev_t dev) {
-    ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
-
-    ggml_cann_set_device(dev_ctx->device);
-
-    aclrtEvent event;
-    ACL_CHECK(aclrtCreateEvent(&event));
-
-    return new ggml_backend_event{
-        /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
-        /* .context = */ event,
-    };
-}
-
-/**
- * @brief Frees a CANN backend event.
- *
- * This function destroys the ACL runtime event associated with the given CANN
- * backend event and then deletes the event structure itself.
- *
- * @param event Pointer to the event structure to be freed.
- */
-static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
-    ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
-
-    delete event;
-    GGML_UNUSED(dev);
-}
-
-/**
- * @brief Synchronizes the given event on the CANN backend.
- *
- * This function waits for the specified event to complete on the ACL runtime.
- *
- * @param event Pointer to the event structure to be synchronized.
- */
-static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
-    ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
-
-    GGML_UNUSED(dev);
-}
-
-static const ggml_backend_device_i ggml_backend_cann_device_interface = {
-    /* .get_name                = */ ggml_backend_cann_device_get_name,
-    /* .get_description         = */ ggml_backend_cann_device_get_description,
-    /* .get_memory              = */ ggml_backend_cann_device_get_memory,
-    /* .get_type                = */ ggml_backend_cann_device_get_type,
-    /* .get_props               = */ ggml_backend_cann_device_get_props,
-    /* .init_backend            = */ ggml_backend_cann_device_init,    // called for every card
-    /* .get_buffer_type         = */ ggml_backend_cann_device_get_buffer_type,
-    /* .get_host_buffer_type    = */ ggml_backend_cann_device_get_host_buffer_type,
-    /* .buffer_from_host_ptr    = */ NULL, // not supported for CANN
-    /* .supports_op             = */ ggml_backend_cann_supports_op,
-    /* .supports_buft           = */ ggml_backend_cann_supports_buft,
-    /* .offload_op              = */ ggml_backend_cann_offload_op,
-    /* .event_new               = */ ggml_backend_cann_device_event_new,
-    /* .event_free              = */ ggml_backend_cann_device_event_free,
-    /* .event_synchronize       = */ ggml_backend_cann_device_event_synchronize,
-};
-
-
-// backend reg
-struct ggml_backend_cann_reg_context {
-    std::vector<ggml_backend_dev_t> devices;
-};
-
-static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
-    GGML_UNUSED(reg);
-    return GGML_CANN_NAME;
-}
-
-static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
-    ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
-    return ctx->devices.size();
-}
-
-static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
-    ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
-    GGML_ASSERT(index < ctx->devices.size());
-    return ctx->devices[index];
-}
-
-static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
-    GGML_UNUSED(reg);
-    GGML_UNUSED(name);
-    // reserved for future use
-    return nullptr;
-}
-
-static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
-    /* .get_name          = */ ggml_backend_cann_reg_get_name,
-    /* .get_device_count  = */ ggml_backend_cann_reg_get_device_count,
-    /* .get_device_get    = */ ggml_backend_cann_reg_get_device,
-    /* .get_proc_address  = */ ggml_backend_cann_reg_get_proc_address,
-};
-
-// backend registry, called only once for cann backend
-ggml_backend_reg_t ggml_backend_cann_reg() {
-    static ggml_backend_reg reg;
-    static bool initialized = false;
-
-    {
-        static std::mutex mutex;
-        std::lock_guard<std::mutex> lock(mutex);
-        if (!initialized) {
-            aclInit(nullptr);
-            ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
-
-            for (int i = 0; i < ggml_cann_info().device_count; i++) {
-                ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
-                dev_ctx->description = aclrtGetSocName();
-                dev_ctx->device = i;
-                dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
-                ggml_cann_set_device(i);
-                ggml_backend_dev_t dev = new ggml_backend_device {
-                    /* .interface = */ ggml_backend_cann_device_interface,
-                    /* .reg       = */ &reg,
-                    /* .context   = */ dev_ctx
-                };
-                ctx->devices.push_back(dev);
-            }
-
-            reg = ggml_backend_reg {
-                /* .interface = */ ggml_backend_cann_reg_interface,
-                /* .context   = */ ctx
-            };
-        }
-
-        initialized = true;
-    }
-
-    return &reg;
-}
-
-ggml_backend_t ggml_backend_cann_init(int32_t device) {
-    aclInit(nullptr);
-    if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
-        GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
-        return nullptr;
-    }
-
-    ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
-    if (ctx == nullptr) {
-        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
-        return nullptr;
-    }
-    ggml_cann_set_device(ctx->device);
-    ggml_backend_t cann_backend =
-        new ggml_backend{/* .guid      = */ ggml_backend_cann_guid(),
-                         /* .interface = */ ggml_backend_cann_interface,
-                         /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
-                         /* .context   = */ ctx};
-
-    return cann_backend;
-}
-
-bool ggml_backend_is_cann(ggml_backend_t backend) {
-    return backend != NULL &&
-           ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
-}
-
-int32_t ggml_backend_cann_get_device_count() {
-    return ggml_cann_info().device_count;
-}
-
-void ggml_backend_cann_get_device_description(
-    int32_t device, char* description, size_t description_size) {
-    ggml_cann_set_device(device);
-    const char* soc_name = aclrtGetSocName();
-    snprintf(description, description_size, "%s", soc_name);
-}
-
-void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
-                                         size_t* total) {
-    ggml_cann_set_device(device);
-    ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
-}
diff --git a/ggml/src/ggml-cpu-impl.h b/ggml/src/ggml-cpu-impl.h
deleted file mode 100644 (file)
index 5b45155..0000000
+++ /dev/null
@@ -1,614 +0,0 @@
-#pragma once
-
-// GGML CPU internal header
-
-#include "ggml.h"
-#include "ggml-impl.h"
-#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
-//#include <stddef.h>
-#include <stdbool.h>
-#include <string.h> // memcpy
-#include <math.h>   // fabsf
-
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#if defined(_MSC_VER)
-
-#define m512bh(p) p
-#define m512i(p) p
-
-#else
-
-#define m512bh(p) (__m512bh)(p)
-#define m512i(p) (__m512i)(p)
-
-#endif
-
-/**
- * Converts brain16 to float32.
- *
- * The bfloat16 floating point format has the following structure:
- *
- *       ┌sign
- *       │
- *       │   ┌exponent
- *       │   │
- *       │   │      ┌mantissa
- *       │   │      │
- *       │┌──┴───┐┌─┴───┐
- *     0b0000000000000000 brain16
- *
- * Since bf16 has the same number of exponent bits as a 32bit float,
- * encoding and decoding numbers becomes relatively straightforward.
- *
- *       ┌sign
- *       │
- *       │   ┌exponent
- *       │   │
- *       │   │      ┌mantissa
- *       │   │      │
- *       │┌──┴───┐┌─┴───────────────────┐
- *     0b00000000000000000000000000000000 IEEE binary32
- *
- * For comparison, the standard fp16 format has fewer exponent bits.
- *
- *       ┌sign
- *       │
- *       │  ┌exponent
- *       │  │
- *       │  │    ┌mantissa
- *       │  │    │
- *       │┌─┴─┐┌─┴──────┐
- *     0b0000000000000000 IEEE binary16
- *
- * @see IEEE 754-2008
- */
-static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
-    union {
-        float f;
-        uint32_t i;
-    } u;
-    u.i = (uint32_t)h.bits << 16;
-    return u.f;
-}
-
-/**
- * Converts float32 to brain16.
- *
- * This is binary identical with Google Brain float conversion.
- * Floats shall round to nearest even, and NANs shall be quiet.
- * Subnormals aren't flushed to zero, except perhaps when used.
- * This code should vectorize nicely if using modern compilers.
- */
-static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
-    ggml_bf16_t h;
-    union {
-        float f;
-        uint32_t i;
-    } u;
-    u.f = s;
-    if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
-        h.bits = (u.i >> 16) | 64; /* force to quiet */
-        return h;
-    }
-    h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
-    return h;
-}
-
-#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
-#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
-
-// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
-#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
-#ifndef __FMA__
-#define __FMA__
-#endif
-#ifndef __F16C__
-#define __F16C__
-#endif
-#endif
-
-// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
-#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
-#ifndef __SSE3__
-#define __SSE3__
-#endif
-#ifndef __SSSE3__
-#define __SSSE3__
-#endif
-#endif
-
-#if defined(__ARM_FEATURE_SVE)
-#include <arm_sve.h>
-#include <sys/prctl.h>
-#endif
-
-// 16-bit float
-// on Arm, we use __fp16
-// on x86, we use uint16_t
-#if defined(__ARM_NEON)
-
-// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
-//
-//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
-//
-#include <arm_neon.h>
-
-#ifdef _MSC_VER
-
-typedef uint16_t ggml_fp16_internal_t;
-
-#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
-
-#else
-
-typedef __fp16 ggml_fp16_internal_t;
-
-#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
-
-#endif // _MSC_VER
-
-#if !defined(__aarch64__)
-
-// 32-bit ARM compatibility
-
-// vaddlvq_s16
-// vpaddq_s16
-// vpaddq_s32
-// vaddvq_s32
-// vaddvq_f32
-// vmaxvq_f32
-// vcvtnq_s32_f32
-// vzip1_u8
-// vzip2_u8
-
-inline static int32_t vaddlvq_s16(int16x8_t v) {
-    int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
-    return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
-}
-
-inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
-    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
-    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
-    return vcombine_s16(a0, b0);
-}
-
-inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
-    int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
-    int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
-    return vcombine_s32(a0, b0);
-}
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
-    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-
-inline static float vaddvq_f32(float32x4_t v) {
-    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
-}
-
-inline static float vmaxvq_f32(float32x4_t v) {
-    return
-        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
-            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
-}
-
-inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
-    int32x4_t res;
-
-    res[0] = roundf(vgetq_lane_f32(v, 0));
-    res[1] = roundf(vgetq_lane_f32(v, 1));
-    res[2] = roundf(vgetq_lane_f32(v, 2));
-    res[3] = roundf(vgetq_lane_f32(v, 3));
-
-    return res;
-}
-
-inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
-    uint8x8_t res;
-
-    res[0] = a[0]; res[1] = b[0];
-    res[2] = a[1]; res[3] = b[1];
-    res[4] = a[2]; res[5] = b[2];
-    res[6] = a[3]; res[7] = b[3];
-
-    return res;
-}
-
-inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
-    uint8x8_t res;
-
-    res[0] = a[4]; res[1] = b[4];
-    res[2] = a[5]; res[3] = b[5];
-    res[4] = a[6]; res[5] = b[6];
-    res[6] = a[7]; res[7] = b[7];
-
-    return res;
-}
-
-// vld1q_s16_x2
-// vld1q_u8_x2
-// vld1q_u8_x4
-// vld1q_s8_x2
-// vld1q_s8_x4
-// TODO: double-check these work correctly
-
-typedef struct ggml_int16x8x2_t {
-    int16x8_t val[2];
-} ggml_int16x8x2_t;
-
-inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
-    ggml_int16x8x2_t res;
-
-    res.val[0] = vld1q_s16(ptr + 0);
-    res.val[1] = vld1q_s16(ptr + 8);
-
-    return res;
-}
-
-typedef struct ggml_uint8x16x2_t {
-    uint8x16_t val[2];
-} ggml_uint8x16x2_t;
-
-inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
-    ggml_uint8x16x2_t res;
-
-    res.val[0] = vld1q_u8(ptr + 0);
-    res.val[1] = vld1q_u8(ptr + 16);
-
-    return res;
-}
-
-typedef struct ggml_uint8x16x4_t {
-    uint8x16_t val[4];
-} ggml_uint8x16x4_t;
-
-inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
-    ggml_uint8x16x4_t res;
-
-    res.val[0] = vld1q_u8(ptr + 0);
-    res.val[1] = vld1q_u8(ptr + 16);
-    res.val[2] = vld1q_u8(ptr + 32);
-    res.val[3] = vld1q_u8(ptr + 48);
-
-    return res;
-}
-
-typedef struct ggml_int8x16x2_t {
-    int8x16_t val[2];
-} ggml_int8x16x2_t;
-
-inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
-    ggml_int8x16x2_t res;
-
-    res.val[0] = vld1q_s8(ptr + 0);
-    res.val[1] = vld1q_s8(ptr + 16);
-
-    return res;
-}
-
-typedef struct ggml_int8x16x4_t {
-    int8x16_t val[4];
-} ggml_int8x16x4_t;
-
-inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
-    ggml_int8x16x4_t res;
-
-    res.val[0] = vld1q_s8(ptr + 0);
-    res.val[1] = vld1q_s8(ptr + 16);
-    res.val[2] = vld1q_s8(ptr + 32);
-    res.val[3] = vld1q_s8(ptr + 48);
-
-    return res;
-}
-
-// NOTE: not tested
-inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
-    int8x16_t res;
-
-    res[ 0] = a[b[ 0]];
-    res[ 1] = a[b[ 1]];
-    res[ 2] = a[b[ 2]];
-    res[ 3] = a[b[ 3]];
-    res[ 4] = a[b[ 4]];
-    res[ 5] = a[b[ 5]];
-    res[ 6] = a[b[ 6]];
-    res[ 7] = a[b[ 7]];
-    res[ 8] = a[b[ 8]];
-    res[ 9] = a[b[ 9]];
-    res[10] = a[b[10]];
-    res[11] = a[b[11]];
-    res[12] = a[b[12]];
-    res[13] = a[b[13]];
-    res[14] = a[b[14]];
-    res[15] = a[b[15]];
-
-    return res;
-}
-
-// NOTE: not tested
-inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
-    uint8x16_t res;
-
-    res[ 0] = a[b[ 0]];
-    res[ 1] = a[b[ 1]];
-    res[ 2] = a[b[ 2]];
-    res[ 3] = a[b[ 3]];
-    res[ 4] = a[b[ 4]];
-    res[ 5] = a[b[ 5]];
-    res[ 6] = a[b[ 6]];
-    res[ 7] = a[b[ 7]];
-    res[ 8] = a[b[ 8]];
-    res[ 9] = a[b[ 9]];
-    res[10] = a[b[10]];
-    res[11] = a[b[11]];
-    res[12] = a[b[12]];
-    res[13] = a[b[13]];
-    res[14] = a[b[14]];
-    res[15] = a[b[15]];
-
-    return res;
-}
-
-#else
-
-#define ggml_int16x8x2_t  int16x8x2_t
-#define ggml_uint8x16x2_t uint8x16x2_t
-#define ggml_uint8x16x4_t uint8x16x4_t
-#define ggml_int8x16x2_t  int8x16x2_t
-#define ggml_int8x16x4_t  int8x16x4_t
-
-#define ggml_vld1q_s16_x2 vld1q_s16_x2
-#define ggml_vld1q_u8_x2  vld1q_u8_x2
-#define ggml_vld1q_u8_x4  vld1q_u8_x4
-#define ggml_vld1q_s8_x2  vld1q_s8_x2
-#define ggml_vld1q_s8_x4  vld1q_s8_x4
-#define ggml_vqtbl1q_s8   vqtbl1q_s8
-#define ggml_vqtbl1q_u8   vqtbl1q_u8
-
-#endif // !defined(__aarch64__)
-
-#if !defined(__ARM_FEATURE_DOTPROD)
-
-inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
-    const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
-    const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
-
-    return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
-}
-
-#else
-
-#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
-
-#endif // !defined(__ARM_FEATURE_DOTPROD)
-
-#endif // defined(__ARM_NEON)
-
-#if defined(__ARM_NEON) && !defined(_MSC_VER)
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-
-#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-    ggml_fp16_internal_t tmp;
-    memcpy(&tmp, &h, sizeof(ggml_fp16_t));
-    return (float)tmp;
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-    ggml_fp16_t res;
-    ggml_fp16_internal_t tmp = f;
-    memcpy(&res, &tmp, sizeof(ggml_fp16_t));
-    return res;
-}
-
-#else
-
-#ifdef __wasm_simd128__
-#include <wasm_simd128.h>
-#else
-#ifdef __POWER9_VECTOR__
-#include <altivec.h>
-#undef bool
-#define bool _Bool
-#else
-#if defined(_MSC_VER) || defined(__MINGW32__)
-#include <intrin.h>
-#else
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
-#if !defined(__riscv)
-#include <immintrin.h>
-#endif
-#endif
-#endif
-#endif
-#endif
-
-#ifdef __riscv_v_intrinsic
-#include <riscv_vector.h>
-#endif
-
-#if defined(__loongarch64)
-#if defined(__loongarch_asx)
-#include <lasxintrin.h>
-#endif
-#if defined(__loongarch_sx)
-#include <lsxintrin.h>
-#endif
-#endif
-
-#if defined(__loongarch_asx)
-
-typedef union {
-    int32_t i;
-    float f;
-} ft_union;
-
-/* float type data load instructions */
-static __m128 __lsx_vreplfr2vr_s(float val) {
-    ft_union fi_tmpval = {.f = val};
-    return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
-}
-
-static __m256 __lasx_xvreplfr2vr_s(float val) {
-    ft_union fi_tmpval = {.f = val};
-    return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
-}
-#endif
-
-#ifdef __F16C__
-
-#ifdef _MSC_VER
-#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
-#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
-#else
-#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
-#endif
-
-#elif defined(__POWER9_VECTOR__)
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-/* the inline asm below is about 12% faster than the lookup method */
-#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
-#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-    register float f;
-    register double d;
-    __asm__(
-        "mtfprd %0,%2\n"
-        "xscvhpdp %0,%0\n"
-        "frsp %1,%0\n" :
-        /* temp */ "=d"(d),
-        /* out */  "=f"(f):
-        /* in */   "r"(h));
-    return f;
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-    register double d;
-    register ggml_fp16_t r;
-    __asm__( /* xscvdphp can work on double or single precision */
-        "xscvdphp %0,%2\n"
-        "mffprd %1,%0\n" :
-        /* temp */ "=d"(d),
-        /* out */  "=r"(r):
-        /* in */   "f"(f));
-    return r;
-}
-
-#else
-
-// FP16 <-> FP32
-// ref: https://github.com/Maratyszcza/FP16
-
-static inline float fp32_from_bits(uint32_t w) {
-    union {
-        uint32_t as_bits;
-        float as_value;
-    } fp32;
-    fp32.as_bits = w;
-    return fp32.as_value;
-}
-
-static inline uint32_t fp32_to_bits(float f) {
-    union {
-        float as_value;
-        uint32_t as_bits;
-    } fp32;
-    fp32.as_value = f;
-    return fp32.as_bits;
-}
-
-static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
-    const uint32_t w = (uint32_t) h << 16;
-    const uint32_t sign = w & UINT32_C(0x80000000);
-    const uint32_t two_w = w + w;
-
-    const uint32_t exp_offset = UINT32_C(0xE0) << 23;
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
-    const float exp_scale = 0x1.0p-112f;
-#else
-    const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
-#endif
-    const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
-
-    const uint32_t magic_mask = UINT32_C(126) << 23;
-    const float magic_bias = 0.5f;
-    const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
-
-    const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
-    const uint32_t result = sign |
-        (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
-    return fp32_from_bits(result);
-}
-
-static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
-#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
-    const float scale_to_inf = 0x1.0p+112f;
-    const float scale_to_zero = 0x1.0p-110f;
-#else
-    const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
-    const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
-#endif
-    float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
-
-    const uint32_t w = fp32_to_bits(f);
-    const uint32_t shl1_w = w + w;
-    const uint32_t sign = w & UINT32_C(0x80000000);
-    uint32_t bias = shl1_w & UINT32_C(0xFF000000);
-    if (bias < UINT32_C(0x71000000)) {
-        bias = UINT32_C(0x71000000);
-    }
-
-    base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
-    const uint32_t bits = fp32_to_bits(base);
-    const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
-    const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
-    const uint32_t nonsign = exp_bits + mantissa_bits;
-    return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
-}
-
-#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
-#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
-
-#endif // __F16C__
-
-#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
-
-#ifdef __ARM_FEATURE_SVE
-#include <arm_sve.h>
-#endif // __ARM_FEATURE_SVE
-
-// precomputed f32 table for f16 (256 KB)
-// defined in ggml.c, initialized in ggml_init()
-extern float ggml_table_f32_f16[1 << 16];
-
-// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
-// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
-// This is also true for POWER9.
-#if !defined(GGML_FP16_TO_FP32)
-inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
-    uint16_t s;
-    memcpy(&s, &f, sizeof(uint16_t));
-    return ggml_table_f32_f16[s];
-}
-
-#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
-#endif
-
-#if !defined(GGML_FP32_TO_FP16)
-#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
-#endif
-
-#ifdef __cplusplus
-}
-#endif
diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c
deleted file mode 100644 (file)
index de1de18..0000000
+++ /dev/null
@@ -1,13834 +0,0 @@
-#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
-#define _USE_MATH_DEFINES // For M_PI on MSVC
-
-#include "ggml-aarch64.h"
-#include "ggml-backend-impl.h"
-#include "ggml-backend.h"
-#include "ggml-cpu-impl.h"
-#include "ggml-cpu.h"
-#include "ggml-impl.h"
-#include "ggml-quants.h"
-#include "ggml.h"
-
-#if defined(_MSC_VER) || defined(__MINGW32__)
-#include <malloc.h> // using malloc.h with MSC/MINGW
-#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
-#include <alloca.h>
-#endif
-
-#include <assert.h>
-#include <errno.h>
-#include <time.h>
-#include <math.h>
-#include <stdlib.h>
-#include <string.h>
-#include <stdint.h>
-#include <inttypes.h>
-#include <stdio.h>
-#include <float.h>
-#include <limits.h>
-#include <stdarg.h>
-#include <signal.h>
-#if defined(__gnu_linux__)
-#include <syscall.h>
-#endif
-
-#ifdef GGML_USE_OPENMP
-#include <omp.h>
-#endif
-
-#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
-#undef GGML_USE_LLAMAFILE
-#endif
-
-#ifdef GGML_USE_LLAMAFILE
-#include <llamafile/sgemm.h>
-#endif
-
-#if defined(_MSC_VER)
-// disable "possible loss of data" to avoid hundreds of casts
-// we should just be careful :)
-#pragma warning(disable: 4244 4267)
-
-// disable POSIX deprecation warnings
-// these functions are never going away, anyway
-#pragma warning(disable: 4996)
-
-// unreachable code because of multiple instances of code after GGML_ABORT
-#pragma warning(disable: 4702)
-#endif
-
-// Note: once we move threading into a separate C++ file
-// will use std::hardware_destructive_interference_size instead of hardcoding it here
-// and we'll use C++ attribute syntax.
-#define GGML_CACHE_LINE  64
-
-#if defined(__clang__) || defined(__GNUC__)
-#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))
-#endif
-
-#if defined(__has_feature)
-#if __has_feature(thread_sanitizer)
-#define GGML_TSAN_ENABLED 1
-#endif
-#else  // __has_feature
-#if defined(__SANITIZE_THREAD__)
-#define GGML_TSAN_ENABLED 1
-#endif
-#endif // __has_feature
-
-#define UNUSED GGML_UNUSED
-#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
-
-#if defined(GGML_USE_ACCELERATE)
-#include <Accelerate/Accelerate.h>
-#endif
-
-// floating point type used to accumulate sums
-typedef double ggml_float;
-
-#define GGML_GELU_FP16
-#define GGML_GELU_QUICK_FP16
-
-#define GGML_SOFT_MAX_UNROLL 4
-#define GGML_VEC_DOT_UNROLL  2
-#define GGML_VEC_MAD_UNROLL  32
-
-//
-// global data
-//
-
-// precomputed gelu table for f16 (128 KB)
-static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
-
-// precomputed quick gelu table for f16 (128 KB)
-static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
-
-// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
-float ggml_table_f32_f16[1 << 16];
-
-#if defined(__ARM_ARCH)
-struct ggml_arm_arch_features_type {
-    int has_neon;
-    int has_i8mm;
-    int has_sve;
-    int sve_cnt;
-} ggml_arm_arch_features = {-1, -1, -1, 0};
-#endif
-
-
-#if defined(_WIN32)
-
-#define WIN32_LEAN_AND_MEAN
-#ifndef NOMINMAX
-    #define NOMINMAX
-#endif
-#include <windows.h>
-
-
-#if !defined(__clang__)
-#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
-
-typedef volatile LONG atomic_int;
-typedef atomic_int atomic_bool;
-typedef atomic_int atomic_flag;
-
-#define ATOMIC_FLAG_INIT 0
-
-typedef enum {
-    memory_order_relaxed,
-    memory_order_consume,
-    memory_order_acquire,
-    memory_order_release,
-    memory_order_acq_rel,
-    memory_order_seq_cst
-} memory_order;
-
-static void atomic_store(atomic_int * ptr, LONG val) {
-    InterlockedExchange(ptr, val);
-}
-static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {
-    // TODO: add support for explicit memory order
-    InterlockedExchange(ptr, val);
-}
-static LONG atomic_load(atomic_int * ptr) {
-    return InterlockedCompareExchange(ptr, 0, 0);
-}
-static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {
-    // TODO: add support for explicit memory order
-    return InterlockedCompareExchange(ptr, 0, 0);
-}
-static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
-    return InterlockedExchangeAdd(ptr, inc);
-}
-static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {
-    // TODO: add support for explicit memory order
-    return InterlockedExchangeAdd(ptr, inc);
-}
-static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
-    return InterlockedExchange(ptr, 1);
-}
-static void atomic_flag_clear(atomic_flag * ptr) {
-    InterlockedExchange(ptr, 0);
-}
-static void atomic_thread_fence(memory_order mo) {
-    MemoryBarrier();
-}
-#else // clang
-#include <stdatomic.h>
-#endif
-
-typedef HANDLE pthread_t;
-
-typedef DWORD thread_ret_t;
-static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
-    (void) unused;
-    HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
-    if (handle == NULL)
-    {
-        return EAGAIN;
-    }
-
-    *out = handle;
-    return 0;
-}
-
-static int pthread_join(pthread_t thread, void * unused) {
-    (void) unused;
-    int ret = (int) WaitForSingleObject(thread, INFINITE);
-    CloseHandle(thread);
-    return ret;
-}
-
-static int sched_yield (void) {
-    Sleep (0);
-    return 0;
-}
-#else
-
-#include <pthread.h>
-#include <stdatomic.h>
-#include <sched.h>
-#if defined(__FreeBSD__)
-#include <pthread_np.h>
-#endif
-
-typedef void * thread_ret_t;
-
-#include <sys/types.h>
-#include <sys/stat.h>
-#include <unistd.h>
-
-#endif
-
-typedef pthread_t ggml_thread_t;
-
-#ifdef GGML_USE_CPU_HBM
-#include <hbwmalloc.h>
-#endif
-
-#if defined(__APPLE__)
-#include <unistd.h>
-#include <mach/mach.h>
-#include <TargetConditionals.h>
-#endif
-
-//
-// cache line
-//
-
-#if defined(__cpp_lib_hardware_interference_size)
-#define CACHE_LINE_SIZE hardware_destructive_interference_size
-#else
-#if defined(__POWER9_VECTOR__)
-#define CACHE_LINE_SIZE 128
-#else
-#define CACHE_LINE_SIZE 64
-#endif
-#endif
-
-static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
-
-
-static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
-static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
-static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
-
-static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_F32] = {
-        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
-        .vec_dot_type             = GGML_TYPE_F32,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_F16] = {
-        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,
-        .vec_dot_type             = GGML_TYPE_F16,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q4_0] = {
-        .vec_dot                  = ggml_vec_dot_q4_0_q8_0,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-#if defined (__ARM_FEATURE_MATMUL_INT8)
-        .nrows                    = 2,
-#else
-        .nrows                    = 1,
-#endif
-    },
-    [GGML_TYPE_Q4_1] = {
-        .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
-        .vec_dot_type             = GGML_TYPE_Q8_1,
-#if defined (__ARM_FEATURE_MATMUL_INT8)
-        .nrows                    = 2,
-#else
-        .nrows                    = 1,
-#endif
-    },
-    [4] = { // GGML_TYPE_Q4_2
-        .vec_dot                  = NULL,
-        .vec_dot_type             = GGML_TYPE_COUNT,
-        .nrows                    = 1,
-    },
-    [5] = { // GGML_TYPE_Q4_3
-        .vec_dot                  = NULL,
-        .vec_dot_type             = GGML_TYPE_COUNT,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q5_0] = {
-        .vec_dot                  = ggml_vec_dot_q5_0_q8_0,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q5_1] = {
-        .vec_dot                  = ggml_vec_dot_q5_1_q8_1,
-        .vec_dot_type             = GGML_TYPE_Q8_1,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q8_0] = {
-        .from_float_to_mat        = quantize_mat_q8_0,
-        .vec_dot                  = ggml_vec_dot_q8_0_q8_0,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-#if defined (__ARM_FEATURE_MATMUL_INT8)
-        .nrows                    = 2,
-#else
-        .nrows                    = 1,
-#endif
-    },
-    [GGML_TYPE_Q8_1] = {
-        .vec_dot_type             = GGML_TYPE_Q8_1,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q2_K] = {
-        .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q3_K] = {
-        .vec_dot                  = ggml_vec_dot_q3_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q4_K] = {
-        .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q5_K] = {
-        .vec_dot                  = ggml_vec_dot_q5_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q6_K] = {
-        .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ2_XXS] = {
-        .vec_dot                  = ggml_vec_dot_iq2_xxs_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ2_XS] = {
-        .vec_dot                  = ggml_vec_dot_iq2_xs_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ3_XXS] = {
-        .vec_dot                  = ggml_vec_dot_iq3_xxs_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ3_S] = {
-        .vec_dot                  = ggml_vec_dot_iq3_s_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ2_S] = {
-        .vec_dot                  = ggml_vec_dot_iq2_s_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ1_S] = {
-        .vec_dot                  = ggml_vec_dot_iq1_s_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ1_M] = {
-        .vec_dot                  = ggml_vec_dot_iq1_m_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ4_NL] = {
-        .vec_dot                  = ggml_vec_dot_iq4_nl_q8_0,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_IQ4_XS] = {
-        .vec_dot                  = ggml_vec_dot_iq4_xs_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_BF16] = {
-        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_bf16,
-        .vec_dot_type             = GGML_TYPE_BF16,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_Q4_0_4_4] = {
-        .vec_dot                  = NULL,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-        .nrows                    = 1,
-        .ncols                    = 4,
-        .gemv                     = ggml_gemv_q4_0_4x4_q8_0,
-        .gemm                     = ggml_gemm_q4_0_4x4_q8_0,
-    },
-    [GGML_TYPE_Q4_0_4_8] = {
-        .vec_dot                  = NULL,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-        .nrows                    = 1,
-        .ncols                    = 4,
-        .gemv                     = ggml_gemv_q4_0_4x8_q8_0,
-        .gemm                     = ggml_gemm_q4_0_4x8_q8_0,
-    },
-    [GGML_TYPE_Q4_0_8_8] = {
-        .vec_dot                  = NULL,
-        .vec_dot_type             = GGML_TYPE_Q8_0,
-        .nrows                    = 1,
-        .ncols                    = 8,
-        .gemv                     = ggml_gemv_q4_0_8x8_q8_0,
-        .gemm                     = ggml_gemm_q4_0_8x8_q8_0,
-    },
-    [GGML_TYPE_TQ1_0] = {
-        .vec_dot                  = ggml_vec_dot_tq1_0_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-    [GGML_TYPE_TQ2_0] = {
-        .vec_dot                  = ggml_vec_dot_tq2_0_q8_K,
-        .vec_dot_type             = GGML_TYPE_Q8_K,
-        .nrows                    = 1,
-    },
-};
-
-const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
-    return &type_traits_cpu[type];
-}
-
-//
-// simd mappings
-//
-
-// we define a common set of C macros which map to specific intrinsics based on the current architecture
-// we then implement the fundamental computation operations below using only these macros
-// adding support for new architectures requires to define the corresponding SIMD macros
-//
-// GGML_F32_STEP / GGML_F16_STEP
-//   number of elements to process in a single step
-//
-// GGML_F32_EPR / GGML_F16_EPR
-//   number of elements to fit in a single register
-//
-
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
-
-#define GGML_SIMD
-
-// F32 NEON
-
-#define GGML_F32_STEP 16
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              float32x4_t
-#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
-#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
-#define GGML_F32x4_LOAD         vld1q_f32
-#define GGML_F32x4_STORE        vst1q_f32
-#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
-#define GGML_F32x4_ADD          vaddq_f32
-#define GGML_F32x4_MUL          vmulq_f32
-#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
-#define GGML_F32x4_REDUCE(res, x)                  \
-{                                                  \
-    int offset = GGML_F32_ARR >> 1;                \
-    for (int i = 0; i < offset; ++i) {             \
-        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
-    }                                              \
-    (res) = GGML_F32x4_REDUCE_ONE((x)[0]);         \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
-    #define GGML_F16_STEP 32
-    #define GGML_F16_EPR  8
-
-    #define GGML_F16x8              float16x8_t
-    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
-    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
-    #define GGML_F16x8_LOAD(x)      vld1q_f16((const ggml_fp16_internal_t *)(x))
-    #define GGML_F16x8_STORE        vst1q_f16
-    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
-    #define GGML_F16x8_ADD          vaddq_f16
-    #define GGML_F16x8_MUL          vmulq_f16
-    #define GGML_F16x8_REDUCE(res, x)                               \
-    do {                                                            \
-        int offset = GGML_F16_ARR >> 1;                             \
-        for (int i = 0; i < offset; ++i) {                          \
-            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
-        }                                                           \
-        offset >>= 1;                                               \
-        for (int i = 0; i < offset; ++i) {                          \
-            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
-        }                                                           \
-        offset >>= 1;                                               \
-        for (int i = 0; i < offset; ++i) {                          \
-            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
-        }                                                           \
-        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
-        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
-        (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
-    } while (0)
-
-    #define GGML_F16_VEC                GGML_F16x8
-    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
-    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
-    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i])
-    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
-    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
-    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
-    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE
-#else
-    // if FP16 vector arithmetic is not supported, we use FP32 instead
-    // and take advantage of the vcvt_ functions to convert to/from FP16
-
-    #define GGML_F16_STEP 16
-    #define GGML_F16_EPR  4
-
-    #define GGML_F32Cx4              float32x4_t
-    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
-    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
-    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
-    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
-    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
-    #define GGML_F32Cx4_ADD          vaddq_f32
-    #define GGML_F32Cx4_MUL          vmulq_f32
-    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
-
-    #define GGML_F16_VEC                GGML_F32Cx4
-    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
-    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
-    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
-    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
-    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
-    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
-    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
-    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
-#endif
-
-#elif defined(__AVX512F__)
-
-#define GGML_SIMD
-
-// F32 AVX512
-
-#define GGML_F32_STEP 64
-#define GGML_F32_EPR  16
-
-#define GGML_F32x16         __m512
-#define GGML_F32x16_ZERO    _mm512_setzero_ps()
-#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
-#define GGML_F32x16_LOAD    _mm512_loadu_ps
-#define GGML_F32x16_STORE   _mm512_storeu_ps
-// _mm512_fmadd_ps is defined in AVX512F so no guard is required
-#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
-#define GGML_F32x16_ADD     _mm512_add_ps
-#define GGML_F32x16_MUL     _mm512_mul_ps
-#define GGML_F32x16_REDUCE(res, x)                                    \
-do {                                                                  \
-    int offset = GGML_F32_ARR >> 1;                                   \
-    for (int i = 0; i < offset; ++i) {                                \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
-    }                                                                 \
-    offset >>= 1;                                                     \
-    for (int i = 0; i < offset; ++i) {                                \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
-    }                                                                 \
-    offset >>= 1;                                                     \
-    for (int i = 0; i < offset; ++i) {                                \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
-    }                                                                 \
-    res = _mm512_reduce_add_ps(x[0]);                                 \
-} while (0)
-
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x16
-#define GGML_F32_VEC_ZERO   GGML_F32x16_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x16_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x16_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x16_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x16_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x16_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x16_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
-
-// F16 AVX512
-
-// F16 AVX
-
-#define GGML_F16_STEP 64
-#define GGML_F16_EPR  16
-
-// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
-
-#define GGML_F32Cx16             __m512
-#define GGML_F32Cx16_ZERO        _mm512_setzero_ps()
-#define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x)
-
-// unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
-// so F16C guard isn't required
-#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
-#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
-
-#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
-#define GGML_F32Cx16_ADD         _mm512_add_ps
-#define GGML_F32Cx16_MUL         _mm512_mul_ps
-#define GGML_F32Cx16_REDUCE(res, x)                               \
-do {                                                              \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    res = _mm512_reduce_add_ps(x[0]);                             \
-} while (0)
-
-#define GGML_F16_VEC                GGML_F32Cx16
-#define GGML_F16_VEC_ZERO           GGML_F32Cx16_ZERO
-#define GGML_F16_VEC_SET1           GGML_F32Cx16_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx16_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F32Cx16_FMA
-#define GGML_F16_VEC_ADD            GGML_F32Cx16_ADD
-#define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE
-
-#elif defined(__AVX__)
-
-#define GGML_SIMD
-
-// F32 AVX
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  8
-
-#define GGML_F32x8         __m256
-#define GGML_F32x8_ZERO    _mm256_setzero_ps()
-#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
-#define GGML_F32x8_LOAD    _mm256_loadu_ps
-#define GGML_F32x8_STORE   _mm256_storeu_ps
-#if defined(__FMA__)
-    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
-#else
-    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
-#endif
-#define GGML_F32x8_ADD     _mm256_add_ps
-#define GGML_F32x8_MUL     _mm256_mul_ps
-#define GGML_F32x8_REDUCE(res, x)                                 \
-do {                                                              \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
-    }                                                             \
-    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
-                                 _mm256_extractf128_ps(x[0], 1)); \
-    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
-    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1));        \
-} while (0)
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x8
-#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
-
-// F16 AVX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  8
-
-// F16 arithmetic is not supported by AVX, so we use F32 instead
-
-#define GGML_F32Cx8             __m256
-#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
-#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
-
-#if defined(__F16C__)
-// the  _mm256_cvt intrinsics require F16C
-#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
-#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
-#else
-static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
-    float tmp[8];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return _mm256_loadu_ps(tmp);
-}
-static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
-    float arr[8];
-
-    _mm256_storeu_ps(arr, y);
-
-    for (int i = 0; i < 8; i++)
-        x[i] = GGML_FP32_TO_FP16(arr[i]);
-}
-#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
-#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
-#endif
-
-#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
-#define GGML_F32Cx8_ADD         _mm256_add_ps
-#define GGML_F32Cx8_MUL         _mm256_mul_ps
-#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
-
-#define GGML_F16_VEC                GGML_F32Cx8
-#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
-#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
-#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
-#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
-
-#elif defined(__POWER9_VECTOR__)
-
-#define GGML_SIMD
-
-// F32 POWER9
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              vector float
-#define GGML_F32x4_ZERO         0.0f
-#define GGML_F32x4_SET1         vec_splats
-#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
-#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
-#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
-#define GGML_F32x4_ADD          vec_add
-#define GGML_F32x4_MUL          vec_mul
-#define GGML_F32x4_REDUCE(res, x)              \
-{                                              \
-    int offset = GGML_F32_ARR >> 1;            \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    offset >>= 1;                              \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    offset >>= 1;                              \
-    for (int i = 0; i < offset; ++i) {         \
-        x[i] = vec_add(x[i], x[offset+i]);     \
-    }                                          \
-    res = vec_extract(x[0], 0) +               \
-          vec_extract(x[0], 1) +               \
-          vec_extract(x[0], 2) +               \
-          vec_extract(x[0], 3);                \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 POWER9
-#define GGML_F16_STEP       GGML_F32_STEP
-#define GGML_F16_EPR        GGML_F32_EPR
-#define GGML_F16_VEC        GGML_F32x4
-#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F16_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F16_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F16_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F16_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
-// Use vec_xl, not vec_ld, in case the load address is not aligned.
-#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
-  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
-  vec_extract_fp32_from_shortl(vec_xl(0, p))
-#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
-#define GGML_F16_VEC_STORE(p, r, i)                             \
-  if (i & 0x1)                                                  \
-    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
-                                   r[i - GGML_ENDIAN_BYTE(0)]), \
-            0, p - GGML_F16_EPR)
-
-#elif defined(__wasm_simd128__)
-
-#define GGML_SIMD
-
-// F32 WASM
-
-#define GGML_F32_STEP 16
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4              v128_t
-#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
-#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
-#define GGML_F32x4_LOAD         wasm_v128_load
-#define GGML_F32x4_STORE        wasm_v128_store
-#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
-#define GGML_F32x4_ADD          wasm_f32x4_add
-#define GGML_F32x4_MUL          wasm_f32x4_mul
-#define GGML_F32x4_REDUCE(res, x)                  \
-{                                                  \
-    int offset = GGML_F32_ARR >> 1;                \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    res = wasm_f32x4_extract_lane(x[0], 0) +       \
-          wasm_f32x4_extract_lane(x[0], 1) +       \
-          wasm_f32x4_extract_lane(x[0], 2) +       \
-          wasm_f32x4_extract_lane(x[0], 3);        \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 WASM
-
-#define GGML_F16_STEP 16
-#define GGML_F16_EPR  4
-
-inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
-    float tmp[4];
-
-    tmp[0] = GGML_FP16_TO_FP32(p[0]);
-    tmp[1] = GGML_FP16_TO_FP32(p[1]);
-    tmp[2] = GGML_FP16_TO_FP32(p[2]);
-    tmp[3] = GGML_FP16_TO_FP32(p[3]);
-
-    return wasm_v128_load(tmp);
-}
-
-inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
-    float tmp[4];
-
-    wasm_v128_store(tmp, x);
-
-    p[0] = GGML_FP32_TO_FP16(tmp[0]);
-    p[1] = GGML_FP32_TO_FP16(tmp[1]);
-    p[2] = GGML_FP32_TO_FP16(tmp[2]);
-    p[3] = GGML_FP32_TO_FP16(tmp[3]);
-}
-
-#define GGML_F16x4             v128_t
-#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
-#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
-#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
-#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
-#define GGML_F16x4_FMA         GGML_F32x4_FMA
-#define GGML_F16x4_ADD         wasm_f32x4_add
-#define GGML_F16x4_MUL         wasm_f32x4_mul
-#define GGML_F16x4_REDUCE(res, x)                  \
-{                                                  \
-    int offset = GGML_F16_ARR >> 1;                \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    offset >>= 1;                                  \
-    for (int i = 0; i < offset; ++i) {             \
-        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
-    }                                              \
-    res = wasm_f32x4_extract_lane(x[0], 0) +       \
-          wasm_f32x4_extract_lane(x[0], 1) +       \
-          wasm_f32x4_extract_lane(x[0], 2) +       \
-          wasm_f32x4_extract_lane(x[0], 3);        \
-}
-
-#define GGML_F16_VEC                GGML_F16x4
-#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO
-#define GGML_F16_VEC_SET1           GGML_F16x4_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F16x4_FMA
-#define GGML_F16_VEC_ADD            GGML_F16x4_ADD
-#define GGML_F16_VEC_MUL            GGML_F16x4_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
-
-#elif defined(__SSE3__)
-
-#define GGML_SIMD
-
-// F32 SSE
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4         __m128
-#define GGML_F32x4_ZERO    _mm_setzero_ps()
-#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
-#define GGML_F32x4_LOAD    _mm_loadu_ps
-#define GGML_F32x4_STORE   _mm_storeu_ps
-#if defined(__FMA__)
-    // TODO: Does this work?
-    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
-#else
-    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
-#endif
-#define GGML_F32x4_ADD     _mm_add_ps
-#define GGML_F32x4_MUL     _mm_mul_ps
-#define GGML_F32x4_REDUCE(res, x)                                 \
-{                                                                 \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
-    }                                                             \
-    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
-    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0));        \
-}
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 SSE
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  4
-
-static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
-    float tmp[4];
-
-    tmp[0] = GGML_FP16_TO_FP32(x[0]);
-    tmp[1] = GGML_FP16_TO_FP32(x[1]);
-    tmp[2] = GGML_FP16_TO_FP32(x[2]);
-    tmp[3] = GGML_FP16_TO_FP32(x[3]);
-
-    return _mm_loadu_ps(tmp);
-}
-
-static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
-    float arr[4];
-
-    _mm_storeu_ps(arr, y);
-
-    x[0] = GGML_FP32_TO_FP16(arr[0]);
-    x[1] = GGML_FP32_TO_FP16(arr[1]);
-    x[2] = GGML_FP32_TO_FP16(arr[2]);
-    x[3] = GGML_FP32_TO_FP16(arr[3]);
-}
-
-#define GGML_F32Cx4             __m128
-#define GGML_F32Cx4_ZERO        _mm_setzero_ps()
-#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)
-#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)
-#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
-#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
-#define GGML_F32Cx4_ADD         _mm_add_ps
-#define GGML_F32Cx4_MUL         _mm_mul_ps
-#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
-
-#define GGML_F16_VEC                 GGML_F32Cx4
-#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
-#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
-#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
-#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
-#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
-#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
-
-#elif defined(__loongarch_asx)
-
-#define GGML_SIMD
-
-// F32 LASX
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  8
-
-#define GGML_F32x8         __m256
-#define GGML_F32x8_ZERO    (__m256)__lasx_xvldi(0)
-#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
-#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
-#define GGML_F32x8_STORE(x,y)   __lasx_xvst((y), (x), 0)
-#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
-#define GGML_F32x8_ADD     __lasx_xvfadd_s
-#define GGML_F32x8_MUL     __lasx_xvfmul_s
-#define GGML_F32x8_REDUCE(res, x)                                 \
-do {                                                              \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
-    }                                                             \
-    float *tmp_p = (float *)&x[0]; \
-    res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7];  \
-} while (0)
-// TODO: is this optimal ?
-
-#define GGML_F32_VEC        GGML_F32x8
-#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
-
-// F16 LASX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  8
-
-// F16 arithmetic is not supported by AVX, so we use F32 instead
-
-#define GGML_F32Cx8          __m256
-#define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
-#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
-
-static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
-    float tmp[8];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return (__m256)__lasx_xvld(tmp, 0);
-}
-static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
-    float arr[8];
-
-    __lasx_xvst(y, arr, 0);
-
-    for (int i = 0; i < 8; i++) {
-        x[i] = GGML_FP32_TO_FP16(arr[i]);
-    }
-}
-#define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)
-#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
-
-#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
-#define GGML_F32Cx8_ADD         __lasx_xvfadd_s
-#define GGML_F32Cx8_MUL         __lasx_xvfmul_s
-#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
-
-#define GGML_F16_VEC                GGML_F32Cx8
-#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
-#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
-#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
-#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
-#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
-#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
-#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
-
-#elif defined(__loongarch_sx)
-
-#define GGML_SIMD
-
-// F32 LSX
-
-#define GGML_F32_STEP 32
-#define GGML_F32_EPR  4
-
-#define GGML_F32x4         __m128
-#define GGML_F32x4_ZERO    __lsx_vldi(0)
-#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
-#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
-#define GGML_F32x4_STORE((x),(y))   __lsx_vst((y), (x), 0)
-#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
-#define GGML_F32x4_ADD     __lsx_vfadd_s
-#define GGML_F32x4_MUL     __lsx_vfmul_s
-#define GGML_F32x4_REDUCE(res, x)                                 \
-{                                                                 \
-    int offset = GGML_F32_ARR >> 1;                               \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lsx_vfadd_s(x[i], x[offset+i]);                     \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lsx_vfadd_s(x[i], x[offset+i]);                     \
-    }                                                             \
-    offset >>= 1;                                                 \
-    for (int i = 0; i < offset; ++i) {                            \
-        x[i] = __lsx_vfadd_s(x[i], x[offset+i]);                     \
-    }                                                             \
-    __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
-    tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
-    tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
-    const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
-    tmp = __lsx_vsrli_d((__m128i)t0, 32); \
-    tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
-    tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
-    res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0);        \
-}
-
-#define GGML_F32_VEC        GGML_F32x4
-#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
-#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
-#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
-#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
-#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
-#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
-#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
-#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
-
-// F16 LSX
-
-#define GGML_F16_STEP 32
-#define GGML_F16_EPR  4
-
-static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
-    float tmp[4];
-
-    tmp[0] = GGML_FP16_TO_FP32(x[0]);
-    tmp[1] = GGML_FP16_TO_FP32(x[1]);
-    tmp[2] = GGML_FP16_TO_FP32(x[2]);
-    tmp[3] = GGML_FP16_TO_FP32(x[3]);
-
-    return __lsx_vld(tmp, 0);
-}
-
-static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
-    float arr[4];
-
-    __lsx_vst(y, arr, 0);
-
-    x[0] = GGML_FP32_TO_FP16(arr[0]);
-    x[1] = GGML_FP32_TO_FP16(arr[1]);
-    x[2] = GGML_FP32_TO_FP16(arr[2]);
-    x[3] = GGML_FP32_TO_FP16(arr[3]);
-}
-
-#define GGML_F32Cx4             __m128
-#define GGML_F32Cx4_ZERO        __lsx_vldi(0)
-#define GGML_F32Cx4_SET1(x)     __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
-#define GGML_F32Cx4_LOAD(x)     __lsx_f16x4_load(x)
-#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
-#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
-#define GGML_F32Cx4_ADD         __lsx_vfadd_s
-#define GGML_F32Cx4_MUL         __lsx_vfmul_s
-#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
-
-#define GGML_F16_VEC                 GGML_F32Cx4
-#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
-#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
-#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
-#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
-#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
-#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
-#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
-#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
-
-#endif
-
-// GGML_F32_ARR / GGML_F16_ARR
-//   number of registers to use per step
-#ifdef GGML_SIMD
-#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
-#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
-#endif
-
-//
-// Threading defs
-//
-
-typedef pthread_t          ggml_thread_t;
-
-#if defined(_WIN32)
-
-typedef CONDITION_VARIABLE ggml_cond_t;
-typedef SRWLOCK            ggml_mutex_t;
-
-#define ggml_mutex_init(m)   InitializeSRWLock(m)
-#define ggml_mutex_destroy(m)
-#define ggml_mutex_lock(m)   AcquireSRWLockExclusive(m)
-#define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)
-#define ggml_mutex_lock_shared(m)   AcquireSRWLockShared(m)
-#define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)
-
-#define ggml_cond_init(c)    InitializeConditionVariable(c)
-#define ggml_cond_destroy(c)
-#define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)
-#define ggml_cond_broadcast(c) WakeAllConditionVariable(c)
-
-#define ggml_thread_create pthread_create
-#define ggml_thread_join   pthread_join
-
-#else
-
-typedef pthread_cond_t     ggml_cond_t;
-typedef pthread_mutex_t    ggml_mutex_t;
-
-#define ggml_mutex_init(m)          pthread_mutex_init(m, NULL)
-#define ggml_mutex_destroy(m)       pthread_mutex_destroy(m)
-#define ggml_mutex_lock(m)          pthread_mutex_lock(m)
-#define ggml_mutex_unlock(m)        pthread_mutex_unlock(m)
-#define ggml_mutex_lock_shared(m)   pthread_mutex_lock(m)
-#define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)
-
-#define ggml_lock_init(x)    UNUSED(x)
-#define ggml_lock_destroy(x) UNUSED(x)
-#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
-#define ggml_lock_lock(x)    _mm_pause()
-#else
-#define ggml_lock_lock(x)    UNUSED(x)
-#endif
-#define ggml_lock_unlock(x)  UNUSED(x)
-
-#define GGML_LOCK_INITIALIZER 0
-#define ggml_cond_init(c)      pthread_cond_init(c, NULL)
-#define ggml_cond_destroy(c)   pthread_cond_destroy(c)
-#define ggml_cond_wait(c, m)   pthread_cond_wait(c, m)
-#define ggml_cond_broadcast(c) pthread_cond_broadcast(c)
-
-#define ggml_thread_create pthread_create
-#define ggml_thread_join   pthread_join
-
-#endif
-
-// Threadpool def
-struct ggml_threadpool {
-    ggml_mutex_t mutex;       // mutex for cond.var
-    ggml_cond_t  cond;        // cond.var for waiting for new work
-
-    struct ggml_cgraph * cgraph;
-    struct ggml_cplan  * cplan;
-
-    // synchronization primitives
-    atomic_int n_graph;       // incremented when there is work to be done (i.e each graph)
-    atomic_int GGML_CACHE_ALIGN n_barrier;
-    atomic_int GGML_CACHE_ALIGN n_barrier_passed;
-    atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
-
-    // these are atomic as an annotation for thread-sanitizer
-    atomic_bool stop;         // Used for stopping the threadpool altogether
-    atomic_bool pause;        // Used for pausing the threadpool or individual threads
-    atomic_bool abort;        // Used for aborting processing of a graph
-
-    struct ggml_compute_state * workers;   // per thread state
-    int          n_threads_max; // number of threads in the pool
-    atomic_int   n_threads_cur; // number of threads used in the current graph
-
-    int32_t      prio;        // Scheduling priority
-    uint32_t     poll;        // Polling level (0 - no polling)
-
-    enum ggml_status ec;
-};
-
-// Per-thread state
-struct ggml_compute_state {
-#ifndef GGML_USE_OPENMP
-    ggml_thread_t thrd;
-    bool cpumask[GGML_MAX_N_THREADS];
-    int  last_graph;
-    bool pending;
-#endif
-    struct ggml_threadpool * threadpool;
-    int ith;
-};
-
-struct ggml_compute_params {
-    // ith = thread index, nth = number of threads
-    int ith, nth;
-
-    // work buffer for all threads
-    size_t wsize;
-    void * wdata;
-
-    struct ggml_threadpool * threadpool;
-};
-
-//
-// fundamental operations
-//
-
-inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
-inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
-inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
-inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
-inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
-inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
-inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
-inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
-inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
-inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
-
-static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
-   assert(nrc == 1);
-   UNUSED(nrc);
-   UNUSED(bx);
-   UNUSED(by);
-   UNUSED(bs);
-
-#if defined(GGML_SIMD)
-    float sumf = 0.0f;
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
-
-    GGML_F32_VEC ax[GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-
-            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
-        }
-    }
-
-    // reduce sum0..sum3 to sum0
-    GGML_F32_VEC_REDUCE(sumf, sum);
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        sumf += x[i]*y[i];
-    }
-#else
-    // scalar
-    ggml_float sumf = 0.0;
-    for (int i = 0; i < n; ++i) {
-        sumf += (ggml_float)(x[i]*y[i]);
-    }
-#endif
-
-    *s = sumf;
-}
-
-static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
-    assert(nrc == 1);
-    UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
-    UNUSED(bs);
-    int i = 0;
-    ggml_float sumf = 0;
-
-#if defined(__AVX512BF16__)
-    __m512 c1 = _mm512_setzero_ps();
-    __m512 c2 = _mm512_setzero_ps();
-    for (; i + 64 <= n; i += 64) {
-        c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
-                             m512bh(_mm512_loadu_si512((y + i))));
-        c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
-                             m512bh(_mm512_loadu_si512((y + i + 32))));
-    }
-    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
-    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
-
-#elif defined(__AVX512F__)
-#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
-    __m512 c1 = _mm512_setzero_ps();
-    __m512 c2 = _mm512_setzero_ps();
-    for (; i + 32 <= n; i += 32) {
-        c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
-        c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
-    }
-    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
-    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
-
-#undef LOAD
-#elif defined(__AVX2__)
-#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
-    __m256 c1 = _mm256_setzero_ps();
-    __m256 c2 = _mm256_setzero_ps();
-    __m256 c3 = _mm256_setzero_ps();
-    __m256 c4 = _mm256_setzero_ps();
-    for (; i + 32 <= n; i += 32) {
-        c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
-        c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
-        c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
-        c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
-    }
-    __m128 g;
-    c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
-                       _mm256_add_ps(c2, c4));
-    g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
-                   _mm256_castps256_ps128(c1));
-    g = _mm_add_ps(g, _mm_movehl_ps(g, g));
-    g = _mm_add_ss(g, _mm_movehdup_ps(g));
-    sumf += (ggml_float)_mm_cvtss_f32(g);
-
-#undef LOAD
-#endif
-
-    for (; i < n; ++i) {
-        sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
-                             GGML_BF16_TO_FP32(y[i]));
-    }
-    *s = sumf;
-}
-
-static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
-    assert(nrc == 1);
-    UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
-    UNUSED(bs);
-
-    ggml_float sumf = 0.0;
-
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
-
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-
-            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
-        }
-    }
-
-    // reduce sum0..sum3 to sum0
-    GGML_F16_VEC_REDUCE(sumf, sum);
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
-    }
-#else
-    for (int i = 0; i < n; ++i) {
-        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
-    }
-#endif
-
-    *s = sumf;
-}
-
-// compute GGML_VEC_DOT_UNROLL dot products at once
-// xs - x row stride in bytes
-inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
-    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
-
-    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
-
-    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
-        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
-    }
-
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
-
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-
-            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
-                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
-
-                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
-            }
-        }
-    }
-
-    // reduce sum0..sum3 to sum0
-    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
-        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
-        }
-    }
-#else
-    for (int i = 0; i < n; ++i) {
-        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
-        }
-    }
-#endif
-
-    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
-        s[i] = sumf[i];
-    }
-}
-
-inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
-
-    GGML_F32_VEC ax[GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
-
-            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
-#else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] += x[i]*v;
-    }
-#endif
-}
-
-inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
-
-    GGML_F16_VEC ax[GGML_F16_ARR];
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
-
-            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
-    }
-#else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
-    }
-#endif
-}
-
-// xs and vs are byte strides of x and v
-inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
-
-    const float * restrict x[GGML_VEC_MAD_UNROLL];
-    const float * restrict v[GGML_VEC_MAD_UNROLL];
-
-    for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
-        x[i] = (const float *) ((const char *) xv + i*xs);
-        v[i] = (const float *) ((const char *) vv + i*vs);
-    }
-
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
-
-    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
-        vx[k] = GGML_F32_VEC_SET1(v[k][0]);
-    }
-
-    GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-
-            for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
-                ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
-                ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
-            }
-
-            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
-        }
-    }
-
-    // leftovers
-    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
-        for (int i = np; i < n; ++i) {
-            y[i] += x[k][i]*v[k][0];
-        }
-    }
-#else
-    // scalar
-    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
-        for (int i = 0; i < n; ++i) {
-            y[i] += x[k][i]*v[k][0];
-        }
-    }
-#endif
-}
-
-//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
-inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
-#if defined(GGML_USE_ACCELERATE)
-    vDSP_vsmul(y, 1, &v, y, 1, n);
-#elif defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F32_STEP - 1));
-
-    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
-
-    GGML_F32_VEC ay[GGML_F32_ARR];
-
-    for (int i = 0; i < np; i += GGML_F32_STEP) {
-        for (int j = 0; j < GGML_F32_ARR; j++) {
-            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
-            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
-
-            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] *= v;
-    }
-#else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] *= v;
-    }
-#endif
-}
-
-inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
-#if defined(GGML_SIMD)
-    const int np = (n & ~(GGML_F16_STEP - 1));
-
-    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
-
-    GGML_F16_VEC ay[GGML_F16_ARR];
-
-    for (int i = 0; i < np; i += GGML_F16_STEP) {
-        for (int j = 0; j < GGML_F16_ARR; j++) {
-            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
-            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
-
-            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
-        }
-    }
-
-    // leftovers
-    for (int i = np; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
-    }
-#else
-    // scalar
-    for (int i = 0; i < n; ++i) {
-        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
-    }
-#endif
-}
-
-inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
-inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
-inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
-inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
-inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
-inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
-inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
-inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
-inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
-inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
-inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
-inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
-inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
-inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
-// TODO: optimize performance
-inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
-inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
-inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
-
-static const float GELU_COEF_A     = 0.044715f;
-static const float GELU_QUICK_COEF = -1.702f;
-static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
-
-inline static float ggml_gelu_f32(float x) {
-    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
-    const uint16_t * i16 = (const uint16_t *) x;
-    for (int i = 0; i < n; ++i) {
-        y[i] = ggml_table_gelu_f16[i16[i]];
-    }
-}
-
-#ifdef GGML_GELU_FP16
-inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
-    uint16_t t;
-    for (int i = 0; i < n; ++i) {
-        if (x[i] <= -10.0f) {
-            y[i] = 0.0f;
-        } else if (x[i] >= 10.0f) {
-            y[i] = x[i];
-        } else {
-            ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
-            memcpy(&t, &fp16, sizeof(uint16_t));
-            y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
-        }
-    }
-}
-#else
-inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
-    for (int i = 0; i < n; ++i) {
-        y[i] = ggml_gelu_f32(x[i]);
-    }
-}
-#endif
-
-inline static float ggml_gelu_quick_f32(float x) {
-    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
-}
-
-//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
-//    const uint16_t * i16 = (const uint16_t *) x;
-//    for (int i = 0; i < n; ++i) {
-//        y[i] = ggml_table_gelu_quick_f16[i16[i]];
-//    }
-//}
-
-#ifdef GGML_GELU_QUICK_FP16
-inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
-    uint16_t t;
-    for (int i = 0; i < n; ++i) {
-        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
-        memcpy(&t, &fp16, sizeof(uint16_t));
-        y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
-    }
-}
-#else
-inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
-    for (int i = 0; i < n; ++i) {
-        y[i] = ggml_gelu_quick_f32(x[i]);
-    }
-}
-#endif
-
-// Sigmoid Linear Unit (SiLU) function
-inline static float ggml_silu_f32(float x) {
-    return x/(1.0f + expf(-x));
-}
-
-#if __FINITE_MATH_ONLY__
-#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
-#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
-#endif
-
-#if defined(__ARM_NEON) && defined(__aarch64__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static float32x4_t ggml_v_expf(float32x4_t x) {
-    const float32x4_t r = vdupq_n_f32(0x1.8p23f);
-    const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
-    const float32x4_t n = vsubq_f32(z, r);
-    const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
-                                    vdupq_n_f32(0x1.7f7d1cp-20f));
-    const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
-    const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
-    const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
-    const float32x4_t u = vmulq_f32(b, b);
-    const float32x4_t j = vfmaq_f32(
-        vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
-        vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
-                  vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
-    if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
-        return vfmaq_f32(k, j, k);
-    const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
-    const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
-    const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
-    return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
-                     vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static float32x4_t ggml_v_silu(float32x4_t x) {
-    const float32x4_t one = vdupq_n_f32(1.0f);
-    const float32x4_t zero = vdupq_n_f32(0.0f);
-    const float32x4_t neg_x = vsubq_f32(zero, x);
-    const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
-    const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
-    return vdivq_f32(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__AVX512F__) && defined(__AVX512DQ__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m512 ggml_v_expf(__m512 x) {
-  const __m512 r = _mm512_set1_ps(0x1.8p23f);
-  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
-  const __m512 n = _mm512_sub_ps(z, r);
-  const __m512 b =
-      _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
-                       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
-  const __mmask16 d =
-      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
-  const __m512 u = _mm512_mul_ps(b, b);
-  const __m512 j = _mm512_fmadd_ps(
-      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
-                                      _mm512_set1_ps(0x1.573e2ep-5f)),
-                      u,
-                      _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
-                                      _mm512_set1_ps(0x1.fffdb6p-2f))),
-      u,
-      _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
-  const __m512 res = _mm512_scalef_ps(j, n);
-  if (_mm512_kortestz(d, d))
-    return res;
-  const __m512 zero = _mm512_setzero_ps();
-  const __m512 alt = _mm512_mask_blend_ps(
-      _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
-  return _mm512_mask_blend_ps(d, res, alt);
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m512 ggml_v_silu(__m512 x) {
-    const __m512 one = _mm512_set1_ps(1);
-    const __m512 zero = _mm512_setzero_ps();
-    const __m512 neg_x = _mm512_sub_ps(zero, x);
-    const __m512 exp_neg_x = ggml_v_expf(neg_x);
-    const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
-    return _mm512_div_ps(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__AVX2__) && defined(__FMA__)
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m256 ggml_v_expf(__m256 x) {
-  const __m256 r = _mm256_set1_ps(0x1.8p23f);
-  const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
-  const __m256 n = _mm256_sub_ps(z, r);
-  const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
-                                    _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
-  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
-  const __m256 k = _mm256_castsi256_ps(
-      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
-  const __m256i c = _mm256_castps_si256(
-      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
-                    _mm256_set1_ps(126), _CMP_GT_OQ));
-  const __m256 u = _mm256_mul_ps(b, b);
-  const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
-                                                                   _mm256_set1_ps(0x1.573e2ep-5f)), u,
-                                                   _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
-                                                                   _mm256_set1_ps(0x1.fffdb6p-2f))),
-                                   u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
-  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
-    return _mm256_fmadd_ps(j, k, k);
-  const __m256i g = _mm256_and_si256(
-      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
-      _mm256_set1_epi32(0x82000000u));
-  const __m256 s1 =
-      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
-  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
-  const __m256i d = _mm256_castps_si256(
-      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
-                    _mm256_set1_ps(192), _CMP_GT_OQ));
-  return _mm256_or_ps(
-      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
-      _mm256_andnot_ps(
-          _mm256_castsi256_ps(d),
-          _mm256_or_ps(
-              _mm256_and_ps(_mm256_castsi256_ps(c),
-                            _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
-              _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m256 ggml_v_silu(__m256 x) {
-    const __m256 one = _mm256_set1_ps(1);
-    const __m256 zero = _mm256_setzero_ps();
-    const __m256 neg_x = _mm256_sub_ps(zero, x);
-    const __m256 exp_neg_x = ggml_v_expf(neg_x);
-    const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
-    return _mm256_div_ps(x, one_plus_exp_neg_x);
-}
-
-#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
-
-#if defined(__FMA__)
-#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
-#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
-#else
-#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
-#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
-#endif
-
-// adapted from arm limited optimized routine
-// the maximum error is 1.45358 plus 0.5 ulps
-// numbers above 88.38 will flush to infinity
-// numbers beneath -103.97 will flush to zero
-inline static __m128 ggml_v_expf(__m128 x) {
-    const __m128 r = _mm_set1_ps(0x1.8p23f);
-    const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
-    const __m128 n = _mm_sub_ps(z, r);
-    const __m128 b =
-        NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
-    const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
-    const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
-    const __m128i c =
-        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
-    const __m128 u = _mm_mul_ps(b, b);
-    const __m128 j =
-        MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
-                        MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
-                u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
-    if (!_mm_movemask_epi8(c))
-        return MADD128(j, k, k);
-    const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
-                                    _mm_set1_epi32(0x82000000u));
-    const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
-    const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
-    const __m128i d =
-        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
-    return _mm_or_ps(
-        _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
-        _mm_andnot_ps(_mm_castsi128_ps(d),
-                      _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
-                                _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
-}
-
-// computes silu x/(1+exp(-x)) in single precision vector
-inline static __m128 ggml_v_silu(__m128 x) {
-    const __m128 one = _mm_set1_ps(1);
-    const __m128 zero = _mm_setzero_ps();
-    const __m128 neg_x = _mm_sub_ps(zero, x);
-    const __m128 exp_neg_x = ggml_v_expf(neg_x);
-    const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
-    return _mm_div_ps(x, one_plus_exp_neg_x);
-}
-
-#endif // __ARM_NEON / __AVX2__ / __SSE2__
-
-static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
-    int i = 0;
-#if defined(__AVX512F__) && defined(__AVX512DQ__)
-    for (; i + 15 < n; i += 16) {
-        _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
-    }
-#elif defined(__AVX2__) && defined(__FMA__)
-    for (; i + 7 < n; i += 8) {
-        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
-    }
-#elif defined(__SSE2__)
-    for (; i + 3 < n; i += 4) {
-        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
-    }
-#elif defined(__ARM_NEON) && defined(__aarch64__)
-    for (; i + 3 < n; i += 4) {
-        vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
-    }
-#endif
-    for (; i < n; ++i) {
-        y[i] = ggml_silu_f32(x[i]);
-    }
-}
-
-static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
-    int i = 0;
-    ggml_float sum = 0;
-#if defined(__AVX512F__) && defined(__AVX512DQ__)
-    for (; i + 15 < n; i += 16) {
-        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
-                                               _mm512_set1_ps(max)));
-        _mm512_storeu_ps(y + i, val);
-        sum += (ggml_float)_mm512_reduce_add_ps(val);
-    }
-#elif defined(__AVX2__) && defined(__FMA__)
-    for (; i + 7 < n; i += 8) {
-        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
-                                               _mm256_set1_ps(max)));
-        _mm256_storeu_ps(y + i, val);
-        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
-                                 _mm256_castps256_ps128(val));
-        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
-        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
-        sum += (ggml_float)_mm_cvtss_f32(val2);
-    }
-#elif defined(__SSE2__)
-    for (; i + 3 < n; i += 4) {
-        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
-                                            _mm_set1_ps(max)));
-        _mm_storeu_ps(y + i, val);
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
-        val = _mm_add_ss(val, _mm_movehdup_ps(val));
-#else
-        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
-        val = _mm_add_ps(val, tmp);
-        tmp = _mm_movehl_ps(tmp, val);
-        val = _mm_add_ss(val, tmp);
-#endif
-        sum += (ggml_float)_mm_cvtss_f32(val);
-    }
-#elif defined(__ARM_NEON) && defined(__aarch64__)
-    for (; i + 3 < n; i += 4) {
-        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
-                                                vdupq_n_f32(max)));
-        vst1q_f32(y + i, val);
-        sum += (ggml_float)vaddvq_f32(val);
-    }
-#endif
-    for (; i < n; ++i) {
-        float val = expf(x[i] - max);
-        sum += (ggml_float)val;
-        y[i] = val;
-    }
-    return sum;
-}
-
-static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
-    // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
-
-    int i = 0;
-    ggml_float sum = 0;
-    for (; i < n; ++i) {
-        float val = x[i] - max;
-        y[i] = val;
-        sum += (ggml_float)expf(val);
-    }
-    return sum = (ggml_float)logf(sum);
-}
-
-inline static float ggml_silu_backward_f32(float x, float dy) {
-    const float s = 1.0f/(1.0f + expf(-x));
-    return dy*s*(1.0f + x*(1.0f - s));
-}
-
-inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
-    for (int i = 0; i < n; ++i) {
-        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
-    }
-}
-
-inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
-#ifndef GGML_USE_ACCELERATE
-    ggml_float sum = 0.0;
-    for (int i = 0; i < n; ++i) {
-        sum += (ggml_float)x[i];
-    }
-    *s = sum;
-#else
-    vDSP_sve(x, 1, s, n);
-#endif
-}
-
-inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
-    ggml_float sum = 0.0;
-    for (int i = 0; i < n; ++i) {
-        sum += (ggml_float)x[i];
-    }
-    *s = sum;
-}
-
-inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
-    float sum = 0.0f;
-    for (int i = 0; i < n; ++i) {
-        sum += GGML_FP16_TO_FP32(x[i]);
-    }
-    *s = sum;
-}
-
-inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
-    float sum = 0.0f;
-    for (int i = 0; i < n; ++i) {
-        sum += GGML_BF16_TO_FP32(x[i]);
-    }
-    *s = sum;
-}
-
-inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
-#ifndef GGML_USE_ACCELERATE
-    float max = -INFINITY;
-    for (int i = 0; i < n; ++i) {
-        max = MAX(max, x[i]);
-    }
-    *s = max;
-#else
-    vDSP_maxv(x, 1, s, n);
-#endif
-}
-
-inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
-    ggml_vec_norm_f32(n, s, x);
-    *s = 1.f/(*s);
-}
-
-inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
-    float max = -INFINITY;
-    int idx = 0;
-    for (int i = 0; i < n; ++i) {
-        max = MAX(max, x[i]);
-        if (max == x[i]) { idx = i; }
-    }
-    *s = idx;
-}
-
-// Helpers for polling loops
-#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
-static inline void ggml_thread_cpu_relax(void) {
-    __asm__ volatile("yield" ::: "memory");
-}
-#elif defined(__x86_64__)
-static inline void ggml_thread_cpu_relax(void) {
-    _mm_pause();
-}
-#else
-static inline void ggml_thread_cpu_relax(void) {;}
-#endif
-
-//
-// NUMA support
-//
-
-#define GGML_NUMA_MAX_NODES 8
-#define GGML_NUMA_MAX_CPUS 512
-
-struct ggml_numa_node {
-    uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
-    uint32_t n_cpus;
-};
-
-struct ggml_numa_nodes {
-    enum ggml_numa_strategy numa_strategy;
-    struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
-    uint32_t n_nodes;
-    uint32_t total_cpus; // hardware threads on system
-    uint32_t current_node; // node on which main process is execting
-#if defined(__gnu_linux__)
-    cpu_set_t cpuset; // cpuset from numactl
-#else
-    uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
-#endif
-};
-
-//
-// ggml state
-//
-
-struct ggml_state {
-    struct ggml_numa_nodes numa;
-};
-
-// global state
-static struct ggml_state g_state = {0};
-static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
-
-// TODO: move to threading file
-// critical section via spin lock
-void ggml_critical_section_start(void) {
-    while (atomic_flag_test_and_set(&g_state_critical)) {
-        // spin
-        sched_yield();
-    }
-}
-
-void ggml_critical_section_end(void) {
-    atomic_flag_clear(&g_state_critical);
-}
-
-static void ggml_barrier(struct ggml_threadpool * tp) {
-    int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
-    if (n_threads == 1) {
-        return;
-    }
-
-#ifdef GGML_USE_OPENMP
-    #pragma omp barrier
-#else
-    int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
-
-    // enter barrier (full seq-cst fence)
-    int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
-
-    if (n_barrier == (n_threads - 1)) {
-        // last thread
-        atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
-
-        // exit barrier (fill seq-cst fence)
-        atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
-        return;
-    }
-
-    // wait for other threads
-    while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
-        ggml_thread_cpu_relax();
-    }
-
-    // exit barrier (full seq-cst fence)
-    // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
-    #ifdef GGML_TSAN_ENABLED
-    atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
-    #else
-    atomic_thread_fence(memory_order_seq_cst);
-    #endif
-#endif
-}
-
-#if defined(__gnu_linux__)
-static cpu_set_t ggml_get_numa_affinity(void) {
-    cpu_set_t cpuset;
-    pthread_t thread;
-    thread = pthread_self();
-    CPU_ZERO(&cpuset);
-    pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
-    return cpuset;
-}
-#else
-static uint32_t ggml_get_numa_affinity(void) {
-    return 0; // no NUMA support
-}
-#endif
-
-void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
-    if (g_state.numa.n_nodes > 0) {
-        fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
-
-        return;
-    }
-
-#if defined(__gnu_linux__)
-    struct stat st;
-    char path[256];
-    int rv;
-
-    // set numa scheme
-    g_state.numa.numa_strategy = numa_flag;
-
-    GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy);
-
-    g_state.numa.cpuset = ggml_get_numa_affinity();
-
-    // enumerate nodes
-    while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
-        rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
-        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
-        if (stat(path, &st) != 0) { break; }
-        ++g_state.numa.n_nodes;
-    }
-
-    // enumerate CPUs
-    while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
-        rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
-        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
-        if (stat(path, &st) != 0) { break; }
-        ++g_state.numa.total_cpus;
-    }
-
-    GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
-
-    // figure out which node we're on
-    uint current_cpu;
-    int getcpu_ret = 0;
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
-    getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
-#else
-    // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
-#   if !defined(SYS_getcpu) && defined(SYS_get_cpu)
-#       define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
-#   endif
-    getcpu_ret = syscall(SYS_getcpu, &current_cpu, &g_state.numa.current_node);
-#endif
-
-    if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
-        g_state.numa.n_nodes = 0;
-        return;
-    }
-
-    GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu);
-
-    for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
-        struct ggml_numa_node * node = &g_state.numa.nodes[n];
-        GGML_PRINT_DEBUG("CPUs on node %u:", n);
-        node->n_cpus = 0;
-        for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
-            rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
-            GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
-            if (stat(path, &st) == 0) {
-                node->cpus[node->n_cpus++] = c;
-                GGML_PRINT_DEBUG(" %u", c);
-            }
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-
-    if (ggml_is_numa()) {
-        FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
-        if (fptr != NULL) {
-            char buf[42];
-            if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
-                GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
-            }
-            fclose(fptr);
-        }
-    }
-#else
-    UNUSED(numa_flag);
-    // TODO
-#endif
-}
-
-bool ggml_is_numa(void) {
-    return g_state.numa.n_nodes > 1;
-}
-
-#if defined(__ARM_ARCH)
-
-#if defined(__linux__) && defined(__aarch64__)
-#include <sys/auxv.h>
-#elif defined(__APPLE__)
-#include <sys/sysctl.h>
-#endif
-
-#if !defined(HWCAP2_I8MM)
-#define HWCAP2_I8MM 0
-#endif
-
-static void ggml_init_arm_arch_features(void) {
-#if defined(__linux__) && defined(__aarch64__)
-    uint32_t hwcap = getauxval(AT_HWCAP);
-    uint32_t hwcap2 = getauxval(AT_HWCAP2);
-
-    ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
-    ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
-    ggml_arm_arch_features.has_sve  = !!(hwcap & HWCAP_SVE);
-
-#if defined(__ARM_FEATURE_SVE)
-    ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
-#endif
-#elif defined(__APPLE__)
-    int oldp = 0;
-    size_t size = sizeof(oldp);
-    if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
-        oldp = 0;
-    }
-    ggml_arm_arch_features.has_neon = oldp;
-
-    if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
-        oldp = 0;
-    }
-    ggml_arm_arch_features.has_i8mm = oldp;
-
-    ggml_arm_arch_features.has_sve = 0;
-    ggml_arm_arch_features.sve_cnt = 0;
-#else
-// Run-time CPU feature detection not implemented for this platform, fallback to compile time
-#if defined(__ARM_NEON)
-    ggml_arm_arch_features.has_neon = 1;
-#else
-    ggml_arm_arch_features.has_neon = 0;
-#endif
-
-#if defined(__ARM_FEATURE_MATMUL_INT8)
-    ggml_arm_arch_features.has_i8mm = 1;
-#else
-    ggml_arm_arch_features.has_i8mm = 0;
-#endif
-
-#if defined(__ARM_FEATURE_SVE)
-    ggml_arm_arch_features.has_sve = 1;
-    ggml_arm_arch_features.sve_cnt = 16;
-#else
-    ggml_arm_arch_features.has_sve = 0;
-    ggml_arm_arch_features.sve_cnt = 0;
-#endif
-#endif
-}
-#endif
-
-struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
-    GGML_ASSERT(!ggml_get_no_alloc(ctx));
-
-    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-
-    ggml_set_i32(result, value);
-
-    return result;
-}
-
-struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
-    GGML_ASSERT(!ggml_get_no_alloc(ctx));
-
-    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
-
-    ggml_set_f32(result, value);
-
-    return result;
-}
-
-struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
-    const int n     = ggml_nrows(tensor);
-    const int nc    = tensor->ne[0];
-    const size_t n1 = tensor->nb[1];
-
-    char * const data = tensor->data;
-
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            {
-                assert(tensor->nb[0] == sizeof(int8_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
-                }
-            } break;
-        case GGML_TYPE_I16:
-            {
-                assert(tensor->nb[0] == sizeof(int16_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
-                }
-            } break;
-        case GGML_TYPE_I32:
-            {
-                assert(tensor->nb[0] == sizeof(int32_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
-                }
-            } break;
-        case GGML_TYPE_F16:
-            {
-                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
-                }
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
-                }
-            } break;
-        case GGML_TYPE_F32:
-            {
-                assert(tensor->nb[0] == sizeof(float));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
-                }
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-
-    return tensor;
-}
-
-struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
-    const int n     = ggml_nrows(tensor);
-    const int nc    = tensor->ne[0];
-    const size_t n1 = tensor->nb[1];
-
-    char * const data = tensor->data;
-
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            {
-                assert(tensor->nb[0] == sizeof(int8_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
-                }
-            } break;
-        case GGML_TYPE_I16:
-            {
-                assert(tensor->nb[0] == sizeof(int16_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
-                }
-            } break;
-        case GGML_TYPE_I32:
-            {
-                assert(tensor->nb[0] == sizeof(int32_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
-                }
-            } break;
-        case GGML_TYPE_F16:
-            {
-                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
-                }
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                assert(tensor->nb[0] == sizeof(ggml_bf16_t));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
-                }
-            } break;
-        case GGML_TYPE_F32:
-            {
-                assert(tensor->nb[0] == sizeof(float));
-                for (int i = 0; i < n; i++) {
-                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
-                }
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-
-    return tensor;
-}
-
-int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
-    if (!ggml_is_contiguous(tensor)) {
-        int64_t id[4] = { 0, 0, 0, 0 };
-        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
-        return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
-    }
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
-                return ((int8_t *)(tensor->data))[i];
-            }
-        case GGML_TYPE_I16:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
-                return ((int16_t *)(tensor->data))[i];
-            }
-        case GGML_TYPE_I32:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
-                return ((int32_t *)(tensor->data))[i];
-            }
-        case GGML_TYPE_F16:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
-                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
-            }
-        case GGML_TYPE_BF16:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
-                return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
-            }
-        case GGML_TYPE_F32:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(float));
-                return ((float *)(tensor->data))[i];
-            }
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
-    if (!ggml_is_contiguous(tensor)) {
-        int64_t id[4] = { 0, 0, 0, 0 };
-        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
-        ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
-        return;
-    }
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
-                ((int8_t *)(tensor->data))[i] = value;
-            } break;
-        case GGML_TYPE_I16:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
-                ((int16_t *)(tensor->data))[i] = value;
-            } break;
-        case GGML_TYPE_I32:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
-                ((int32_t *)(tensor->data))[i] = value;
-            } break;
-        case GGML_TYPE_F16:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
-                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
-                ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                GGML_ASSERT(tensor->nb[0] == sizeof(float));
-                ((float *)(tensor->data))[i] = value;
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
-    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            return ((int8_t *) data)[0];
-        case GGML_TYPE_I16:
-            return ((int16_t *) data)[0];
-        case GGML_TYPE_I32:
-            return ((int32_t *) data)[0];
-        case GGML_TYPE_F16:
-            return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
-        case GGML_TYPE_BF16:
-            return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
-        case GGML_TYPE_F32:
-            return ((float *) data)[0];
-        default:
-            GGML_ABORT("fatal error");
-    }
-}
-
-void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
-    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            {
-                ((int8_t *)(data))[0] = value;
-            } break;
-        case GGML_TYPE_I16:
-            {
-                ((int16_t *)(data))[0] = value;
-            } break;
-        case GGML_TYPE_I32:
-            {
-                ((int32_t *)(data))[0] = value;
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ((float *)(data))[0] = value;
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
-    if (!ggml_is_contiguous(tensor)) {
-        int64_t id[4] = { 0, 0, 0, 0 };
-        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
-        return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
-    }
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            {
-                return ((int8_t *)(tensor->data))[i];
-            }
-        case GGML_TYPE_I16:
-            {
-                return ((int16_t *)(tensor->data))[i];
-            }
-        case GGML_TYPE_I32:
-            {
-                return ((int32_t *)(tensor->data))[i];
-            }
-        case GGML_TYPE_F16:
-            {
-                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
-            }
-        case GGML_TYPE_BF16:
-            {
-                return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
-            }
-        case GGML_TYPE_F32:
-            {
-                return ((float *)(tensor->data))[i];
-            }
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
-    if (!ggml_is_contiguous(tensor)) {
-        int64_t id[4] = { 0, 0, 0, 0 };
-        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
-        ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
-        return;
-    }
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            {
-                ((int8_t *)(tensor->data))[i] = value;
-            } break;
-        case GGML_TYPE_I16:
-            {
-                ((int16_t *)(tensor->data))[i] = value;
-            } break;
-        case GGML_TYPE_I32:
-            {
-                ((int32_t *)(tensor->data))[i] = value;
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ((float *)(tensor->data))[i] = value;
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
-    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            return ((int8_t *) data)[0];
-        case GGML_TYPE_I16:
-            return ((int16_t *) data)[0];
-        case GGML_TYPE_I32:
-            return ((int32_t *) data)[0];
-        case GGML_TYPE_F16:
-            return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
-        case GGML_TYPE_BF16:
-            return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
-        case GGML_TYPE_F32:
-            return ((float *) data)[0];
-        default:
-            GGML_ABORT("fatal error");
-    }
-}
-
-void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
-    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
-    switch (tensor->type) {
-        case GGML_TYPE_I8:
-            {
-                ((int8_t *)(data))[0] = value;
-            } break;
-        case GGML_TYPE_I16:
-            {
-                ((int16_t *)(data))[0] = value;
-            } break;
-        case GGML_TYPE_I32:
-            {
-                ((int32_t *)(data))[0] = value;
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ((float *)(data))[0] = value;
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// ggml_compute_forward_dup
-
-static void ggml_compute_forward_dup_same_cont(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-    GGML_ASSERT(src0->type == dst->type);
-
-    const size_t nb0 = ggml_type_size(src0->type);
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by elements
-    const int ne = ggml_nelements(dst);
-    const int dr = (ne + nth - 1) / nth;
-    const int ie0 = dr * ith;
-    const int ie1 = MIN(ie0 + dr, ne);
-
-    if (ie0 < ie1) {
-        memcpy(
-            ((char *)  dst->data + ie0*nb0),
-            ((char *) src0->data + ie0*nb0),
-            (ie1 - ie0) * nb0);
-    }
-}
-
-static void ggml_compute_forward_dup_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
-        // copy by rows
-        const size_t rs = ne00*nb00;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
-
-    if (ggml_is_contiguous(dst)) {
-        if (nb00 == sizeof(ggml_fp16_t)) {
-            if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
-                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
-                            }
-
-                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        }
-        return;
-    }
-
-    // dst counters
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    if (dst->type == GGML_TYPE_F16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
-
-                        if (++i10 == ne00) {
-                            i10 = 0;
-                            if (++i11 == ne01) {
-                                i11 = 0;
-                                if (++i12 == ne02) {
-                                    i12 = 0;
-                                    if (++i13 == ne03) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
-    }
-}
-
-static void ggml_compute_forward_dup_bf16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
-        // copy by rows
-        const size_t rs = ne00*nb00;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
-
-    if (ggml_is_contiguous(dst)) {
-        if (nb00 == sizeof(ggml_bf16_t)) {
-            if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
-                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
-                            }
-
-                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        }
-        return;
-    }
-
-    // dst counters
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    if (dst->type == GGML_TYPE_BF16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
-
-                        if (++i10 == ne00) {
-                            i10 = 0;
-                            if (++i11 == ne01) {
-                                i11 = 0;
-                                if (++i12 == ne02) {
-                                    i12 = 0;
-                                    if (++i13 == ne03) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
-    }
-}
-
-static void ggml_compute_forward_dup_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
-        // copy by rows
-        const size_t rs = ne00*nb00;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    if (ggml_is_contiguous(dst)) {
-        // TODO: simplify
-        if (nb00 == sizeof(float)) {
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dst->type)->from_float;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        }
-
-        return;
-    }
-
-    // dst counters
-
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        memcpy(dst_ptr, src0_ptr, sizeof(float));
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_BF16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
-    }
-}
-
-// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
-static void ggml_compute_forward_dup_bytes(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-    GGML_ASSERT(src0->type == dst->type);
-
-    GGML_TENSOR_UNARY_OP_LOCALS;
-
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
-        ggml_compute_forward_dup_same_cont(params, dst);
-        return;
-    }
-
-    const size_t type_size = ggml_type_size(src0->type);
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == type_size && nb0 == type_size) {
-        // copy by rows
-        const size_t rs = ne00 * type_size;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    if (ggml_is_contiguous(dst)) {
-        size_t id = 0;
-        char * dst_ptr = (char *) dst->data;
-        const size_t rs = ne00 * type_size;
-
-        if (nb00 == type_size) {
-            // src0 is contigous on first dimension, copy by rows
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    id += rs * ir0;
-                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                        memcpy(dst_ptr + id, src0_ptr, rs);
-                        id += rs;
-                    }
-                    id += rs * (ne01 - ir1);
-                }
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    id += rs * ir0;
-                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, type_size);
-
-                            id += type_size;
-                        }
-                    }
-                    id += rs * (ne01 - ir1);
-                }
-            }
-        }
-
-        return;
-    }
-
-    // dst counters
-
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            i10 += ne00 * ir0;
-            while (i10 >= ne0) {
-                i10 -= ne0;
-                if (++i11 == ne1) {
-                    i11 = 0;
-                    if (++i12 == ne2) {
-                        i12 = 0;
-                        if (++i13 == ne3) {
-                            i13 = 0;
-                        }
-                    }
-                }
-            }
-            for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                          char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                    memcpy(dst_ptr, src0_ptr, type_size);
-
-                    if (++i10 == ne0) {
-                        i10 = 0;
-                        if (++i11 == ne1) {
-                            i11 = 0;
-                            if (++i12 == ne2) {
-                                i12 = 0;
-                                if (++i13 == ne3) {
-                                    i13 = 0;
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-            i10 += ne00 * (ne01 - ir1);
-            while (i10 >= ne0) {
-                i10 -= ne0;
-                if (++i11 == ne1) {
-                    i11 = 0;
-                    if (++i12 == ne2) {
-                        i12 = 0;
-                        if (++i13 == ne3) {
-                            i13 = 0;
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_dup(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (src0->type == dst->type) {
-        ggml_compute_forward_dup_bytes(params, dst);
-        return;
-    }
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_dup_f16(params, dst);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ggml_compute_forward_dup_bf16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_dup_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_add
-
-static void ggml_compute_forward_add_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(float)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-            const int64_t nr0 = ne00 / ne10;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
-            for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
-                vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
-#else
-                ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
-            }
-        }
-    } else {
-        // src1 is not contiguous
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                const int64_t i10 = i0 % ne10;
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
-                dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_add_f16_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    if (dst->type == GGML_TYPE_F32) {
-        GGML_ASSERT( nb0 == sizeof(float));
-    }
-    else {
-        GGML_ASSERT(dst->type  == GGML_TYPE_F16);
-        GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
-    }
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(float)) {
-        if (dst->type == GGML_TYPE_F16) {
-            for (int ir = ir0; ir < ir1; ++ir) {
-                // src0, src1 and dst are same shape => same indices
-                const int i3 = ir/(ne2*ne1);
-                const int i2 = (ir - i3*ne2*ne1)/ne1;
-                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-                for (int i = 0; i < ne0; i++) {
-                    dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
-                }
-            }
-        } else {
-            for (int ir = ir0; ir < ir1; ++ir) {
-                // src0, src1 and dst are same shape => same indices
-                const int i3 = ir/(ne2*ne1);
-                const int i2 = (ir - i3*ne2*ne1)/ne1;
-                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-                for (int i = 0; i < ne0; i++) {
-                    dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
-                }
-            }
-        }
-    }
-    else {
-        // src1 is not contiguous
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_compute_forward_add_bf16_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    if (dst->type == GGML_TYPE_F32) {
-        GGML_ASSERT( nb0 == sizeof(float));
-    }
-    else {
-        GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
-        GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
-    }
-
-    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(float)) {
-        if (dst->type == GGML_TYPE_BF16) {
-            for (int ir = ir0; ir < ir1; ++ir) {
-                // src0, src1 and dst are same shape => same indices
-                const int i3 = ir/(ne2*ne1);
-                const int i2 = (ir - i3*ne2*ne1)/ne1;
-                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-                ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-                for (int i = 0; i < ne0; i++) {
-                    dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
-                }
-            }
-        } else {
-            for (int ir = ir0; ir < ir1; ++ir) {
-                // src0, src1 and dst are same shape => same indices
-                const int i3 = ir/(ne2*ne1);
-                const int i2 = (ir - i3*ne2*ne1)/ne1;
-                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-                ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-                for (int i = 0; i < ne0; i++) {
-                    dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
-                }
-            }
-        }
-    }
-    else {
-        // src1 is not contiguous
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_compute_forward_add_f16_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(ggml_fp16_t)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-            for (int i = 0; i < ne0; i++) {
-                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
-            }
-        }
-    }
-    else {
-        // src1 is not contiguous
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_compute_forward_add_bf16_bf16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
-    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
-    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(ggml_bf16_t)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-            ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
-            ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
-
-            for (int i = 0; i < ne0; i++) {
-                dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
-            }
-        }
-    }
-    else {
-        // src1 is not contiguous
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_compute_forward_add_q_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const enum ggml_type type = src0->type;
-    const enum ggml_type dtype = dst->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-    ggml_from_float_t const quantize_row_q = ggml_get_type_traits(dtype)->from_float;
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    GGML_ASSERT(ggml_is_quantized(src0->type));
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 indices
-        const int i03 = ir/(ne02*ne01);
-        const int i02 = (ir - i03*ne02*ne01)/ne01;
-        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        // src1 and dst are same shape as src0 => same indices
-        const int i13 = i03;
-        const int i12 = i02;
-        const int i11 = i01;
-
-        const int i3 = i03;
-        const int i2 = i02;
-        const int i1 = i01;
-
-        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
-        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
-
-        assert(ne00 % 32 == 0);
-
-        // unquantize row from src0 to temp buffer
-        dequantize_row_q(src0_row, wdata, ne00);
-        // add src1
-        ggml_vec_acc_f32(ne00, wdata, src1_row);
-        // quantize row to dst
-        if (quantize_row_q != NULL) {
-            quantize_row_q(wdata, dst_row, ne00);
-        } else {
-            memcpy(dst_row, wdata, ne0*nb0);
-        }
-    }
-}
-
-static void ggml_compute_forward_add(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_F16:
-            {
-                if (src1->type == GGML_TYPE_F16) {
-                    ggml_compute_forward_add_f16_f16(params, dst);
-                }
-                else if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add_f16_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                if (src1->type == GGML_TYPE_BF16) {
-                    ggml_compute_forward_add_bf16_bf16(params, dst);
-                }
-                else if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add_bf16_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q4_0_4_4:
-        case GGML_TYPE_Q4_0_4_8:
-        case GGML_TYPE_Q4_0_8_8:
-            {
-                ggml_compute_forward_add_q_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_add1
-
-static void ggml_compute_forward_add1_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-#ifdef GGML_USE_ACCELERATE
-        UNUSED(ggml_vec_add1_f32);
-
-        vDSP_vadd(
-                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                (float *) ((char *) src1->data), 0,
-                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                ne0);
-#else
-        ggml_vec_add1_f32(ne0,
-                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-               *(float *) src1->data);
-#endif
-    }
-}
-
-static void ggml_compute_forward_add1_f16_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = *(float *) src1->data;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-        for (int i = 0; i < ne0; i++) {
-            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
-        }
-    }
-}
-
-static void ggml_compute_forward_add1_f16_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-        for (int i = 0; i < ne0; i++) {
-            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
-        }
-    }
-}
-
-static void ggml_compute_forward_add1_q_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = *(float *) src1->data;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const enum ggml_type type = src0->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-    ggml_from_float_t const quantize_row_q = ggml_get_type_traits(type)->from_float;
-
-    // we don't support permuted src0
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    GGML_ASSERT(ggml_is_quantized(src0->type));
-    GGML_ASSERT(dst->type == src0->type);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
-        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
-
-        assert(ne0 % 32 == 0);
-
-        // unquantize row from src0 to temp buffer
-        dequantize_row_q(src0_row, wdata, ne0);
-        // add src1
-        ggml_vec_acc1_f32(ne0, wdata, v);
-        // quantize row to dst
-        quantize_row_q(wdata, dst_row, ne0);
-    }
-}
-
-static void ggml_compute_forward_add1_bf16_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = *(float *) src1->data;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-        for (int i = 0; i < ne0; i++) {
-            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
-        }
-    }
-}
-
-static void ggml_compute_forward_add1_bf16_bf16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
-
-    // scalar to add
-    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
-    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
-    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
-
-    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are same shape => same indices
-        const int i3 = ir/(ne2*ne1);
-        const int i2 = (ir - i3*ne2*ne1)/ne1;
-        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-        for (int i = 0; i < ne0; i++) {
-            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
-        }
-    }
-}
-
-static void ggml_compute_forward_add1(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_add1_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                if (src1->type == GGML_TYPE_F16) {
-                    ggml_compute_forward_add1_f16_f16(params, dst);
-                }
-                else if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add1_f16_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                if (src1->type == GGML_TYPE_BF16) {
-                    ggml_compute_forward_add1_bf16_bf16(params, dst);
-                }
-                else if (src1->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_add1_bf16_f32(params, dst);
-                }
-                else {
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q4_0_4_4:
-        case GGML_TYPE_Q4_0_4_8:
-        case GGML_TYPE_Q4_0_8_8:
-            {
-                ggml_compute_forward_add1_q_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_acc
-
-static void ggml_compute_forward_acc_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-
-    // view src0 and dst with these strides and data offset inbytes during acc
-    // nb0 is implicitly element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) dst->op_params)[0];
-    size_t nb2     = ((int32_t *) dst->op_params)[1];
-    size_t nb3     = ((int32_t *) dst->op_params)[2];
-    size_t offset  = ((int32_t *) dst->op_params)[3];
-    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
-
-    if (!inplace) {
-        if (params->ith == 0) {
-            // memcpy needs to be synchronized across threads to avoid race conditions.
-            // => do it in INIT phase
-            memcpy(
-                ((char *)  dst->data),
-                ((char *) src0->data),
-                ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(src1);
-    const int nc = src1->ne[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
-
-    // src0 and dst as viewed during acc
-    const size_t nb0 = ggml_element_size(src0);
-
-    const size_t nb00 = nb0;
-    const size_t nb01 = nb1;
-    const size_t nb02 = nb2;
-    const size_t nb03 = nb3;
-
-    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
-    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
-
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are viewed with shape of src1 and offset
-        // => same indices
-        const int i3 = ir/(ne12*ne11);
-        const int i2 = (ir - i3*ne12*ne11)/ne11;
-        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
-
-#ifdef GGML_USE_ACCELERATE
-        vDSP_vadd(
-                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
-                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
-#else
-        ggml_vec_add_f32(nc,
-                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
-                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
-                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
-#endif
-    }
-}
-
-static void ggml_compute_forward_acc(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_acc_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q4_0_4_4:
-        case GGML_TYPE_Q4_0_4_8:
-        case GGML_TYPE_Q4_0_8_8:
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sub
-
-static void ggml_compute_forward_sub_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (nb10 == sizeof(float)) {
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-            const int64_t nr0 = ne00 / ne10;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
-            for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
-                vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
-#else
-                ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
-            }
-        }
-    } else {
-        // src1 is not contiguous
-        for (int ir = ir0; ir < ir1; ++ir) {
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                const int64_t i10 = i0 % ne10;
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
-                dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_sub(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sub_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_mul
-
-static void ggml_compute_forward_mul_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t nr = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    if (nb10 == sizeof(float)) {
-        for (int64_t ir = ith; ir < nr; ir += nth) {
-            // src0 and dst are same shape => same indices
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-            const int64_t nr0 = ne00 / ne10;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
-            for (int64_t r = 0 ; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
-                UNUSED(ggml_vec_mul_f32);
-
-                vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
-#else
-                ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
-            }
-        }
-    } else {
-        // src1 is not contiguous
-        for (int64_t ir = ith; ir < nr; ir += nth) {
-            // src0 and dst are same shape => same indices
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
-            for (int64_t i0 = 0; i0 < ne00; ++i0) {
-                const int64_t i10 = i0 % ne10;
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
-                dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_mul(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_mul_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_div
-
-static void ggml_compute_forward_div_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t nr = ggml_nrows(src0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    if (nb10 == sizeof(float)) {
-        for (int64_t ir = ith; ir < nr; ir += nth) {
-            // src0 and dst are same shape => same indices
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-            const int64_t nr0 = ne00 / ne10;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
-
-            for (int64_t r = 0; r < nr0; ++r) {
-#ifdef GGML_USE_ACCELERATE
-                UNUSED(ggml_vec_div_f32);
-
-                vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
-#else
-                ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
-#endif
-            }
-        }
-    } else {
-        // src1 is not contiguous
-        for (int64_t ir = ith; ir < nr; ir += nth) {
-            // src0 and dst are same shape => same indices
-            // src1 is broadcastable across src0 and dst in i1, i2, i3
-            const int64_t i03 = ir/(ne02*ne01);
-            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-            const int64_t i13 = i03 % ne13;
-            const int64_t i12 = i02 % ne12;
-            const int64_t i11 = i01 % ne11;
-
-            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
-
-            for (int64_t i0 = 0; i0 < ne00; ++i0) {
-                const int64_t i10 = i0 % ne10;
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
-
-                dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_div(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_div_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sqr
-
-static void ggml_compute_forward_sqr_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n     = ggml_nrows(src0);
-    const int nc    = src0->ne[0];
-
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sqr_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sqr(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sqr_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sqrt
-
-static void ggml_compute_forward_sqrt_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sqrt_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sqrt(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sqrt_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_log
-
-static void ggml_compute_forward_log_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_log_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_log(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_log_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sin
-
-static void ggml_compute_forward_sin_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sin_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sin(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sin_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_cos
-
-static void ggml_compute_forward_cos_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_cos_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_cos(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_cos_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sum
-
-static void ggml_compute_forward_sum_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_scalar(dst));
-    assert(src0->nb[0] == sizeof(float));
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
-
-    ggml_float sum     = 0;
-    ggml_float row_sum = 0;
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_f32_ggf(ne00,
-                        &row_sum,
-                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
-                sum += row_sum;
-            }
-        }
-    }
-    ((float *) dst->data)[0] = sum;
-}
-
-static void ggml_compute_forward_sum_f16(
-    const struct ggml_compute_params * params,
-          struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_scalar(dst));
-
-    assert(src0->nb[0] == sizeof(ggml_fp16_t));
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
-
-    float sum = 0;
-    float row_sum = 0;
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_f16_ggf(ne00,
-                    &row_sum,
-                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
-                sum += row_sum;
-            }
-        }
-    }
-    ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
-}
-
-static void ggml_compute_forward_sum_bf16(
-    const struct ggml_compute_params * params,
-          struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_scalar(dst));
-
-    assert(src0->nb[0] == sizeof(ggml_bf16_t));
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
-
-    float sum = 0;
-    float row_sum = 0;
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_bf16_ggf(ne00,
-                    &row_sum,
-                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
-                sum += row_sum;
-            }
-        }
-    }
-    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
-}
-
-static void ggml_compute_forward_sum(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sum_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_sum_f16(params, dst);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ggml_compute_forward_sum_bf16(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sum_rows
-
-static void ggml_compute_forward_sum_rows_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT(dst->nb[0] == sizeof(float));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(ne0 == 1);
-    GGML_ASSERT(ne1 == ne01);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
-    for (int64_t i3 = 0; i3 < ne03; i3++) {
-        for (int64_t i2 = 0; i2 < ne02; i2++) {
-            for (int64_t i1 = 0; i1 < ne01; i1++) {
-                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
-                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
-                float row_sum = 0;
-                ggml_vec_sum_f32(ne00, &row_sum, src_row);
-                dst_row[0] = row_sum;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_sum_rows(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sum_rows_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_mean
-
-static void ggml_compute_forward_mean_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(src0->nb[0] == sizeof(float));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    assert(ne0 == 1);
-    assert(ne1 == ne01);
-    assert(ne2 == ne02);
-    assert(ne3 == ne03);
-
-    UNUSED(ne0);
-    UNUSED(ne1);
-    UNUSED(ne2);
-    UNUSED(ne3);
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_f32(ne00,
-                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
-
-                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_mean(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_mean_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_argmax
-
-static void ggml_compute_forward_argmax_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(src0->nb[0] == sizeof(float));
-    assert(dst->nb[0] == sizeof(float));
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-
-    const size_t nb01 = src0->nb[1];
-    const size_t nb0 = dst->nb[0];
-
-    for (int64_t i1 = 0; i1 < ne01; i1++) {
-        float * src = (float *) ((char *) src0->data + i1*nb01);
-        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
-        int v = 0;
-        ggml_vec_argmax_f32(ne00, &v, src);
-        dst_[0] = v;
-    }
-}
-
-static void ggml_compute_forward_argmax(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_argmax_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_count_equal
-
-static void ggml_compute_forward_count_equal_i32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    GGML_ASSERT(src0->type == GGML_TYPE_I32);
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    GGML_ASSERT(ggml_is_scalar(dst));
-    GGML_ASSERT(dst->type == GGML_TYPE_I64);
-
-    const int64_t nr = ggml_nrows(src0);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    int64_t * sums = (int64_t *) params->wdata;
-    int64_t sum_thread = 0;
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t ir = ir0; ir < ir1; ++ir) {
-        const int64_t i03 =  ir                        / (ne02*ne01);
-        const int64_t i02 = (ir - i03*ne03)            /       ne01;
-        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
-
-        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
-        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
-
-        for (int64_t i00 = 0; i00 < ne00; ++i00) {
-            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
-            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
-
-            sum_thread += val0 == val1;
-        }
-    }
-    if (ith != 0) {
-        sums[ith] = sum_thread;
-    }
-    ggml_barrier(params->threadpool);
-
-    if (ith != 0) {
-        return;
-    }
-
-    for (int ith_other = 1; ith_other < nth; ++ith_other) {
-        sum_thread += sums[ith_other];
-    }
-    *((int64_t *) dst->data) = sum_thread;
-}
-
-static void ggml_compute_forward_count_equal(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_count_equal_i32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_repeat
-
-static void ggml_compute_forward_repeat_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int nr0 = (int)(ne0/ne00);
-    const int nr1 = (int)(ne1/ne01);
-    const int nr2 = (int)(ne2/ne02);
-    const int nr3 = (int)(ne3/ne03);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // TODO: maybe this is not optimal?
-    for                         (int i3 = 0; i3 < nr3;  i3++) {
-        for                     (int k3 = 0; k3 < ne03; k3++) {
-            for                 (int i2 = 0; i2 < nr2;  i2++) {
-                for             (int k2 = 0; k2 < ne02; k2++) {
-                    for         (int i1 = 0; i1 < nr1;  i1++) {
-                        for     (int k1 = 0; k1 < ne01; k1++) {
-                            for (int i0 = 0; i0 < nr0;  i0++) {
-                                ggml_vec_cpy_f32(ne00,
-                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
-                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_repeat_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int nr0 = (int)(ne0/ne00);
-    const int nr1 = (int)(ne1/ne01);
-    const int nr2 = (int)(ne2/ne02);
-    const int nr3 = (int)(ne3/ne03);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // TODO: maybe this is not optimal?
-    for                         (int i3 = 0; i3 < nr3;  i3++) {
-        for                     (int k3 = 0; k3 < ne03; k3++) {
-            for                 (int i2 = 0; i2 < nr2;  i2++) {
-                for             (int k2 = 0; k2 < ne02; k2++) {
-                    for         (int i1 = 0; i1 < nr1;  i1++) {
-                        for     (int k1 = 0; k1 < ne01; k1++) {
-                            for (int i0 = 0; i0 < nr0;  i0++) {
-                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
-                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
-                                // ggml_vec_cpy_f16(ne00, y, x)
-                                for (int i = 0; i < ne00; ++i) {
-                                    y[i]  = x[i];
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_repeat(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_I16:
-            {
-                ggml_compute_forward_repeat_f16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_repeat_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_repeat_back
-
-static void ggml_compute_forward_repeat_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_can_repeat(dst, src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int nr0 = (int)(ne00/ne0);
-    const int nr1 = (int)(ne01/ne1);
-    const int nr2 = (int)(ne02/ne2);
-    const int nr3 = (int)(ne03/ne3);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    if (ggml_is_contiguous(dst)) {
-        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-    } else {
-        for         (int k3 = 0; k3 < ne3; k3++) {
-            for     (int k2 = 0; k2 < ne2; k2++) {
-                for (int k1 = 0; k1 < ne1; k1++) {
-                    ggml_vec_set_f32(ne0,
-                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
-                        0);
-                }
-            }
-        }
-    }
-
-    // TODO: maybe this is not optimal?
-    for                         (int i3 = 0; i3 < nr3; i3++) {
-        for                     (int k3 = 0; k3 < ne3; k3++) {
-            for                 (int i2 = 0; i2 < nr2; i2++) {
-                for             (int k2 = 0; k2 < ne2; k2++) {
-                    for         (int i1 = 0; i1 < nr1; i1++) {
-                        for     (int k1 = 0; k1 < ne1; k1++) {
-                            for (int i0 = 0; i0 < nr0; i0++) {
-                                ggml_vec_acc_f32(ne0,
-                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
-                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_repeat_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_repeat_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_concat
-
-static void ggml_compute_forward_concat_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int32_t dim = ggml_get_op_params_i32(dst, 0);
-
-    GGML_ASSERT(dim >= 0 && dim < 4);
-
-    int64_t o[4] = {0, 0, 0, 0};
-    o[dim] = src0->ne[dim];
-
-    const float * x;
-
-    // TODO: smarter multi-theading
-    for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = ith; i2 < ne2; i2 += nth) {
-            for (int i1 = 0; i1 < ne1; i1++) {
-                for (int i0 = 0; i0 < ne0; i0++) {
-                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
-                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
-                    } else {
-                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
-                    }
-
-                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
-
-                    *y = *x;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_concat(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_concat_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_abs
-
-static void ggml_compute_forward_abs_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_abs_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_abs(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_abs_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sgn
-
-static void ggml_compute_forward_sgn_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sgn_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sgn(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sgn_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_neg
-
-static void ggml_compute_forward_neg_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_neg_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_neg(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_neg_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_step
-
-static void ggml_compute_forward_step_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_step_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_step(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_step_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_tanh
-
-static void ggml_compute_forward_tanh_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_tanh_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_tanh(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_tanh_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_elu
-
-static void ggml_compute_forward_elu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_elu_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_elu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_elu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_relu
-
-static void ggml_compute_forward_relu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_relu_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_relu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_relu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_sigmoid
-
-static void ggml_compute_forward_sigmoid_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_sigmoid_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_sigmoid(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_sigmoid_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_gelu
-
-static void ggml_compute_forward_gelu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_gelu_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
-        for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
-            UNUSED(x);
-            assert(!isnan(x));
-            assert(!isinf(x));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_gelu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_gelu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_gelu_quick
-
-static void ggml_compute_forward_gelu_quick_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_gelu_quick_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
-        for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
-            UNUSED(x);
-            assert(!isnan(x));
-            assert(!isinf(x));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_gelu_quick(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_gelu_quick_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_silu
-
-static void ggml_compute_forward_silu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_silu_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
-
-#ifndef NDEBUG
-        for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
-            UNUSED(x);
-            assert(!isnan(x));
-            assert(!isinf(x));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_silu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_silu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-// ggml_compute_forward_leaky_relu
-
-static void ggml_compute_forward_leaky_relu_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    float negative_slope;
-    memcpy(&negative_slope, dst->op_params, sizeof(float));
-
-    assert(dst->nb[0]  == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_leaky_relu_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
-    }
-}
-
-static void ggml_compute_forward_leaky_relu(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_leaky_relu_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_silu_back
-
-static void ggml_compute_forward_silu_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * grad = dst->src[1];
-
-    assert(ggml_is_contiguous_1(grad));
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-    assert(ggml_are_same_shape(src0, grad));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        ggml_vec_silu_backward_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])),
-                (float *) ((char *) grad->data + i1*(grad->nb[1])));
-
-#ifndef NDEBUG
-        for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
-            UNUSED(x);
-            assert(!isnan(x));
-            assert(!isinf(x));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_silu_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_silu_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-static void ggml_compute_forward_hardswish_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_hardswish_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-static void ggml_compute_forward_hardswish(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_hardswish_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_hardsigmoid_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_hardsigmoid_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_hardsigmoid(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_hardsigmoid_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_exp_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        ggml_vec_exp_f32(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_exp(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_exp_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-// ggml_compute_forward_norm
-
-static void ggml_compute_forward_norm_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    float eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
-
-    GGML_ASSERT(eps > 0.0f);
-
-    // TODO: optimize
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
-                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                ggml_float sum = 0.0;
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    sum += (ggml_float)x[i00];
-                }
-
-                float mean = sum/ne00;
-
-                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
-                ggml_float sum2 = 0.0;
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    float v = x[i00] - mean;
-                    y[i00] = v;
-                    sum2 += (ggml_float)(v*v);
-                }
-
-                float variance = sum2/ne00;
-                const float scale = 1.0f/sqrtf(variance + eps);
-
-                ggml_vec_scale_f32(ne00, y, scale);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_norm(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_norm_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_group_rms_norm
-
-static void ggml_compute_forward_rms_norm_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    float eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
-
-    GGML_ASSERT(eps > 0.0f);
-
-    // TODO: optimize
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
-                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                ggml_float sum = 0.0;
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    sum += (ggml_float)(x[i00] * x[i00]);
-                }
-
-                const float mean = sum/ne00;
-
-                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
-                memcpy(y, x, ne00 * sizeof(float));
-                // for (int i00 = 0; i00 < ne00; i00++) {
-                //     y[i00] = x[i00];
-                // }
-
-                const float scale = 1.0f/sqrtf(mean + eps);
-
-                ggml_vec_scale_f32(ne00, y, scale);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_rms_norm(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rms_norm_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_rms_norm_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    float eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
-
-    // TODO: optimize
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
-                // src1 is same shape as src0 => same indices
-                const int64_t i11 = i01;
-                const int64_t i12 = i02;
-                const int64_t i13 = i03;
-
-                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
-
-                ggml_float sum_xx  = 0.0;
-                ggml_float sum_xdz = 0.0;
-
-                for (int64_t i00 = 0; i00 < ne00; i00++) {
-                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
-                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
-                }
-
-                //const float mean     = (float)(sum_xx)/ne00;
-                const float mean_eps = (float)(sum_xx)/ne00 + eps;
-                const float sum_eps  = (float)(sum_xx) + eps*ne00;
-                //const float mean_xdz = (float)(sum_xdz)/ne00;
-                // we could cache rms from forward pass to improve performance.
-                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
-                //const float rms      = sqrtf(mean_eps);
-                const float rrms     = 1.0f / sqrtf(mean_eps);
-                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
-
-                {
-                    // z = rms_norm(x)
-                    //
-                    // rms_norm(src0) =
-                    //     scale(
-                    //         src0,
-                    //         div(
-                    //             1,
-                    //             sqrt(
-                    //                 add(
-                    //                     scale(
-                    //                         sum(
-                    //                             sqr(
-                    //                                 src0)),
-                    //                         (1.0/N)),
-                    //                     eps))));
-
-                    // postorder:
-                    // ## op    args         grad
-                    // 00 param src0         grad[#00]
-                    // 01 const 1
-                    // 02 sqr   (#00)        grad[#02]
-                    // 03 sum   (#02)        grad[#03]
-                    // 04 const 1/N
-                    // 05 scale (#03, #04)   grad[#05]
-                    // 06 const eps
-                    // 07 add   (#05, #06)   grad[#07]
-                    // 08 sqrt  (#07)        grad[#08]
-                    // 09 div   (#01,#08)    grad[#09]
-                    // 10 scale (#00,#09)    grad[#10]
-                    //
-                    // backward pass, given grad[#10]
-                    // #10: scale
-                    // grad[#00] += scale(grad[#10],#09)
-                    // grad[#09] += sum(mul(grad[#10],#00))
-                    // #09: div
-                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
-                    // #08: sqrt
-                    // grad[#07] += mul(grad[#08], div(0.5, #08))
-                    // #07: add
-                    // grad[#05] += grad[#07]
-                    // #05: scale
-                    // grad[#03] += scale(grad[#05],#04)
-                    // #03: sum
-                    // grad[#02] += repeat(grad[#03], #02)
-                    // #02:
-                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
-                    //
-                    // substitute and simplify:
-                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
-                    // grad[#02] = repeat(grad[#03], #02)
-                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
-                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
-                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
-                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
-                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
-                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
-                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
-                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
-                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
-                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
-                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
-                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
-                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
-                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
-                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
-                    // a = b*c + d*e
-                    // a = b*c*f/f + d*e*f/f
-                    // a = (b*c*f + d*e*f)*(1/f)
-                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
-                    // a = (b + d*e/c)*c
-                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
-                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
-                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
-                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
-                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
-                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
-                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
-                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
-                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
-                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
-                }
-                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
-                // post-order:
-                // dx := x
-                // dx := scale(dx,-mean_xdz/mean_eps)
-                // dx := add(dx, dz)
-                // dx := scale(dx, rrms)
-                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
-
-                ggml_vec_cpy_f32  (ne00, dx, x);
-                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
-                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
-                ggml_vec_acc_f32  (ne00, dx, dz);
-                ggml_vec_scale_f32(ne00, dx, rrms);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_rms_norm_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rms_norm_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_group_norm
-
-static void ggml_compute_forward_group_norm_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    // TODO: optimize
-
-    float eps;
-    memcpy(&eps, dst->op_params + 1, sizeof(float));
-
-    int n_channels = src0->ne[2];
-    int n_groups = dst->op_params[0];
-    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
-    for (int i = ith; i < n_groups; i += nth) {
-        int start = i * n_channels_per_group;
-        int end = start + n_channels_per_group;
-        if (end > n_channels) {
-            end = n_channels;
-        }
-        int step = end - start;
-
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            ggml_float sum = 0.0;
-            for (int64_t i02 = start; i02 < end; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
-
-                    ggml_float sumr = 0.0;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        sumr += (ggml_float)x[i00];
-                    }
-                    sum += sumr;
-                }
-            }
-            const float mean = sum / (ne00 * ne01 * step);
-
-            ggml_float sum2 = 0.0;
-            for (int64_t i02 = start; i02 < end; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
-
-                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
-
-                    ggml_float sumr = 0.0;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        float v = x[i00] - mean;
-                        y[i00] = v;
-                        sumr += (ggml_float)(v * v);
-                    }
-                    sum2 += sumr;
-                }
-            }
-            const float variance = sum2 / (ne00 * ne01 * step);
-            const float scale = 1.0f / sqrtf(variance + eps);
-
-            for (int64_t i02 = start; i02 < end; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
-                    ggml_vec_scale_f32(ne00, y, scale);
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_group_norm(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_group_norm_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_mul_mat
-
-static void ggml_compute_forward_mul_mat_one_chunk(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const int64_t num_rows_per_vec_dot,
-    const int64_t ir0_start,
-    const int64_t ir0_end,
-    const int64_t ir1_start,
-    const int64_t ir1_end) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const enum ggml_type type = src0->type;
-
-    const bool src1_cont = ggml_is_contiguous(src1);
-
-    ggml_vec_dot_t const vec_dot      = type_traits_cpu[type].vec_dot;
-    enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
-
-    // broadcast factors
-    const int64_t r2 = ne12 / ne02;
-    const int64_t r3 = ne13 / ne03;
-
-    //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
-
-    // threads with no work simply yield (not sure if it helps)
-    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
-        return;
-    }
-
-    const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
-    const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
-    assert(ne12 % ne02 == 0);
-    assert(ne13 % ne03 == 0);
-
-    // block-tiling attempt
-    const int64_t blck_0 = 16;
-    const int64_t blck_1 = 16;
-
-    const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
-
-    // attempt to reduce false-sharing (does not seem to make a difference)
-    // 16 * 2, accounting for mmla kernels
-    float tmp[32];
-
-    for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
-        for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
-            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
-                const int64_t i13 = (ir1 / (ne12 * ne1));
-                const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
-                const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
-
-                // broadcast src0 into src1
-                const int64_t i03 = i13 / r3;
-                const int64_t i02 = i12 / r2;
-
-                const int64_t i1 = i11;
-                const int64_t i2 = i12;
-                const int64_t i3 = i13;
-
-                const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
-
-                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
-                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
-                //       the original src1 data pointer, so we should index using the indices directly
-                // TODO: this is a bit of a hack, we should probably have a better way to handle this
-                const char * src1_col = (const char*)wdata +
-                    (src1_cont || src1->type != vec_dot_type
-                        ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
-                        : (i11 * nb11 + i12 * nb12 + i13 * nb13));
-                float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
-
-                //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
-                //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
-                //}
-
-                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
-                    vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
-                }
-
-                for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
-                    memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_mul_mat(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const enum ggml_type type = src0->type;
-
-    enum ggml_type           const vec_dot_type         = type_traits_cpu[type].vec_dot_type;
-    ggml_from_float_t        const from_float           = ggml_get_type_traits(vec_dot_type)->from_float;
-    ggml_from_float_to_mat_t const from_float_to_mat    = type_traits_cpu[vec_dot_type].from_float_to_mat;
-    int64_t                  const vec_dot_num_rows     = type_traits_cpu[type].nrows;
-    int64_t                  const matmul_num_cols      = type_traits_cpu[type].ncols;
-    int64_t                  const blck_size_interleave = ggml_get_type_traits(type)->blck_size_interleave;
-    ggml_gemv_t              const gemv                 = type_traits_cpu[type].gemv;
-    ggml_gemm_t              const gemm                 = type_traits_cpu[type].gemm;
-
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne12);
-    GGML_ASSERT(ne3 == ne13);
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-#if GGML_USE_LLAMAFILE
-    // broadcast factors
-    const int64_t r2 = ne12 / ne02;
-    const int64_t r3 = ne13 / ne03;
-
-    const bool src1_cont = ggml_is_contiguous(src1);
-
-    if (src1_cont) {
-        for (int64_t i13 = 0; i13 < ne13; i13++)
-            for (int64_t i12 = 0; i12 < ne12; i12++)
-                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
-                                     (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
-                                     nb01/ggml_type_size(src0->type),
-                                     (const char *)src1->data + i12*nb12 + i13*nb13,
-                                     nb11/ggml_type_size(src1->type),
-                                     (char *)dst->data + i12*nb2 + i13*nb3,
-                                     nb1/ggml_type_size(dst->type),
-                                     ith, nth,
-                                     src0->type,
-                                     src1->type,
-                                     dst->type))
-                    goto UseGgmlGemm1;
-        return;
-    }
-UseGgmlGemm1:;
-#endif
-
-    if (src1->type != vec_dot_type) {
-        char * wdata = params->wdata;
-
-        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
-        const size_t nbw2 = nbw1*ne11;
-        const size_t nbw3 = nbw2*ne12;
-
-        assert(params->wsize >= ne13*nbw3);
-        GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-        for (int64_t i13 = 0; i13 < ne13; ++i13) {
-            for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                int64_t i11_processed = 0;
-                if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
-                    for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
-                        from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
-                                          (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
-                                          4, ne10, blck_size_interleave);
-                    }
-                    i11_processed = ne11 - ne11 % 4;
-                }
-                for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
-                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
-                           (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
-                           ne10);
-                }
-            }
-        }
-    }
-
-    if (ith == 0) {
-        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
-        atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
-    }
-
-    ggml_barrier(params->threadpool);
-
-#if GGML_USE_LLAMAFILE
-    if (src1->type != vec_dot_type) {
-        const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
-        const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
-        for (int64_t i13 = 0; i13 < ne13; i13++)
-            for (int64_t i12 = 0; i12 < ne12; i12++)
-                if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
-                                     (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
-                                     nb01/ggml_type_size(src0->type),
-                                     (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
-                                     row_size/ggml_type_size(vec_dot_type),
-                                     (char *)dst->data + i12*nb2 + i13*nb3,
-                                     nb1/ggml_type_size(dst->type),
-                                     ith, nth,
-                                     src0->type,
-                                     vec_dot_type,
-                                     dst->type))
-                    goto UseGgmlGemm2;
-        return;
-    }
-UseGgmlGemm2:;
-#endif
-
-    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
-    const int64_t nr0 = ne0;
-
-    // This is the size of the rest of the dimensions of the result
-    const int64_t nr1 = ne1 * ne2 * ne3;
-
-    // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
-    int64_t num_rows_per_vec_dot = vec_dot_num_rows;
-    // TODO: currently the mmla kernels support only even numbered rows/cols.
-    // this check can be removed once they are extended to support odd numbered rows/cols too
-    if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
-        num_rows_per_vec_dot = 1;
-    }
-
-    // Now select a reasonable chunk size.
-    int chunk_size = 16;
-
-    // We need to step up the size if it's small
-    if (nr0 == 1 || nr1 == 1) {
-        chunk_size = 64;
-    }
-
-    // distribute the work across the inner or outer loop based on which one is larger
-    // The number of chunks in the 0/1 dim.
-    // CEIL(nr0/chunk_size)
-    int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
-    int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
-
-    // If the chunking is poor for the number of threads on this setup, scrap the whole plan.  Re-chunk it by thread.
-    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggerganov/llama.cpp/pull/6915
-    //   In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
-    if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
-        // distribute the thread work across the inner or outer loop based on which one is larger
-        nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
-        nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
-    }
-
-    // The number of elements in each chunk
-    const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
-    const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
-
-    if ((ggml_n_dims(src0) == 2) && gemv) {
-        const void * src1_wdata      = (src1->type == vec_dot_type) ? src1->data : params->wdata;
-        const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
-        int64_t src0_start = (ith * ne01) / nth;
-        int64_t src0_end   = ((ith + 1) * ne01) / nth;
-        src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
-        src0_end   = (src0_end   % matmul_num_cols) ? src0_end   + matmul_num_cols - (src0_end   % matmul_num_cols): src0_end;
-        if (src0_start >= src0_end) return;
-
-        // If there are more than three rows in src1, use gemm; otherwise, use gemv.
-        if (gemm && (ne11 > 3)) {
-            gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
-                 (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
-        }
-        for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
-            gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
-                 (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
-                 src0_end - src0_start);
-        }
-        return;
-    }
-
-    // The first chunk comes from our thread_id, the rest will get auto-assigned.
-    int current_chunk = ith;
-
-    while (current_chunk < nchunk0 * nchunk1) {
-        const int64_t ith0 = current_chunk % nchunk0;
-        const int64_t ith1 = current_chunk / nchunk0;
-
-        const int64_t ir0_start = dr0 * ith0;
-        const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
-
-        const int64_t ir1_start = dr1 * ith1;
-        const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
-
-        ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
-
-        if (nth >= nchunk0 * nchunk1) {
-            break;
-        }
-
-        current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
-    }
-}
-
-// ggml_compute_forward_mul_mat_id
-
-static void ggml_compute_forward_mul_mat_id(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * ids = dst->src[2];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const enum ggml_type type = src0->type;
-
-    const bool src1_cont = ggml_is_contiguous(src1);
-
-    ggml_vec_dot_t    const vec_dot         = type_traits_cpu[type].vec_dot;
-    enum ggml_type    const vec_dot_type    = type_traits_cpu[type].vec_dot_type;
-    ggml_from_float_t const from_float      = ggml_get_type_traits(vec_dot_type)->from_float;
-    int64_t           const matmul_num_cols = type_traits_cpu[type].ncols;
-    ggml_gemv_t       const gemv            = type_traits_cpu[type].gemv;
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    // row groups
-    const int n_ids = ids->ne[0]; // n_expert_used
-    const int n_as  = ne02;       // n_expert
-
-    char * wdata_src1_end = (src1->type == vec_dot_type) ?
-            (char *) params->wdata :
-            (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
-
-    struct mmid_row_mapping {
-        int32_t i1;
-        int32_t i2;
-    };
-
-    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
-    struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
-
-    if (src1->type != vec_dot_type) {
-        char * wdata = params->wdata;
-
-        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
-        const size_t nbw2 = nbw1*ne11;
-        const size_t nbw3 = nbw2*ne12;
-
-        assert(params->wsize >= ne13*nbw3);
-        GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-        for (int64_t i13 = 0; i13 < ne13; ++i13) {
-            for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
-                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
-                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
-                               ne10);
-                }
-            }
-        }
-    }
-
-#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
-
-    if (ith == 0) {
-        // initialize matrix_row_counts
-        memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
-
-        // group rows by src0 matrix
-        for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
-            for (int id = 0; id < n_ids; ++id) {
-                const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
-
-                assert(i02 >= 0 && i02 < n_as);
-
-                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
-                matrix_row_counts[i02] += 1;
-            }
-        }
-    }
-
-    ggml_barrier(params->threadpool);
-
-    // compute each matrix multiplication in sequence
-    for (int cur_a = 0; cur_a < n_as; ++cur_a) {
-        const int64_t cne1 = matrix_row_counts[cur_a];
-
-        if (cne1 == 0) {
-            continue;
-        }
-
-        const char * src0_cur = (const char *) src0->data + cur_a*nb02;
-
-        const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
-        const size_t row_size = ggml_row_size(vec_dot_type, ne10);
-
-        const int64_t nr0 = ne01; // src0 rows
-        const int64_t nr1 = cne1; // src1 rows
-
-        if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
-            int64_t src0_cur_start = (ith * ne01) / nth;
-            int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;
-            src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
-            src0_cur_end   = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
-            if (src0_cur_start >= src0_cur_end) return;
-
-            for (int ir1 = 0; ir1 < nr1; ir1++) {
-                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
-                const int id       = row_mapping.i1; // selected expert index
-
-                const int64_t  i11 = id % ne11;
-                const int64_t  i12 = row_mapping.i2; // row index in src1
-
-                const int64_t  i1 = id;  // selected expert index
-                const int64_t  i2 = i12; // row
-
-                const char * src1_col = (const char *) wdata +
-                    (src1_cont || src1->type != vec_dot_type
-                    ? (i11        + i12 * ne11) * row_size
-                    : (i11 * nb11 + i12 * nb12));
-
-                gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
-                     (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
-            }
-            continue;
-        }
-
-        // distribute the thread work across the inner or outer loop based on which one is larger
-
-        const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
-        const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
-
-        const int64_t ith0 = ith % nth0;
-        const int64_t ith1 = ith / nth0;
-
-        const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
-        const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
-
-        const int64_t ir010 = dr0*ith0;
-        const int64_t ir011 = MIN(ir010 + dr0, nr0);
-
-        const int64_t ir110 = dr1*ith1;
-        const int64_t ir111 = MIN(ir110 + dr1, nr1);
-
-        // threads with no work simply yield (not sure if it helps)
-        //if (ir010 >= ir011 || ir110 >= ir111) {
-        //    sched_yield();
-        //    continue;
-        //}
-
-        // block-tiling attempt
-        const int64_t blck_0 = 16;
-        const int64_t blck_1 = 16;
-
-        // attempt to reduce false-sharing (does not seem to make a difference)
-        float tmp[16];
-
-        for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
-            for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
-                for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
-                    const int64_t _i12 = ir1; // logical row index for this expert
-
-                    struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
-                    const int id       = row_mapping.i1; // selected expert index
-
-                    const int64_t  i11 = id % ne11;
-                    const int64_t  i12 = row_mapping.i2; // row index in src1
-
-                    const int64_t  i1 = id;  // selected expert index
-                    const int64_t  i2 = i12; // row
-
-                    // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
-                    //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
-                    //       the original src1 data pointer, so we should index using the indices directly
-                    // TODO: this is a bit of a hack, we should probably have a better way to handle this
-                    const char * src1_col = (const char *) wdata +
-                        (src1_cont || src1->type != vec_dot_type
-                        ? (i11      + i12*ne11)*row_size
-                        : (i11*nb11 + i12*nb12));
-
-                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
-
-                    //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                    //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
-                    //}
-
-                    for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
-                    }
-
-                    memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
-                }
-            }
-        }
-    }
-
-#undef MMID_MATRIX_ROW
-}
-
-// ggml_compute_forward_out_prod
-
-static void ggml_compute_forward_out_prod_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_ASSERT(ne0  == ne00);
-    GGML_ASSERT(ne1  == ne10);
-    GGML_ASSERT(ne2  == ne02);
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne3  == ne13);
-    GGML_ASSERT(ne03 == ne13);
-
-    // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    // GGML_ASSERT(nb0 <= nb1);
-    // GGML_ASSERT(nb1 <= nb2);
-    // GGML_ASSERT(nb2 <= nb3);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-    if (ith == 0) {
-        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-    }
-    ggml_barrier(params->threadpool);
-
-    // dst[:,:,:,:] = 0
-    // for i2,i3:
-    //   for i1:
-    //     for i01:
-    //       for i0:
-    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
-
-    // parallelize by last three dimensions
-
-    // total rows in dst
-    const int64_t nr = ne1*ne2*ne3;
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    // block-tiling attempt
-    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
-    const int64_t blck_1 = 16;
-
-    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
-        const int64_t bir1 = MIN(bir + blck_1, ir1);
-        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
-            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
-            for (int64_t ir = bir; ir < bir1; ++ir) {
-                // dst indices
-                const int64_t i3 = ir/(ne2*ne1);
-                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
-                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-                const int64_t i02 = i2;
-                const int64_t i03 = i3;
-
-                //const int64_t i10 = i1;
-                const int64_t i12 = i2;
-                const int64_t i13 = i3;
-
-#if GGML_VEC_MAD_UNROLL > 2
-                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
-                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
-                    const int64_t i11 = i01;
-
-                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
-                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
-
-                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
-                }
-                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
-                    const int64_t i11 = i01;
-
-                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
-                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
-
-                    ggml_vec_mad_f32(ne0, d, s0, *s1);
-                }
-#else
-                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
-                    const int64_t i11 = i01;
-
-                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
-                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
-
-                    ggml_vec_mad_f32(ne0, d, s0, *s1);
-                }
-#endif
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_out_prod_q_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const enum ggml_type type = src0->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
-    GGML_ASSERT(ne3  == ne13);
-
-    // we don't support permuted src0 dim0
-    GGML_ASSERT(nb00 == ggml_type_size(type));
-
-    // dst dim0 cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    // GGML_ASSERT(nb0 <= nb1);
-    // GGML_ASSERT(nb1 <= nb2);
-    // GGML_ASSERT(nb2 <= nb3);
-
-    GGML_ASSERT(ne0 == ne00);
-    GGML_ASSERT(ne1 == ne10);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-    if (ith == 0) {
-        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
-    }
-    ggml_barrier(params->threadpool);
-
-    // parallelize by last three dimensions
-
-    // total rows in dst
-    const int64_t nr = ne1*ne2*ne3;
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    // dst[:,:,:,:] = 0
-    // for i2,i3:
-    //   for i1:
-    //     for i01:
-    //       for i0:
-    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
-
-    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
-
-    for (int64_t ir = ir0; ir < ir1; ++ir) {
-        // dst indices
-        const int64_t i3 = ir/(ne2*ne1);
-        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
-        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
-
-        const int64_t i02 = i2;
-        const int64_t i03 = i3;
-
-        //const int64_t i10 = i1;
-        const int64_t i12 = i2;
-        const int64_t i13 = i3;
-
-        for (int64_t i01 = 0; i01 < ne01; ++i01) {
-            const int64_t i11 = i01;
-
-            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
-            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
-
-            dequantize_row_q(s0, wdata, ne0);
-            ggml_vec_mad_f32(ne0, d, wdata, *s1);
-        }
-    }
-}
-
-static void ggml_compute_forward_out_prod(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q4_0_4_4:
-        case GGML_TYPE_Q4_0_4_8:
-        case GGML_TYPE_Q4_0_8_8:
-            {
-                ggml_compute_forward_out_prod_q_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                GGML_ABORT("fatal error"); // todo
-                // ggml_compute_forward_out_prod_f16_f32(params, dst);
-            }
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_out_prod_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_scale
-
-static void ggml_compute_forward_scale_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
-    // scale factor
-    float v;
-    memcpy(&v, dst->op_params, sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    const size_t nb01 = src0->nb[1];
-
-    const size_t nb1 = dst->nb[1];
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        if (dst->data != src0->data) {
-            // src0 is same shape as dst => same indices
-            memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
-        }
-        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
-    }
-}
-
-static void ggml_compute_forward_scale(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_scale_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_set
-
-static void ggml_compute_forward_set_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-
-    // view src0 and dst with these strides and data offset inbytes during set
-    // nb0 is implicitly element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) dst->op_params)[0];
-    size_t nb2     = ((int32_t *) dst->op_params)[1];
-    size_t nb3     = ((int32_t *) dst->op_params)[2];
-    size_t offset  = ((int32_t *) dst->op_params)[3];
-    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
-
-    if (!inplace) {
-        if (params->ith == 0) {
-            // memcpy needs to be synchronized across threads to avoid race conditions.
-            // => do it in INIT phase
-            memcpy(
-                ((char *)  dst->data),
-                ((char *) src0->data),
-                ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(src1);
-    const int nc = src1->ne[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
-
-    // src0 and dst as viewed during set
-    const size_t nb0 = ggml_element_size(src0);
-
-    const int im0 = (ne10 == 0 ? 0 : ne10-1);
-    const int im1 = (ne11 == 0 ? 0 : ne11-1);
-    const int im2 = (ne12 == 0 ? 0 : ne12-1);
-    const int im3 = (ne13 == 0 ? 0 : ne13-1);
-
-    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
-
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 and dst are viewed with shape of src1 and offset
-        // => same indices
-        const int i3 = ir/(ne12*ne11);
-        const int i2 = (ir - i3*ne12*ne11)/ne11;
-        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
-
-        ggml_vec_cpy_f32(nc,
-                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
-                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
-    }
-}
-
-static void ggml_compute_forward_set(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_set_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q4_0_4_4:
-        case GGML_TYPE_Q4_0_4_8:
-        case GGML_TYPE_Q4_0_8_8:
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_cpy
-
-static void ggml_compute_forward_cpy(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    ggml_compute_forward_dup(params, dst);
-}
-
-// ggml_compute_forward_cont
-
-static void ggml_compute_forward_cont(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    ggml_compute_forward_dup(params, dst);
-}
-
-// ggml_compute_forward_reshape
-
-static void ggml_compute_forward_reshape(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    // NOP
-    UNUSED(params);
-    UNUSED(dst);
-}
-
-// ggml_compute_forward_view
-
-static void ggml_compute_forward_view(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * dst) {
-    // NOP
-    UNUSED(params);
-    UNUSED(dst);
-}
-
-// ggml_compute_forward_permute
-
-static void ggml_compute_forward_permute(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * dst) {
-    // NOP
-    UNUSED(params);
-    UNUSED(dst);
-}
-
-// ggml_compute_forward_transpose
-
-static void ggml_compute_forward_transpose(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * dst) {
-    // NOP
-    UNUSED(params);
-    UNUSED(dst);
-}
-
-// ggml_compute_forward_get_rows
-
-static void ggml_compute_forward_get_rows_q(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
-
-    const enum ggml_type type = src0->type;
-    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
-
-    assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == ggml_type_size(type));
-    assert(ggml_nrows(dst) == nr);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
-        GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
-        dequantize_row_q(
-                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
-                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
-    }
-}
-
-static void ggml_compute_forward_get_rows_f16(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
-
-    assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == sizeof(ggml_fp16_t));
-    assert(ggml_nrows(dst) == nr);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
-        GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
-        ggml_fp16_to_fp32_row(
-                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
-                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
-    }
-}
-
-static void ggml_compute_forward_get_rows_bf16(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
-
-    assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == sizeof(ggml_bf16_t));
-    assert(ggml_nrows(dst) == nr);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
-        GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
-        ggml_bf16_to_fp32_row(
-                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
-                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
-    }
-}
-
-static void ggml_compute_forward_get_rows_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1);
-
-    assert(ne0  == nc);
-    assert(ne02 == ne11);
-    assert(nb00 == sizeof(float));
-    assert(ggml_nrows(dst) == nr);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i = ir0; i < ir1; ++i) {
-        const int64_t i12 = i/(ne11*ne10);
-        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
-        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
-        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
-
-        GGML_ASSERT(i01 >= 0 && i01 < ne01);
-
-        ggml_vec_cpy_f32(nc,
-                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
-                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
-    }
-}
-
-static void ggml_compute_forward_get_rows(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q4_0_4_4:
-        case GGML_TYPE_Q4_0_4_8:
-        case GGML_TYPE_Q4_0_8_8:
-            {
-                ggml_compute_forward_get_rows_q(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_get_rows_f16(params, dst);
-            } break;
-        case GGML_TYPE_BF16:
-            {
-                ggml_compute_forward_get_rows_bf16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-        case GGML_TYPE_I32:
-            {
-                ggml_compute_forward_get_rows_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-
-    //static bool first = true;
-    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-    //if (first) {
-    //    first = false;
-    //} else {
-    //    for (int k = 0; k < dst->ne[1]; ++k) {
-    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
-    //            for (int i = 0; i < 16; ++i) {
-    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
-    //            }
-    //            printf("\n");
-    //        }
-    //        printf("\n");
-    //    }
-    //    printf("\n");
-    //    exit(0);
-    //}
-}
-
-// ggml_compute_forward_get_rows_back
-
-static void ggml_compute_forward_get_rows_back_f32_f16(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_is_contiguous(dst));
-
-    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
-
-    memset(dst->data, 0, ggml_nbytes(dst));
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
-
-    GGML_ASSERT( dst->ne[0] == nc);
-    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
-
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
-
-        for (int j = 0; j < nc; ++j) {
-            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
-            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
-        }
-    }
-}
-
-static void ggml_compute_forward_get_rows_back_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    GGML_ASSERT(ggml_is_contiguous(dst));
-
-    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
-
-    memset(dst->data, 0, ggml_nbytes(dst));
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
-
-    GGML_ASSERT( dst->ne[0] == nc);
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
-
-        ggml_vec_add_f32(nc,
-                (float *) ((char *)  dst->data + r*dst->nb[1]),
-                (float *) ((char *)  dst->data + r*dst->nb[1]),
-                (float *) ((char *) src0->data + i*src0->nb[1]));
-    }
-}
-
-static void ggml_compute_forward_get_rows_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_get_rows_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-
-    //static bool first = true;
-    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-    //if (first) {
-    //    first = false;
-    //} else {
-    //    for (int k = 0; k < dst->ne[1]; ++k) {
-    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
-    //            for (int i = 0; i < 16; ++i) {
-    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
-    //            }
-    //            printf("\n");
-    //        }
-    //        printf("\n");
-    //    }
-    //    printf("\n");
-    //    exit(0);
-    //}
-}
-
-// ggml_compute_forward_diag
-
-static void ggml_compute_forward_diag_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    // TODO: handle transposed/permuted matrices
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(ne00 == ne0);
-    GGML_ASSERT(ne00 == ne1);
-    GGML_ASSERT(ne01 == 1);
-    GGML_ASSERT(ne02 == ne2);
-    GGML_ASSERT(ne03 == ne3);
-
-    GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(nb0  == sizeof(float));
-
-    for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = 0; i2 < ne2; i2++) {
-            for (int i1 = 0; i1 < ne1; i1++) {
-                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
-                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
-                for (int i0 = 0; i0 < i1; i0++) {
-                    d[i0] = 0;
-                }
-                d[i1] = s[i1];
-                for (int i0 = i1+1; i0 < ne0; i0++) {
-                    d[i0] = 0;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_diag(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_diag_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_diag_mask_inf
-
-static void ggml_compute_forward_diag_mask_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const float value) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int  n_past  = ((int32_t *) dst->op_params)[0];
-    const bool inplace = src0->data == dst->data;
-
-    GGML_ASSERT(n_past >= 0);
-
-    if (!inplace) {
-        if (ith == 0) {
-            // memcpy needs to be synchronized across threads to avoid race conditions.
-            // => do it in INIT phase
-            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
-            memcpy(
-                ((char *)  dst->data),
-                ((char *) src0->data),
-                ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-
-    // TODO: handle transposed/permuted matrices
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-    const int nr = src0->ne[1];
-    const int nz = n/nr;
-
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    for (int k = 0; k < nz; k++) {
-        for (int j = ith; j < nr; j += nth) {
-            for (int i = n_past; i < nc; i++) {
-                if (i > n_past + j) {
-                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_diag_mask_inf(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_diag_mask_zero(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_diag_mask_f32(params, dst, 0);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_soft_max
-
-static void ggml_compute_forward_soft_max_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    assert(ggml_is_contiguous(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
-
-    // TODO: handle transposed/permuted matrices
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //const int64_t ne11 = src1 ? src1->ne[1] : 1;
-
-    // TODO: is this supposed to be ceil instead of floor?
-    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
-    const uint32_t n_head      = ne02;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
-
-    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        // ALiBi
-        const uint32_t h = (i1/ne01)%ne02; // head
-        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
-        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
-
-        // broadcast the mask across rows
-        ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-        float       * mp_f32 = src1 ? (float       *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
-
-        ggml_vec_cpy_f32  (nc, wp, sp);
-        ggml_vec_scale_f32(nc, wp, scale);
-        if (mp_f32) {
-            if (use_f16) {
-                for (int i = 0; i < nc; ++i) {
-                    wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
-                }
-            } else {
-                for (int i = 0; i < nc; ++i) {
-                    wp[i] += slope*mp_f32[i];
-                }
-            }
-        }
-
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(wp[i]));
-        }
-#endif
-
-        float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, wp);
-
-        ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
-        assert(sum > 0.0);
-
-        sum = 1.0/sum;
-        ggml_vec_scale_f32(nc, dp, sum);
-
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            assert(!isnan(dp[i]));
-            assert(!isinf(dp[i]));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_soft_max(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_soft_max_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-// ggml_compute_forward_soft_max_back
-
-static void ggml_compute_forward_soft_max_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_are_same_shape(src1, dst));
-
-    // TODO: handle transposed/permuted matrices
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
-        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
-
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(dy[i]));
-            assert(!isnan(y[i]));
-        }
-#endif
-        // Jii = yi - yi*yi
-        // Jij = -yi*yj
-        // J = diag(y)-y.T*y
-        // dx = J * dy
-        // dxk = sum_i(Jki * dyi)
-        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
-        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
-        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
-        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
-        // dxk = -yk * dot(y, dy) + yk*dyk
-        // dxk = yk * (- dot(y, dy) + dyk)
-        // dxk = yk * (dyk - dot(y, dy))
-        //
-        // post-order:
-        // dot_y_dy := dot(y, dy)
-        // dx := dy
-        // dx := dx - dot_y_dy
-        // dx := dx * y
-
-        // linear runtime, no additional memory
-        float dot_y_dy = 0;
-        ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
-        ggml_vec_cpy_f32 (nc, dx, dy);
-        ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
-        ggml_vec_mul_f32 (nc, dx, dx, y);
-
-#ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
-            assert(!isnan(dx[i]));
-            assert(!isinf(dx[i]));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_soft_max_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_soft_max_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_clamp
-
-static void ggml_compute_forward_clamp_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    float min;
-    float max;
-    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-
-    GGML_ASSERT( nb0 == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    for (int j = ith; j < n; j += nth) {
-        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
-        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
-
-        for (int i = 0; i < nc; i++) {
-            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
-        }
-    }
-}
-
-static void ggml_compute_forward_clamp(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_clamp_f32(params, dst);
-            } break;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_TQ1_0:
-        case GGML_TYPE_TQ2_0:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q8_K:
-        case GGML_TYPE_Q4_0_4_4:
-        case GGML_TYPE_Q4_0_4_8:
-        case GGML_TYPE_Q4_0_8_8:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_I64:
-        case GGML_TYPE_F64:
-        case GGML_TYPE_COUNT:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_rope
-
-static float rope_yarn_ramp(const float low, const float high, const int i0) {
-    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
-    return 1 - MIN(1, MAX(0, y));
-}
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
-    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta) {
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
-    }
-    *cos_theta = cosf(theta) * mscale;
-    *sin_theta = sinf(theta) * mscale;
-}
-
-// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
-// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
-    return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
-}
-
-static void ggml_rope_cache_init(
-     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
-     float * cache, float sin_sign, float theta_scale) {
-    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
-    float theta = theta_base;
-    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
-        rope_yarn(
-            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
-        );
-        cache[i0 + 1] *= sin_sign;
-
-        theta *= theta_scale;
-    }
-}
-
-void ggml_rope_yarn_corr_dims(
-    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
-) {
-    // start and end correction dims
-    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
-    float end   =  ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
-    dims[0] = MAX(0, start);
-    dims[1] = MIN(n_dims - 1, end);
-}
-
-static void ggml_compute_forward_rope_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const bool forward) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src2 = dst->src[2];
-
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-
-    //const int n_past     = ((int32_t *) dst->op_params)[0];
-    const int n_dims     = ((int32_t *) dst->op_params)[1];
-    const int mode       = ((int32_t *) dst->op_params)[2];
-    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    GGML_ASSERT(n_dims <= ne0);
-    GGML_ASSERT(n_dims % 2 == 0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
-    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
-    const float * freq_factors = NULL;
-    if (src2 != NULL) {
-        GGML_ASSERT(src2->type == GGML_TYPE_F32);
-        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-        freq_factors = (const float *) src2->data;
-    }
-
-    // backward process uses inverse rotation by cos and sin.
-    // cos and sin build a rotation matrix, where the inverse is the transpose.
-    // this essentially just switches the sign of sin.
-    const float sin_sign = forward ? 1.0f : -1.0f;
-
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            const int64_t p = pos[i2];
-
-            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = src[0];
-                        const float x1 = src[1];
-
-                        dst_data[0] = x0*cos_theta - x1*sin_theta;
-                        dst_data[1] = x0*sin_theta + x1*cos_theta;
-                    }
-                } else {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                        const int64_t ic = i0/2;
-
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                        const float x0 = src[0];
-                        const float x1 = src[n_dims/2];
-
-                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
-                    }
-                }
-
-                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                    const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                    float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                    dst_data[0] = src[0];
-                    dst_data[1] = src[1];
-                }
-            }
-        }
-    }
-}
-
-// TODO: deduplicate f16/f32 code
-static void ggml_compute_forward_rope_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const bool forward) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src2 = dst->src[2];
-
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-
-    //const int n_past     = ((int32_t *) dst->op_params)[0];
-    const int n_dims     = ((int32_t *) dst->op_params)[1];
-    const int mode       = ((int32_t *) dst->op_params)[2];
-    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    GGML_ASSERT(n_dims <= ne0);
-    GGML_ASSERT(n_dims % 2 == 0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
-    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
-    const float * freq_factors = NULL;
-    if (src2 != NULL) {
-        GGML_ASSERT(src2->type == GGML_TYPE_F32);
-        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-        freq_factors = (const float *) src2->data;
-    }
-
-    // backward process uses inverse rotation by cos and sin.
-    // cos and sin build a rotation matrix, where the inverse is the transpose.
-    // this essentially just switches the sign of sin.
-    const float sin_sign = forward ? 1.0f : -1.0f;
-
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            const int64_t p = pos[i2];
-
-            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                if (!is_neox) {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[1]);
-
-                        dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    }
-                } else {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                        const int64_t ic = i0/2;
-
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
-
-                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    }
-                }
-
-                for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                    ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                    dst_data[0] = src[0];
-                    dst_data[1] = src[1];
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_rope(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_rope_f16(params, dst, true);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rope_f32(params, dst, true);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_rope_back
-
-static void ggml_compute_forward_rope_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_rope_f16(params, dst, false);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rope_f32(params, dst, false);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_conv_transpose_1d
-
-static void ggml_compute_forward_conv_transpose_1d_f16_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (ith == 0) {
-        memset(params->wdata, 0, params->wsize);
-
-        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
-                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        dst_data[i00*ne02 + i02] = src[i00];
-                    }
-                }
-            }
-        }
-
-        // permute source data (src1) from (L x Cin) to (Cin x L)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
-            ggml_fp16_t * dst_data = wdata;
-
-            for (int64_t i11 = 0; i11 < ne11; i11++) {
-                const float * const src = (float *)((char *) src1->data + i11*nb11);
-                for (int64_t i10 = 0; i10 < ne10; i10++) {
-                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
-                }
-            }
-        }
-
-        // need to zero dst since we are accumulating into it
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-    ggml_barrier(params->threadpool);
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
-    // total rows in dst
-    const int nr = ne1;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
-    ggml_fp16_t * const wdata_src = wdata + nk;
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
-        for (int i10 = 0; i10 < ne10; i10++) {
-            const int i1n = i10*ne11;
-            for (int i00 = 0; i00 < ne00; i00++) {
-                float v = 0;
-                ggml_vec_dot_f16(ne02, &v, 0,
-                        (ggml_fp16_t *)    wdata_src + i1n, 0,
-                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
-                dst_data[i10*s0 + i00] += v;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_conv_transpose_1d_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02;
-
-    GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (ith == 0) {
-        memset(params->wdata, 0, params->wsize);
-
-        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
-        {
-            float * const wdata = (float *) params->wdata + 0;
-
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
-                    float * dst_data = wdata + i01*ne00*ne02;
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        dst_data[i00*ne02 + i02] = src[i00];
-                    }
-                }
-            }
-        }
-
-        // prepare source data (src1)
-        {
-            float * const wdata = (float *) params->wdata + nk;
-            float * dst_data = wdata;
-
-            for (int64_t i11 = 0; i11 < ne11; i11++) {
-                const float * const src = (float *)((char *) src1->data + i11*nb11);
-                for (int64_t i10 = 0; i10 < ne10; i10++) {
-                    dst_data[i10*ne11 + i11] = src[i10];
-                }
-            }
-        }
-
-        // need to zero dst since we are accumulating into it
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-    ggml_barrier(params->threadpool);
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-
-    // total rows in dst
-    const int nr = ne1;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float * const wdata     = (float *) params->wdata + 0;
-    float * const wdata_src = wdata + nk;
-
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        float * wdata_kernel = wdata + i1*ne02*ne00;
-        for (int i10 = 0; i10 < ne10; i10++) {
-            const int i1n = i10*ne11;
-            for (int i00 = 0; i00 < ne00; i00++) {
-                float v = 0;
-                ggml_vec_dot_f32(ne02, &v, 0,
-                        wdata_src + i1n, 0,
-                        wdata_kernel + i00*ne02, 0, 1);
-                dst_data[i10*s0 + i00] += v;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_conv_transpose_1d(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_im2col_f32
-// src0: kernel [OC, IC, KH, KW]
-// src1: image [N, IC, IH, IW]
-// dst:  result [N, OH, OW, IC*KH*KW]
-static void ggml_compute_forward_im2col_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t N  = is_2D ? ne13 : ne12;
-    const int64_t IC = is_2D ? ne12 : ne11;
-    const int64_t IH = is_2D ? ne11 : 1;
-    const int64_t IW = ne10;
-
-    const int64_t KH = is_2D ? ne01 : 1;
-    const int64_t KW = ne00;
-
-    const int64_t OH = is_2D ? ne2 : 1;
-    const int64_t OW = ne1;
-
-    int ofs0 = is_2D ? nb13 : nb12;
-    int ofs1 = is_2D ? nb12 : nb11;
-
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
-    {
-        float * const wdata = (float *) dst->data;
-
-        for (int64_t in = 0; in < N; in++) {
-            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
-                for (int64_t iow = 0; iow < OW; iow++) {
-                    for (int64_t iic = ith; iic < IC; iic += nth) {
-
-                        // micro kernel
-                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
-
-                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
-                            for (int64_t ikw = 0; ikw < KW; ikw++) {
-                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
-                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
-
-                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
-                                } else {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-
-// ggml_compute_forward_im2col_f16
-// src0: kernel [OC, IC, KH, KW]
-// src1: image [N, IC, IH, IW]
-// dst:  result [N, OH, OW, IC*KH*KW]
-static void ggml_compute_forward_im2col_f16(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F16);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t N  = is_2D ? ne13 : ne12;
-    const int64_t IC = is_2D ? ne12 : ne11;
-    const int64_t IH = is_2D ? ne11 : 1;
-    const int64_t IW = ne10;
-
-    const int64_t KH = is_2D ? ne01 : 1;
-    const int64_t KW = ne00;
-
-    const int64_t OH = is_2D ? ne2 : 1;
-    const int64_t OW = ne1;
-
-    int ofs0 = is_2D ? nb13 : nb12;
-    int ofs1 = is_2D ? nb12 : nb11;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
-    {
-        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
-
-        for (int64_t in = 0; in < N; in++) {
-            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
-                for (int64_t iow = 0; iow < OW; iow++) {
-                    for (int64_t iic = ith; iic < IC; iic += nth) {
-
-                        // micro kernel
-                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
-
-                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
-                            for (int64_t ikw = 0; ikw < KW; ikw++) {
-                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
-                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
-
-                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
-                                } else {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_im2col(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-    switch (dst->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_im2col_f16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_im2col_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_im2col_back_f32
-
-static void ggml_compute_forward_im2col_back_f32(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t N  = is_2D ? ne3 : ne2;
-    const int64_t IC = is_2D ? ne2 : ne1;
-    const int64_t IH = is_2D ? ne1 : 1;
-    const int64_t IW = ne0;
-
-    const int64_t KH = is_2D ? ne01 : 1;
-    const int64_t KW = ne00;
-
-    const int64_t OH = is_2D ? ne12 : 1;
-    const int64_t OW = ne11;
-
-    int ofs0 = is_2D ? nb3 : nb2;
-    int ofs1 = is_2D ? nb2 : nb1;
-
-    GGML_ASSERT(nb0  == sizeof(float));
-
-    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
-    {
-        float * const wdata = (float *) dst->data;
-
-        for (int64_t in = 0; in < N; in++) {
-            for (int64_t iic = ith; iic < IC; iic += nth) {
-                for (int64_t iih = 0; iih < IH; iih++) {
-                    for (int64_t iiw = 0; iiw < IW; iiw++) {
-
-                        // micro kernel
-                        float grad = 0.0f;
-                        for (int64_t ikh = 0; ikh < KH; ikh++) {
-                            for (int64_t ikw = 0; ikw < KW; ikw++) {
-                                // For s0 > 1 some values were skipped over in the forward pass.
-                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
-                                const int64_t tmpw = (iiw + p0 - ikw*d0);
-                                if (tmpw % s0 != 0) {
-                                    continue;
-                                }
-                                const int64_t iow = tmpw / s0;
-
-                                // Equivalent logic as above except for s1.
-                                int64_t ioh;
-                                if (is_2D) {
-                                    const int64_t tmph = iih + p1 - ikh*d1;
-
-                                    if (tmph % s1 != 0) {
-                                        continue;
-                                    }
-
-                                    ioh = tmph / s1;
-                                } else {
-                                    ioh = 0;
-                                }
-
-                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
-                                    continue;
-                                }
-
-                                const float * const src_data = (const float *) src1->data
-                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                                grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
-                            }
-                        }
-                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
-                        dst_data[iih*IW + iiw] = grad;
-                    }
-                }
-            }
-        }
-    }
-}
-
-// ggml_compute_forward_conv_transpose_2d
-
-static void ggml_compute_forward_conv_transpose_2d(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nk = ne00*ne01*ne02*ne03;
-
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
-
-    if (ith == 0) {
-        memset(params->wdata, 0, params->wsize);
-
-        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-
-            for (int64_t i03 = 0; i03 < ne03; i03++) {
-                for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
-                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
-                    for (int64_t i01 = 0; i01 < ne01; i01++) {
-                        for (int64_t i00 = 0; i00 < ne00; i00++) {
-                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
-                        }
-                    }
-                }
-            }
-        }
-
-        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
-        {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
-            for (int i12 = 0; i12 < ne12; i12++) {
-                for (int i11 = 0; i11 < ne11; i11++) {
-                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
-                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
-                    for (int i10 = 0; i10 < ne10; i10++) {
-                        dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
-                    }
-                }
-            }
-        }
-
-        memset(dst->data, 0, ggml_nbytes(dst));
-    }
-    ggml_barrier(params->threadpool);
-
-    const int32_t stride = ggml_get_op_params_i32(dst, 0);
-
-    // total patches in dst
-    const int np = ne2;
-
-    // patches per thread
-    const int dp = (np + nth - 1)/nth;
-
-    // patch range for this thread
-    const int ip0 = dp*ith;
-    const int ip1 = MIN(ip0 + dp, np);
-
-    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-    ggml_fp16_t * const wdata_src = wdata + nk;
-
-    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
-        float * dst_data = (float *)((char *) dst->data + i2*nb2);
-        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
-        for (int i11 = 0; i11 < ne11; i11++) {
-            for (int i10 = 0; i10 < ne10; i10++) {
-                const int i1n = i11*ne10*ne12 + i10*ne12;
-                for (int i01 = 0; i01 < ne01; i01++) {
-                    for (int i00 = 0; i00 < ne00; i00++) {
-                        float v = 0;
-                        ggml_vec_dot_f16(ne03, &v, 0,
-                                wdata_src + i1n, 0,
-                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
-                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
-                    }
-                }
-            }
-        }
-    }
-}
-
-// ggml_compute_forward_pool_1d_sk_p0
-
-static void ggml_compute_forward_pool_1d_sk_p0(
-        const struct ggml_compute_params * params,
-        const enum ggml_op_pool op,
-        const int k,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src = dst->src[0];
-
-    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    const char * cdata = (const char *)src->data;
-    const char * const data_end = cdata + ggml_nbytes(src);
-    float * drow = (float *)dst->data;
-
-    const int64_t rs = dst->ne[0];
-
-    while (cdata < data_end) {
-        const void * srow = (const void *)cdata;
-        int j = 0;
-        for (int64_t i = 0; i < rs; ++i) {
-            switch (op) {
-                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
-                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
-                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
-            }
-            for (int ki = 0; ki < k; ++ki) {
-                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
-                switch (op) {
-                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
-                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
-                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
-                }
-                ++j;
-            }
-            switch (op) {
-                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
-                case GGML_OP_POOL_MAX:                       break;
-                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
-            }
-        }
-
-        cdata += src->nb[1];
-        drow  += rs;
-    }
-}
-
-// ggml_compute_forward_pool_1d
-
-static void ggml_compute_forward_pool_1d(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const int32_t * opts = (const int32_t *)dst->op_params;
-    enum ggml_op_pool op = opts[0];
-    const int k0 = opts[1];
-    const int s0 = opts[2];
-    const int p0 = opts[3];
-    GGML_ASSERT(p0 == 0); // padding not supported
-    GGML_ASSERT(k0 == s0); // only s = k supported
-
-    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
-}
-
-// ggml_compute_forward_pool_2d
-
-static void ggml_compute_forward_pool_2d(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src = dst->src[0];
-
-    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    const int32_t * opts = (const int32_t *)dst->op_params;
-    enum ggml_op_pool op = opts[0];
-    const int k0 = opts[1];
-    const int k1 = opts[2];
-    const int s0 = opts[3];
-    const int s1 = opts[4];
-    const int p0 = opts[5];
-    const int p1 = opts[6];
-    const char * cdata = (const char*)src->data;
-    const char * const data_end = cdata + ggml_nbytes(src);
-
-    const int64_t px = dst->ne[0];
-    const int64_t py = dst->ne[1];
-    const int64_t pa = px * py;
-
-    float * dplane = (float *)dst->data;
-
-    const int ka = k0 * k1;
-    const int offset0 = -p0;
-    const int offset1 = -p1;
-
-    while (cdata < data_end) {
-        for (int oy = 0; oy < py; ++oy) {
-            float * const drow = dplane + oy * px;
-            for (int ox = 0; ox < px; ++ox) {
-                float * const out =  drow + ox;
-                switch (op) {
-                    case GGML_OP_POOL_AVG:     *out = 0;        break;
-                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
-                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
-                }
-
-                const int ix = offset0 + ox * s0;
-                const int iy = offset1 + oy * s1;
-
-                for (int ky = 0; ky < k1; ++ky) {
-                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
-                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
-                    for (int kx = 0; kx < k0; ++kx) {
-                        int j = ix + kx;
-                        if (j < 0 || j >= src->ne[0]) continue;
-                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
-                        switch (op) {
-                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
-                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
-                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
-                        }
-                    }
-                }
-                switch (op) {
-                    case GGML_OP_POOL_AVG:           *out /= ka; break;
-                    case GGML_OP_POOL_MAX:                       break;
-                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
-                }
-            }
-        }
-
-        cdata  += src->nb[2];
-        dplane += pa;
-    }
-}
-
-// ggml_compute_forward_pool_2d_back
-
-static void ggml_compute_forward_pool_2d_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src  = dst->src[0];
-    const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
-
-    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    const int32_t * opts = (const int32_t *)dst->op_params;
-    enum ggml_op_pool op = opts[0];
-    const int k0 = opts[1];
-    const int k1 = opts[2];
-    const int s0 = opts[3];
-    const int s1 = opts[4];
-    const int p0 = opts[5];
-    const int p1 = opts[6];
-
-    char       * cdata  = (char       *) dst->data;
-    const char * cdataf = (const char *) dstf->data;
-    const char * const data_end = cdata + ggml_nbytes(dst);
-
-    GGML_ASSERT(params->ith == 0);
-    memset(cdata, 0, ggml_nbytes(dst));
-
-    const int64_t px = src->ne[0];
-    const int64_t py = src->ne[1];
-    const int64_t pa = px * py;
-
-    const float * splane = (const float *) src->data;
-
-    const int ka = k0 * k1;
-    const int offset0 = -p0;
-    const int offset1 = -p1;
-
-    while (cdata < data_end) {
-        for (int oy = 0; oy < py; ++oy) {
-            const float * const srow = splane + oy * px;
-            for (int ox = 0; ox < px; ++ox) {
-                const float grad0 = srow[ox];
-
-                const int ix = offset0 + ox * s0;
-                const int iy = offset1 + oy * s1;
-
-                if (op == GGML_OP_POOL_MAX) {
-                    float maxval = -FLT_MAX;
-                    int kxmax = -1;
-                    int kymax = -1;
-
-                    for (int ky = 0; ky < k1; ++ky) {
-                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
-                            continue;
-                        }
-                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
-                        for (int kx = 0; kx < k0; ++kx) {
-                            int j = ix + kx;
-                            if (j < 0 || j >= dst->ne[0]) {
-                                continue;
-                            }
-
-                            const float val = dst->type == GGML_TYPE_F32 ?
-                                ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
-                            if (val <= maxval) {
-                                continue;
-                            }
-
-                            maxval = val;
-                            kxmax = kx;
-                            kymax = ky;
-                        }
-                    }
-
-                    if (kxmax == -1 || kymax == -1) {
-                        continue;
-                    }
-
-                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
-                    const int j = ix + kxmax;
-                    if (dst->type == GGML_TYPE_F32) {
-                        ((float *) drow)[j] += grad0;
-                    } else {
-                        ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
-                    }
-                } else if (op == GGML_OP_POOL_AVG) {
-                    const float grad = grad0 / ka;
-
-                    for (int ky = 0; ky < k1; ++ky) {
-                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
-                            continue;
-                        }
-                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
-                        for (int kx = 0; kx < k0; ++kx) {
-                            int j = ix + kx;
-                            if (j < 0 || j >= dst->ne[0]) {
-                                continue;
-                            }
-
-                            if (dst->type == GGML_TYPE_F32) {
-                                ((float *) drow)[j] += grad;
-                            } else {
-                                ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
-                            }
-                        }
-                    }
-                } else {
-                    GGML_ASSERT(false);
-                }
-            }
-        }
-
-        cdata  += dst->nb[2];
-        cdataf += dst->nb[2];
-        splane += pa;
-    }
-}
-
-// ggml_compute_forward_upscale
-
-static void ggml_compute_forward_upscale_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const float sf0 = (float)ne0/src0->ne[0];
-    const float sf1 = (float)ne1/src0->ne[1];
-    const float sf2 = (float)ne2/src0->ne[2];
-    const float sf3 = (float)ne3/src0->ne[3];
-
-    // TODO: optimize
-
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        const int64_t i03 = i3 / sf3;
-        for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
-            const int64_t i02 = i2 / sf2;
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
-                const int64_t i01 = i1 / sf1;
-                for (int64_t i0 = 0; i0 < ne0; i0++) {
-                    const int64_t i00 = i0 / sf0;
-
-                    const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                          float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
-
-                    *y = *x;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_upscale(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_upscale_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-// ggml_compute_forward_pad
-
-static void ggml_compute_forward_pad_f32(
-    const struct ggml_compute_params * params,
-          struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    float * dst_ptr = (float *) dst->data;
-
-    // TODO: optimize
-
-    for (int64_t i2 = 0; i2 < ne2; ++i2) {
-        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                for (int64_t i3 = 0; i3 < ne3; ++i3) {
-                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
-
-                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-
-                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
-                        dst_ptr[dst_idx] = *src_ptr;
-                    } else {
-                        dst_ptr[dst_idx] = 0;
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_pad(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_pad_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-
-// ggml_compute_forward_arange
-
-static void ggml_compute_forward_arange_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    GGML_ASSERT(dst->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const float start = ggml_get_op_params_f32(dst, 0);
-    const float stop  = ggml_get_op_params_f32(dst, 1);
-    const float step  = ggml_get_op_params_f32(dst, 2);
-
-    const int64_t steps = (int64_t) ceilf((stop - start) / step);
-
-    GGML_ASSERT(ggml_nelements(dst) == steps);
-
-    for (int64_t i = ith; i < steps; i+= nth) {
-        float value = start + step * i;
-        ((float *)dst->data)[i] = value;
-    }
-}
-
-static void ggml_compute_forward_arange(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-    switch (dst->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_arange_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_timestep_embedding_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int dim = ggml_get_op_params_i32(dst, 0);
-    const int max_period = ggml_get_op_params_i32(dst, 1);
-
-    int half = dim / 2;
-
-    for (int64_t i = 0; i < ne00; i++) {
-        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
-        for (int64_t j = ith; j < half; j += nth) {
-            float timestep = ((float *)src0->data)[i];
-            float freq = (float)expf(-logf(max_period) * j / half);
-            float arg = timestep * freq;
-            embed_data[j] = cosf(arg);
-            embed_data[j + half] = sinf(arg);
-        }
-        if (dim % 2 != 0 && ith == 0) {
-            embed_data[dim] = 0.f;
-        }
-    }
-}
-
-static void ggml_compute_forward_timestep_embedding(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_timestep_embedding_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_argsort
-
-static void ggml_compute_forward_argsort_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    GGML_ASSERT(nb0 == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t nr = ggml_nrows(src0);
-
-    enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
-
-    for (int64_t i = ith; i < nr; i += nth) {
-        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
-        const float * src_data = (float *)((char *) src0->data + i*nb01);
-
-        for (int64_t j = 0; j < ne0; j++) {
-            dst_data[j] = j;
-        }
-
-        // C doesn't have a functional sort, so we do a bubble sort instead
-        for (int64_t j = 0; j < ne0; j++) {
-            for (int64_t k = j + 1; k < ne0; k++) {
-                if ((order == GGML_SORT_ORDER_ASC  && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
-                    (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
-                    int32_t tmp = dst_data[j];
-                    dst_data[j] = dst_data[k];
-                    dst_data[k] = tmp;
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_argsort(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_argsort_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_flash_attn_ext
-
-static void ggml_compute_forward_flash_attn_ext_f16(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * q,
-        const struct ggml_tensor * k,
-        const struct ggml_tensor * v,
-        const struct ggml_tensor * mask,
-        struct ggml_tensor * dst) {
-
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t D = neq0;
-    const int64_t N = neq1;
-
-    GGML_ASSERT(ne0 == D);
-    GGML_ASSERT(ne2 == N);
-
-    // input tensor rows must be contiguous
-    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
-    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
-    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
-
-    GGML_ASSERT(neq0 == D);
-    GGML_ASSERT(nek0 == D);
-    GGML_ASSERT(nev0 == D);
-
-    GGML_ASSERT(neq1 == N);
-    GGML_ASSERT(nev0 == D);
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    // broadcast factors
-    const int64_t rk2 = neq2/nek2;
-    const int64_t rk3 = neq3/nek3;
-
-    const int64_t rv2 = neq2/nev2;
-    const int64_t rv3 = neq3/nev3;
-
-    // parallelize by q rows using ggml_vec_dot_f32
-
-    // total rows in q
-    const int nr = neq1*neq2*neq3;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    float scale         = 1.0f;
-    float max_bias      = 0.0f;
-    float logit_softcap = 0.0f;
-
-    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
-
-    if (logit_softcap != 0) {
-        scale /= logit_softcap;
-    }
-
-    const uint32_t n_head      = neq2;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    enum ggml_type    const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type;
-    ggml_from_float_t const q_to_vec_dot   = ggml_get_type_traits(k_vec_dot_type)->from_float;
-    ggml_vec_dot_t    const kq_vec_dot     = type_traits_cpu[k->type].vec_dot;
-    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
-
-    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
-    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
-
-    // loop over n_batch and n_head
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // q indices
-        const int iq3 = ir/(neq2*neq1);
-        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
-        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
-
-        const uint32_t h = iq2; // head index
-        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
-
-        float S = 0.0f;      // sum
-        float M = -INFINITY; // maximum KQ value
-
-        float       * VKQ32 = (float       *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
-        float       * V32   =                 (VKQ32 + 1*D); // (temporary) FP32 V buffer
-        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
-        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
-
-        if (v->type == GGML_TYPE_F16) {
-            memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
-        } else {
-            memset(VKQ32, 0, D*sizeof(float));
-        }
-
-        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
-
-        // k indices
-        const int ik3 = iq3 / rk3;
-        const int ik2 = iq2 / rk2;
-
-        // v indices
-        const int iv3 = iq3 / rv3;
-        const int iv2 = iq2 / rv2;
-
-        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
-        q_to_vec_dot(pq, Q_q, D);
-
-        // online softmax / attention
-        // loop over n_kv and n_head_kv
-        // ref: https://arxiv.org/pdf/2112.05682.pdf
-        for (int64_t ic = 0; ic < nek1; ++ic) {
-            const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
-            if (mv == -INFINITY) {
-                continue;
-            }
-
-            float s; // KQ value
-
-            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
-            kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
-
-            s = s*scale; // scale KQ value
-
-            if (logit_softcap != 0.0f) {
-                s = logit_softcap*tanhf(s);
-            }
-
-            s += mv; // apply mask
-
-            const float Mold = M;
-
-            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
-            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
-
-            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
-
-            if (v->type == GGML_TYPE_F16) {
-                if (s > M) {
-                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
-                    M = s;
-                    ms = expf(Mold - M);
-
-                    // V = V*expf(Mold - M)
-                    ggml_vec_scale_f16(D, VKQ16, ms);
-                } else {
-                    // no new maximum, ms == 1.0f, vs != 1.0f
-                    vs = expf(s - M);
-                }
-
-                // V += v*expf(s - M)
-                ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
-            } else {
-                if (s > M) {
-                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
-                    M = s;
-                    ms = expf(Mold - M);
-
-                    // V = V*expf(Mold - M)
-                    ggml_vec_scale_f32(D, VKQ32, ms);
-                } else {
-                    // no new maximum, ms == 1.0f, vs != 1.0f
-                    vs = expf(s - M);
-                }
-
-                v_to_float(v_data, V32, D);
-
-                // V += v*expf(s - M)
-                ggml_vec_mad_f32(D, VKQ32, V32, vs);
-            }
-
-            S = S*ms + vs; // scale and increment sum with partial sum
-        }
-
-        if (v->type == GGML_TYPE_F16) {
-            for (int64_t d = 0; d < D; ++d) {
-                VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
-            }
-        }
-
-        // V /= S
-        const float S_inv = 1.0f/S;
-        ggml_vec_scale_f32(D, VKQ32, S_inv);
-
-        // dst indices
-        const int i1 = iq1;
-        const int i2 = iq2;
-        const int i3 = iq3;
-
-        // original
-        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
-
-        // permute(0, 2, 1, 3)
-        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
-    }
-}
-
-static void ggml_compute_forward_flash_attn_ext(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * q,
-        const struct ggml_tensor * k,
-        const struct ggml_tensor * v,
-        const struct ggml_tensor * mask,
-        struct ggml_tensor * dst) {
-    switch (dst->op_params[3]) {
-        case GGML_PREC_DEFAULT:
-        case GGML_PREC_F32:
-            {
-                // uses F32 accumulators
-                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_flash_attn_back
-
-static void ggml_compute_forward_flash_attn_back_f32(
-        const struct ggml_compute_params * params,
-        const bool masked,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * q = dst->src[0];
-    const struct ggml_tensor * k = dst->src[1];
-    const struct ggml_tensor * v = dst->src[2];
-    const struct ggml_tensor * d = dst->src[3];
-
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
-    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
-    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t D = neq0;
-    const int64_t N = neq1;
-    const int64_t P = nek1 - N;
-    const int64_t M = P + N;
-
-    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
-    const int mxDM = MAX(D, Mup);
-
-    // GGML_ASSERT(ne0 == D);
-    // GGML_ASSERT(ne1 == N);
-    GGML_ASSERT(P >= 0);
-
-    GGML_ASSERT(nbq0 == sizeof(float));
-    GGML_ASSERT(nbk0 == sizeof(float));
-    GGML_ASSERT(nbv0 == sizeof(float));
-
-    GGML_ASSERT(neq0 == D);
-    GGML_ASSERT(nek0 == D);
-    GGML_ASSERT(nev1 == D);
-    GGML_ASSERT(ned0 == D);
-
-    GGML_ASSERT(neq1 == N);
-    GGML_ASSERT(nek1 == N + P);
-    GGML_ASSERT(nev1 == D);
-    GGML_ASSERT(ned1 == N);
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    if (ith == 0) {
-        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
-    }
-    ggml_barrier(params->threadpool);
-
-    const int64_t elem_q = ggml_nelements(q);
-    const int64_t elem_k = ggml_nelements(k);
-
-    enum ggml_type result_type = dst->type;
-    GGML_ASSERT(ggml_blck_size(result_type) == 1);
-    const size_t tsize = ggml_type_size(result_type);
-
-    const size_t offs_q = 0;
-    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
-    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
-
-    void * grad_q = (char *) dst->data;
-    void * grad_k = (char *) dst->data + offs_k;
-    void * grad_v = (char *) dst->data + offs_v;
-
-    const size_t nbgq1 = nb0*neq0;
-    const size_t nbgq2 = nb0*neq0*neq1;
-    const size_t nbgq3 = nb0*neq0*neq1*neq2;
-
-    const size_t nbgk1 = nb0*nek0;
-    const size_t nbgk2 = nb0*nek0*nek1;
-    const size_t nbgk3 = nb0*nek0*nek1*neq2;
-
-    const size_t nbgv1 = nb0*nev0;
-    const size_t nbgv2 = nb0*nev0*nev1;
-    const size_t nbgv3 = nb0*nev0*nev1*neq2;
-
-    // parallelize by k rows using ggml_vec_dot_f32
-
-    // total rows in k
-    const int nr = nek2*nek3;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    const float scale = 1.0f/sqrtf(D);
-
-    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
-
-    // how often k2 (and v2) is repeated in q2
-    int nrep = neq2/nek2;
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // q indices
-        const int ik3 = ir/(nek2);
-        const int ik2 = ir - ik3*nek2;
-
-        const int iq3 = ik3;
-        const int id3 = ik3;
-        const int iv3 = ik3;
-        const int iv2 = ik2;
-
-        for (int irep = 0; irep < nrep; ++irep) {
-            const int iq2 = ik2 + irep*nek2;
-            const int id2 = iq2;
-
-            // (ik2 + irep*nek2) % nek2 == ik2
-            for (int iq1 = 0; iq1 < neq1; ++iq1) {
-                const int id1 = iq1;
-
-                // not sure about CACHE_LINE_SIZE_F32..
-                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
-                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
-                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
-
-                for (int i = M; i < Mup; ++i) {
-                    S[i] = -INFINITY;
-                }
-
-                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
-                for (int64_t ic = 0; ic < masked_begin; ++ic) {
-                    // k indices
-                    const int ik1 = ic;
-
-                    // S indices
-                    const int i1 = ik1;
-
-                    ggml_vec_dot_f32(neq0,
-                            S + i1, 0,
-                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
-                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
-                }
-
-                // scale
-                ggml_vec_scale_f32(masked_begin, S, scale);
-
-                for (int64_t i = masked_begin; i < M; i++) {
-                    S[i] = -INFINITY;
-                }
-
-                // softmax
-                // exclude known -INF S[..] values from max and loop
-                // dont forget to set their SM values to zero
-                {
-                    float max = -INFINITY;
-                    ggml_vec_max_f32(masked_begin, &max, S);
-
-                    ggml_float sum = 0.0;
-                    {
-#ifdef GGML_SOFT_MAX_ACCELERATE
-                        max = -max;
-                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
-                        vvexpf(SM, SM, &Mup);
-                        ggml_vec_sum_f32(Mup, &sum, SM);
-#else
-                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
-#endif
-                    }
-
-                    assert(sum > 0.0);
-
-                    sum = 1.0/sum;
-                    ggml_vec_scale_f32(masked_begin, SM, sum);
-
-                }
-
-                // step-by-step explanation
-                {
-                    // forward-process                    shape      grads from backward process
-                    // parallel_for ik2,ik3:
-                    //  for irep:
-                    //   iq2 = ik2 + irep*nek2
-                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
-                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
-                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
-                    //   for iq1:
-                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
-                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
-                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
-                    //    S0     = -Inf                   [D,1,1,1]
-                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
-                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
-                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
-                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
-                    //   ~S5[i]  = dot(vcur[:,i], S4)
-                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
-                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
-                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
-                    // dst                               backward-/ grad[dst]                 = d
-                    //
-                    // output gradients with their dependencies:
-                    //
-                    // grad[kcur] = grad[S1].T @ qcur
-                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
-                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                    // grad[S4]   = grad[S5] @ vcur
-                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
-                    // grad[qcur] = grad[S1]   @ kcur
-                    // grad[vcur] = grad[S5].T @ S4
-                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
-                    //
-                    // in post-order:
-                    //
-                    // S1         = qcur @ kcur.T
-                    // S2         = S1 * scale
-                    // S3         = diag_mask_inf(S2, P)
-                    // S4         = softmax(S3)
-                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
-                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
-                    // grad[qcur] = grad[S1]   @ kcur
-                    // grad[kcur] = grad[S1].T @ qcur
-                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
-                    //
-                    // using less variables (SM=S4):
-                    //
-                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
-                    // SM            = softmax(S)
-                    // S             = d[:D,iq1,iq2,iq3] @ vcur
-                    // dot_SM_gradSM = dot(SM, S)
-                    // S             = SM * (S - dot(SM, S))
-                    // S             = diag_mask_zero(S, P) * scale
-                    //
-                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
-                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
-                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
-                }
-
-                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
-                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
-                // for ic:
-                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
-                // exclude known future zero S[..] values from operation
-                ggml_vec_set_f32(masked_begin, S, 0);
-                for (int64_t ic = 0; ic < D; ++ic) {
-                    ggml_vec_mad_f32(masked_begin,
-                            S,
-                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
-                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
-                }
-
-                // S = SM * (S - dot(SM, S))
-                float dot_SM_gradSM = 0;
-                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
-                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
-                ggml_vec_mul_f32 (masked_begin, S, S, SM);
-
-                // S = diag_mask_zero(S, P) * scale
-                // already done by above ggml_vec_set_f32
-
-                // exclude known zero S[..] values from operation
-                ggml_vec_scale_f32(masked_begin, S, scale);
-
-                // S    shape [M,1]
-                // SM   shape [M,1]
-                // kcur shape [D,M]
-                // qcur shape [D,1]
-                // vcur shape [M,D]
-
-                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
-                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
-                // for ic:
-                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
-                // exclude known zero S[..] values from loop
-                for (int64_t ic = 0; ic < masked_begin; ++ic) {
-                    ggml_vec_mad_f32(D,
-                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
-                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
-                            S[ic]);
-                }
-
-                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
-                // for ic:
-                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
-                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
-                // exclude known zero S[..] values from loop
-                for (int64_t ic = 0; ic < masked_begin; ++ic) {
-                    ggml_vec_mad_f32(D,
-                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
-                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
-                            S[ic]);
-                }
-
-                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
-                // for ic:
-                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
-                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
-                // exclude known zero SM[..] values from mad
-                for (int64_t ic = 0; ic < D; ++ic) {
-                    ggml_vec_mad_f32(masked_begin,
-                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
-                            SM,
-                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_flash_attn_back(
-        const struct ggml_compute_params * params,
-        const bool masked,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * q = dst->src[0];
-
-    switch (q->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_ssm_conv
-
-static void ggml_compute_forward_ssm_conv_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0]; // conv_x
-    const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nc  = src1->ne[0]; // d_conv
-    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
-    const int nr  = src0->ne[1]; // d_inner
-    const int n_t =  dst->ne[1]; // tokens per sequence
-    const int n_s =  dst->ne[2]; // number of sequences in the batch
-
-    GGML_ASSERT( dst->ne[0] == nr);
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT(src1->nb[0] == sizeof(float));
-    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-    const int ir  = ir1 - ir0;
-
-    for (int i3 = 0; i3 < n_s; ++i3) {
-        for (int i2 = 0; i2 < n_t; ++i2) {
-            // {d_conv - 1 + n_t, d_inner, n_seqs}
-            // sliding window
-            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
-            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
-            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
-
-            // TODO: transpose the output for smaller strides for big batches?
-            // d_inner
-            for (int i1 = 0; i1 < ir; ++i1) {
-                // rowwise dot product
-                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
-                float sumf = 0.0f;
-
-                // d_conv
-                for (int i0 = 0; i0 < nc; ++i0) {
-                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
-                }
-                x[i1] = sumf;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_ssm_conv(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    switch (dst->src[0]->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_ssm_conv_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_ssm_scan
-
-static void ggml_compute_forward_ssm_scan_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0]; // s
-    const struct ggml_tensor * src1 = dst->src[1]; // x
-    const struct ggml_tensor * src2 = dst->src[2]; // dt
-    const struct ggml_tensor * src3 = dst->src[3]; // A
-    const struct ggml_tensor * src4 = dst->src[4]; // B
-    const struct ggml_tensor * src5 = dst->src[5]; // C
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int64_t nc  = src0->ne[0]; // d_state
-    const int64_t nr  = src0->ne[1]; // d_inner
-    const int64_t n_t = src1->ne[1]; // number of tokens per sequence
-    const int64_t n_s = src0->ne[2]; // number of sequences in the batch
-
-    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT(src1->nb[0] == sizeof(float));
-    GGML_ASSERT(src2->nb[0] == sizeof(float));
-    GGML_ASSERT(src3->nb[0] == sizeof(float));
-    GGML_ASSERT(src4->nb[0] == sizeof(float));
-    GGML_ASSERT(src5->nb[0] == sizeof(float));
-    // required for the dot product between s and C
-    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
-    // required for per-sequence offsets for states
-    GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
-    // required to get correct offset for state destination (i.e. src1->nb[3])
-    GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-    const int ir  = ir1 - ir0;
-
-    for (int i3 = 0; i3 < n_s; ++i3) {
-        for (int i2 = 0; i2 < n_t; ++i2) {
-            const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
-            const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-            const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
-            const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
-            const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
-            const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
-                  float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-                  float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
-
-            // use the output as the source for the next token-wise iterations
-            if (i2 > 0) { s0 = s; }
-
-            // d_inner
-            for (int i1 = 0; i1 < ir; ++i1) {
-                // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
-                float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
-                float x_dt = x[i1] * dt_soft_plus;
-                float sumf = 0.0f;
-                // d_state
-                for (int i0 = 0; i0 < nc; ++i0) {
-                    int i = i0 + i1*nc;
-                    // state = prev_state * dA + dB * x
-                    float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
-                    // y = rowwise_dotprod(state, C)
-                    sumf += state * C[i0];
-                    s[i] = state;
-                }
-                y[i1] = sumf;
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_ssm_scan(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    switch (dst->src[0]->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_ssm_scan_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_win_part
-
-static void ggml_compute_forward_win_part_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    UNUSED(params);
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
-
-    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
-    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
-    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
-
-    assert(ne00 == ne0);
-    assert(ne3  == nep0*nep1);
-
-    // TODO: optimize / multi-thread
-    for (int py = 0; py < nep1; ++py) {
-        for (int px = 0; px < nep0; ++px) {
-            const int64_t i3 = py*nep0 + px;
-            for (int64_t i2 = 0; i2 < ne2; ++i2) {
-                for (int64_t i1 = 0; i1 < ne1; ++i1) {
-                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                        const int64_t i02 = py*w + i2;
-                        const int64_t i01 = px*w + i1;
-                        const int64_t i00 = i0;
-
-                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
-                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
-
-                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
-                            ((float *) dst->data)[i] = 0.0f;
-                        } else {
-                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_win_part(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_win_part_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_win_unpart
-
-static void ggml_compute_forward_win_unpart_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    UNUSED(params);
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
-
-    const int32_t w = ((const int32_t *)(dst->op_params))[0];
-
-    // padding
-    const int px = (w - ne1%w)%w;
-    //const int py = (w - ne2%w)%w;
-
-    const int npx = (px + ne1)/w;
-    //const int npy = (py + ne2)/w;
-
-    assert(ne0 == ne00);
-
-    // TODO: optimize / multi-thread
-    for (int64_t i2 = 0; i2 < ne2; ++i2) {
-        for (int64_t i1 = 0; i1 < ne1; ++i1) {
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                const int ip2 = i2/w;
-                const int ip1 = i1/w;
-
-                const int64_t i02 = i2%w;
-                const int64_t i01 = i1%w;
-                const int64_t i00 = i0;
-
-                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
-                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
-
-                ((float *) dst->data)[j] = ((float *) src0->data)[i];
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_win_unpart(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_win_unpart_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-//gmml_compute_forward_unary
-
-static void ggml_compute_forward_unary(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const enum ggml_unary_op op = ggml_get_unary_op(dst);
-
-    switch (op) {
-        case GGML_UNARY_OP_ABS:
-            {
-                ggml_compute_forward_abs(params, dst);
-            } break;
-        case GGML_UNARY_OP_SGN:
-            {
-                ggml_compute_forward_sgn(params, dst);
-            } break;
-        case GGML_UNARY_OP_NEG:
-            {
-                ggml_compute_forward_neg(params, dst);
-            } break;
-        case GGML_UNARY_OP_STEP:
-            {
-                ggml_compute_forward_step(params, dst);
-            } break;
-        case GGML_UNARY_OP_TANH:
-            {
-                ggml_compute_forward_tanh(params, dst);
-            } break;
-        case GGML_UNARY_OP_ELU:
-            {
-                ggml_compute_forward_elu(params, dst);
-            } break;
-        case GGML_UNARY_OP_RELU:
-            {
-                ggml_compute_forward_relu(params, dst);
-            } break;
-        case GGML_UNARY_OP_SIGMOID:
-            {
-                ggml_compute_forward_sigmoid(params, dst);
-            } break;
-        case GGML_UNARY_OP_GELU:
-            {
-                ggml_compute_forward_gelu(params, dst);
-            } break;
-        case GGML_UNARY_OP_GELU_QUICK:
-            {
-                ggml_compute_forward_gelu_quick(params, dst);
-            } break;
-        case GGML_UNARY_OP_SILU:
-            {
-                ggml_compute_forward_silu(params, dst);
-            } break;
-        case GGML_UNARY_OP_HARDSWISH:
-            {
-                ggml_compute_forward_hardswish(params, dst);
-            } break;
-        case GGML_UNARY_OP_HARDSIGMOID:
-            {
-                ggml_compute_forward_hardsigmoid(params, dst);
-            } break;
-        case GGML_UNARY_OP_EXP:
-            {
-                ggml_compute_forward_exp(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_get_rel_pos
-
-static void ggml_compute_forward_get_rel_pos_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    UNUSED(params);
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int64_t w = ne1;
-
-    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
-    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
-
-    for (int64_t i2 = 0; i2 < ne2; ++i2) {
-        for (int64_t i1 = 0; i1 < ne1; ++i1) {
-            const int64_t pos = (w - i1 - 1) + i2;
-            for (int64_t i0 = 0; i0 < ne0; ++i0) {
-                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_get_rel_pos(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-        case GGML_TYPE_BF16:
-            {
-                ggml_compute_forward_get_rel_pos_f16(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_add_rel_pos
-
-static void ggml_compute_forward_add_rel_pos_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src2 = dst->src[2];
-
-    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
-    if (!inplace) {
-        if (params->ith == 0) {
-            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
-        }
-        ggml_barrier(params->threadpool);
-    }
-    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
-
-    float * src1_data = (float *) src1->data;
-    float * src2_data = (float *) src2->data;
-    float * dst_data  = (float *) dst->data;
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    // total patches in dst
-    const int np = ne13;
-
-    // patches per thread
-    const int dp = (np + nth - 1)/nth;
-
-    // patch range for this thread
-    const int ip0 = dp*ith;
-    const int ip1 = MIN(ip0 + dp, np);
-
-    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
-        for (int64_t i12 = 0; i12 < ne12; ++i12) {
-            for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
-                for (int64_t i10 = 0; i10 < ne10; ++i10) {
-                    const int64_t jp0  = jp1 + i10;
-                    const float src1_e = src1_data[jp0];
-                    const float src2_e = src2_data[jp0];
-
-                    const int64_t jdh = jp0 * ne10;
-                    const int64_t jdw = jdh - (ne10 - 1) * i10;
-
-                    for (int64_t j = 0; j < ne10; ++j) {
-                        dst_data[jdh + j     ] += src2_e;
-                        dst_data[jdw + j*ne10] += src1_e;
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_add_rel_pos(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_add_rel_pos_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_rwkv_wkv6
-
-static void ggml_compute_forward_rwkv_wkv6_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-    const int64_t T = dst->src[1]->ne[3];
-    const int64_t C = dst->ne[0];
-    const int64_t HEADS = dst->src[1]->ne[2];
-    const int64_t n_seqs = dst->src[5]->ne[1];
-    const int64_t head_size = C / HEADS;
-
-    float * dst_data = (float *) dst->data;
-    float * state = ((float *) dst->data) + C * T;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    if (ith >= HEADS) {
-        return;
-    }
-
-    const int h_start = (HEADS * ith) / nth;
-    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
-                (HEADS * (ith + 1)) / nth : HEADS;
-
-    float * k =          (float *) dst->src[0]->data;
-    float * v =          (float *) dst->src[1]->data;
-    float * r =          (float *) dst->src[2]->data;
-    float * time_faaaa = (float *) dst->src[3]->data;
-    float * time_decay = (float *) dst->src[4]->data;
-
-    size_t t_stride = HEADS * head_size; // Same to C
-
-    size_t h_stride = C / HEADS;
-    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
-    size_t h_stride_2d = head_size * head_size;
-
-    if (ith == 0) {
-        memset(dst_data, 0, T * C * sizeof(float));
-    }
-    ggml_barrier(params->threadpool);
-
-
-    #if defined(__AVX__) && !defined(__AVX512F__)
-        #define GGML_F32X GGML_F32x8
-        #define GGML_F32X_SET1 GGML_F32x8_SET1
-        #define GGML_F32X_LOAD GGML_F32x8_LOAD
-        #define GGML_F32X_STORE GGML_F32x8_STORE
-        #define GGML_F32X_MUL GGML_F32x8_MUL
-        #define GGML_F32X_FMA GGML_F32x8_FMA
-        #define WKV_VECTOR_SIZE 8
-    #elif defined(__AVX512F__)
-        #define GGML_F32X GGML_F32x16
-        #define GGML_F32X_SET1 GGML_F32x16_SET1
-        #define GGML_F32X_LOAD GGML_F32x16_LOAD
-        #define GGML_F32X_STORE GGML_F32x16_STORE
-        #define GGML_F32X_MUL GGML_F32x16_MUL
-        #define GGML_F32X_FMA GGML_F32x16_FMA
-        #define WKV_VECTOR_SIZE 16
-    #elif defined(__ARM_NEON) && defined(__aarch64__)
-        #define GGML_F32X GGML_F32x4
-        #define GGML_F32X_SET1 GGML_F32x4_SET1
-        #define GGML_F32X_LOAD GGML_F32x4_LOAD
-        #define GGML_F32X_STORE GGML_F32x4_STORE
-        #define GGML_F32X_MUL GGML_F32x4_MUL
-        #define GGML_F32X_FMA GGML_F32x4_FMA
-        #define WKV_VECTOR_SIZE 4
-    #endif
-
-    #ifdef WKV_VECTOR_SIZE
-        const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
-
-        for (int64_t t = 0; t < T; t++) {
-            size_t t_offset = t * t_stride;
-            size_t state_offset = head_size * C * (t / (T / n_seqs));
-            float * state_cur = state + state_offset;
-            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
-
-            for (int64_t h = h_start; h < h_end; h++) {
-                size_t h_offset = h * h_stride;
-                size_t t_h_offset = t_offset + h_offset;
-                size_t h_2d_offset = h * h_stride_2d;
-
-                for (int64_t i = 0; i < head_size; i++) {
-                    size_t t_h_i_offset = t_h_offset + i;
-                    size_t h_i_offset = h_offset + i;
-                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
-
-                    float k_val = k[t_h_i_offset];
-                    float r_val = r[t_h_i_offset];
-                    float time_faaaa_val = time_faaaa[h_i_offset];
-                    float time_decay_val = time_decay[t_h_i_offset];
-
-                    // Broadcast scalar values to vectors
-                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
-                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
-                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
-                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
-
-                    for (int64_t j = 0; j < vec_count; j++) {
-                        size_t base_j = j * WKV_VECTOR_SIZE;
-                        size_t t_h_j_offset = t_h_offset + base_j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
-
-                        // Load x elements at once
-                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
-                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
-                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
-
-                        // Compute kv = v * k
-                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
-
-                        // Compute temp = kv * time_faaaa + prev_state
-                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
-
-                        // Update dst: dst += temp * r
-                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
-                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
-
-                        // Update state: state = prev_state * time_decay + kv
-                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
-                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
-                    }
-
-                    // Handle remaining elements, this will not be used.
-                    for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
-                        size_t t_h_j_offset = t_h_offset + j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
-                        float v_val = v[t_h_j_offset];
-                        float kv_val = v_val * k_val;
-                        float prev_state_val = state_prev[h_2d_i_j_offset];
-                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
-                        dst_data[t_h_j_offset] += temp_val * r_val;
-                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
-                    }
-                }
-            }
-        }
-
-    #else
-        // basically fused operations:
-        // dst = r @ (time_faaaa * (k @ v) + state),
-        // state = time_decay * state + (k @ v),
-        // recursive through each token
-        for (int64_t t = 0; t < T; t++) {
-            size_t t_offset = t * t_stride;
-            size_t state_offset = head_size * C * (t / (T / n_seqs));
-            float * state_cur = state + state_offset;
-            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
-
-            for (int64_t h = h_start; h < h_end; h++) {
-                size_t h_offset = h * h_stride;
-                size_t t_h_offset = t_offset + h_offset;
-                size_t h_2d_offset = h * h_stride_2d;
-
-                for (int64_t i = 0; i < head_size; i++) {
-                    size_t t_h_i_offset = t_h_offset + i;
-                    size_t h_i_offset = h_offset + i;
-                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
-
-                    float k_val = k[t_h_i_offset];
-                    float r_val = r[t_h_i_offset];
-                    float time_faaaa_val = time_faaaa[h_i_offset];
-                    // RWKV v6: different time_decay for each token.
-                    float time_decay_val = time_decay[t_h_i_offset];
-
-                    for (int64_t j = 0; j < head_size; j++) {
-                        size_t t_h_j_offset = t_h_offset + j;
-                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
-
-                        float v_val = v[t_h_j_offset];
-                        float kv_val = v_val * k_val;
-                        float prev_state_val = state_prev[h_2d_i_j_offset];
-                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
-                        dst_data[t_h_j_offset] += temp_val * r_val;
-                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
-                    }
-                }
-            }
-        }
-    #endif
-}
-
-
-static void ggml_compute_forward_rwkv_wkv6(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_unary
-
-static void ggml_compute_forward_map_unary_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_map_unary(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_unary_f32(params, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_binary
-
-static void ggml_compute_forward_map_binary_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(src1));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
-    }
-}
-
-static void ggml_compute_forward_map_binary(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_binary_f32(params, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_custom1
-
-static void ggml_compute_forward_map_custom1_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_custom1_op_f32_t fun) {
-
-    const struct ggml_tensor * a = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a);
-}
-
-// ggml_compute_forward_map_custom2
-
-static void ggml_compute_forward_map_custom2_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_custom2_op_f32_t fun) {
-
-    const struct ggml_tensor * a = dst->src[0];
-    const struct ggml_tensor * b = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a, b);
-}
-
-// ggml_compute_forward_map_custom3
-
-static void ggml_compute_forward_map_custom3_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst,
-        const ggml_custom3_op_f32_t fun) {
-
-    const struct ggml_tensor * a = dst->src[0];
-    const struct ggml_tensor * b = dst->src[1];
-    const struct ggml_tensor * c = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a, b, c);
-}
-
-// ggml_compute_forward_map_custom1
-
-static void ggml_compute_forward_map_custom1(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * a = dst->src[0];
-
-    struct ggml_map_custom1_op_params p;
-    memcpy(&p, dst->op_params, sizeof(p));
-
-    p.fun(dst, a, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_map_custom2
-
-static void ggml_compute_forward_map_custom2(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * a = dst->src[0];
-    const struct ggml_tensor * b = dst->src[1];
-
-    struct ggml_map_custom2_op_params p;
-    memcpy(&p, dst->op_params, sizeof(p));
-
-    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_map_custom3
-
-static void ggml_compute_forward_map_custom3(
-        const struct ggml_compute_params * params,
-              struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * a = dst->src[0];
-    const struct ggml_tensor * b = dst->src[1];
-    const struct ggml_tensor * c = dst->src[2];
-
-    struct ggml_map_custom3_op_params p;
-    memcpy(&p, dst->op_params, sizeof(p));
-
-    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
-}
-
-// ggml_compute_forward_cross_entropy_loss
-
-static void ggml_compute_forward_cross_entropy_loss_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    GGML_ASSERT(ggml_is_scalar(dst));
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-    // TODO: handle transposed/permuted matrices
-    const int64_t nc = src0->ne[0];
-    const int64_t nr = ggml_nrows(src0);
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    float * sums =  (float *) params->wdata;
-    float * st   = ((float *) params->wdata) + nth + ith*nc;
-    float sum_thread = 0.0f;
-
-    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
-        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
-        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
-
-#ifndef NDEBUG
-        for (int64_t i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(s0[i]));
-            assert(!isnan(s1[i]));
-        }
-#endif
-
-        float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, s0);
-        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
-        assert(sum_softmax >= 0.0);
-
-        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
-        ggml_vec_mul_f32(nc, st, st, s1);
-
-        float sum_st = 0.0f;
-        ggml_vec_sum_f32(nc, &sum_st, st);
-        sum_thread += sum_st;
-
-#ifndef NDEBUG
-        for (int64_t i = 0; i < nc; ++i) {
-            assert(!isnan(st[i]));
-            assert(!isinf(st[i]));
-        }
-#endif
-    }
-    sums[ith] = sum_thread;
-    ggml_barrier(params->threadpool);
-
-    if (ith == 0) {
-        float * dp = (float *) dst->data;
-        ggml_vec_sum_f32(nth, dp, sums);
-        dp[0] *= -1.0f / (float) nr;
-    }
-}
-
-static void ggml_compute_forward_cross_entropy_loss(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_cross_entropy_loss_back
-
-static void ggml_compute_forward_cross_entropy_loss_back_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * opt0 = dst->src[2];
-
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int64_t ith = params->ith;
-    const int64_t nth = params->nth;
-
-    // TODO: handle transposed/permuted matrices
-    const int64_t nc = src0->ne[0];
-    const int64_t nr = ggml_nrows(src0);
-
-    // rows per thread
-    const int64_t dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int64_t ir0 = dr*ith;
-    const int64_t ir1 = MIN(ir0 + dr, nr);
-
-    const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
-
-    for (int64_t i1 = ir0; i1 < ir1; i1++) {
-        float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
-        float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
-
-#ifndef NDEBUG
-        for (int64_t i = 0; i < nc; ++i) {
-            //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(s0[i]));
-            assert(!isnan(s1[i]));
-        }
-#endif
-
-        // soft_max
-        float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, s0);
-        ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
-        assert(sum > 0.0);
-        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
-
-        // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
-        ggml_vec_sub_f32(nc, ds0, ds0, s1);
-        ggml_vec_scale_f32(nc, ds0, d_by_nr);
-
-#ifndef NDEBUG
-        for (int64_t i = 0; i < nc; ++i) {
-            assert(!isnan(ds0[i]));
-            assert(!isinf(ds0[i]));
-        }
-#endif
-    }
-}
-
-static void ggml_compute_forward_cross_entropy_loss_back(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static void ggml_compute_forward_opt_step_adamw_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0        = dst->src[0];
-    const struct ggml_tensor * src0_grad   = dst->src[1];
-    const struct ggml_tensor * src0_grad_m = dst->src[2];
-    const struct ggml_tensor * src0_grad_v = dst->src[3];
-    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr  = ggml_nrows(src0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    /* const float   gnorm = 1.0f; */
-    int64_t       iter;   memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
-    const float   alpha = ggml_get_op_params_f32(dst, 2);
-    const float   beta1 = ggml_get_op_params_f32(dst, 3);
-    const float   beta2 = ggml_get_op_params_f32(dst, 4);
-    const float   eps   = ggml_get_op_params_f32(dst, 5);
-    const float   wd    = ggml_get_op_params_f32(dst, 6);
-
-    const float beta1h  = alpha/(1.0f - powf(beta1, iter));
-    const float beta2h  =  1.0f/(1.0f - powf(beta2, iter));
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        const int64_t i03 = ir/(ne02*ne01);
-        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
-        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
-
-        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
-        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
-        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
-        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
-
-        for (int i00 = 0; i00 < ne00; ++i00) {
-            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
-            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
-
-            const float mh =       m[i00]*beta1h;
-            const float vh = sqrtf(v[i00]*beta2h) + eps;
-
-            // The weight decay is applied independently of the Adam momenta m and v.
-            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
-            // See: https://arxiv.org/pdf/1711.05101v3.pdf
-            w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
-        }
-    }
-
-    ggml_barrier(params->threadpool);
-    if (ith != 0) {
-        return;
-    }
-
-    iter++;
-    memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
-}
-
-static void ggml_compute_forward_opt_step_adamw(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_opt_step_adamw_f32(params, dst);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-/////////////////////////////////
-
-static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
-    GGML_ASSERT(params);
-
-    if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
-        return;
-    }
-
-    switch (tensor->op) {
-        case GGML_OP_DUP:
-            {
-                ggml_compute_forward_dup(params, tensor);
-            } break;
-        case GGML_OP_ADD:
-            {
-                ggml_compute_forward_add(params, tensor);
-            } break;
-        case GGML_OP_ADD1:
-            {
-                ggml_compute_forward_add1(params, tensor);
-            } break;
-        case GGML_OP_ACC:
-            {
-                ggml_compute_forward_acc(params, tensor);
-            } break;
-        case GGML_OP_SUB:
-            {
-                ggml_compute_forward_sub(params, tensor);
-            } break;
-        case GGML_OP_MUL:
-            {
-                ggml_compute_forward_mul(params, tensor);
-            } break;
-        case GGML_OP_DIV:
-            {
-                ggml_compute_forward_div(params, tensor);
-            } break;
-        case GGML_OP_SQR:
-            {
-                ggml_compute_forward_sqr(params, tensor);
-            } break;
-        case GGML_OP_SQRT:
-            {
-                ggml_compute_forward_sqrt(params, tensor);
-            } break;
-        case GGML_OP_LOG:
-            {
-                ggml_compute_forward_log(params, tensor);
-            } break;
-        case GGML_OP_SIN:
-            {
-                ggml_compute_forward_sin(params, tensor);
-            } break;
-        case GGML_OP_COS:
-            {
-                ggml_compute_forward_cos(params, tensor);
-            } break;
-        case GGML_OP_SUM:
-            {
-                ggml_compute_forward_sum(params, tensor);
-            } break;
-        case GGML_OP_SUM_ROWS:
-            {
-                ggml_compute_forward_sum_rows(params, tensor);
-            } break;
-        case GGML_OP_MEAN:
-            {
-                ggml_compute_forward_mean(params, tensor);
-            } break;
-        case GGML_OP_ARGMAX:
-            {
-                ggml_compute_forward_argmax(params, tensor);
-            } break;
-        case GGML_OP_COUNT_EQUAL:
-            {
-                ggml_compute_forward_count_equal(params, tensor);
-            } break;
-        case GGML_OP_REPEAT:
-            {
-                ggml_compute_forward_repeat(params, tensor);
-            } break;
-        case GGML_OP_REPEAT_BACK:
-            {
-                ggml_compute_forward_repeat_back(params, tensor);
-            } break;
-        case GGML_OP_CONCAT:
-            {
-                ggml_compute_forward_concat(params, tensor);
-            } break;
-        case GGML_OP_SILU_BACK:
-            {
-                ggml_compute_forward_silu_back(params, tensor);
-            } break;
-        case GGML_OP_NORM:
-            {
-                ggml_compute_forward_norm(params, tensor);
-            } break;
-        case GGML_OP_RMS_NORM:
-            {
-                ggml_compute_forward_rms_norm(params, tensor);
-            } break;
-        case GGML_OP_RMS_NORM_BACK:
-            {
-                ggml_compute_forward_rms_norm_back(params, tensor);
-            } break;
-        case GGML_OP_GROUP_NORM:
-            {
-                ggml_compute_forward_group_norm(params, tensor);
-            } break;
-        case GGML_OP_MUL_MAT:
-            {
-                ggml_compute_forward_mul_mat(params, tensor);
-            } break;
-        case GGML_OP_MUL_MAT_ID:
-            {
-                ggml_compute_forward_mul_mat_id(params, tensor);
-            } break;
-        case GGML_OP_OUT_PROD:
-            {
-                ggml_compute_forward_out_prod(params, tensor);
-            } break;
-        case GGML_OP_SCALE:
-            {
-                ggml_compute_forward_scale(params, tensor);
-            } break;
-        case GGML_OP_SET:
-            {
-                ggml_compute_forward_set(params, tensor);
-            } break;
-        case GGML_OP_CPY:
-            {
-                ggml_compute_forward_cpy(params, tensor);
-            } break;
-        case GGML_OP_CONT:
-            {
-                ggml_compute_forward_cont(params, tensor);
-            } break;
-        case GGML_OP_RESHAPE:
-            {
-                ggml_compute_forward_reshape(params, tensor);
-            } break;
-        case GGML_OP_VIEW:
-            {
-                ggml_compute_forward_view(params, tensor);
-            } break;
-        case GGML_OP_PERMUTE:
-            {
-                ggml_compute_forward_permute(params, tensor);
-            } break;
-        case GGML_OP_TRANSPOSE:
-            {
-                ggml_compute_forward_transpose(params, tensor);
-            } break;
-        case GGML_OP_GET_ROWS:
-            {
-                ggml_compute_forward_get_rows(params, tensor);
-            } break;
-        case GGML_OP_GET_ROWS_BACK:
-            {
-                ggml_compute_forward_get_rows_back(params, tensor);
-            } break;
-        case GGML_OP_DIAG:
-            {
-                ggml_compute_forward_diag(params, tensor);
-            } break;
-        case GGML_OP_DIAG_MASK_INF:
-            {
-                ggml_compute_forward_diag_mask_inf(params, tensor);
-            } break;
-        case GGML_OP_DIAG_MASK_ZERO:
-            {
-                ggml_compute_forward_diag_mask_zero(params, tensor);
-            } break;
-        case GGML_OP_SOFT_MAX:
-            {
-                ggml_compute_forward_soft_max(params, tensor);
-            } break;
-        case GGML_OP_SOFT_MAX_BACK:
-            {
-                ggml_compute_forward_soft_max_back(params, tensor);
-            } break;
-        case GGML_OP_ROPE:
-            {
-                ggml_compute_forward_rope(params, tensor);
-            } break;
-        case GGML_OP_ROPE_BACK:
-            {
-                ggml_compute_forward_rope_back(params, tensor);
-            } break;
-        case GGML_OP_CLAMP:
-            {
-                ggml_compute_forward_clamp(params, tensor);
-            } break;
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            {
-                ggml_compute_forward_conv_transpose_1d(params, tensor);
-            } break;
-        case GGML_OP_IM2COL:
-            {
-                ggml_compute_forward_im2col(params, tensor);
-            } break;
-        case GGML_OP_IM2COL_BACK:
-            {
-                ggml_compute_forward_im2col_back_f32(params, tensor);
-            } break;
-        case GGML_OP_CONV_TRANSPOSE_2D:
-            {
-                ggml_compute_forward_conv_transpose_2d(params, tensor);
-            } break;
-        case GGML_OP_POOL_1D:
-            {
-                ggml_compute_forward_pool_1d(params, tensor);
-            } break;
-        case GGML_OP_POOL_2D:
-            {
-                ggml_compute_forward_pool_2d(params, tensor);
-            } break;
-        case GGML_OP_POOL_2D_BACK:
-            {
-                ggml_compute_forward_pool_2d_back(params, tensor);
-            } break;
-        case GGML_OP_UPSCALE:
-            {
-                ggml_compute_forward_upscale(params, tensor);
-            } break;
-        case GGML_OP_PAD:
-            {
-                ggml_compute_forward_pad(params, tensor);
-            } break;
-        case GGML_OP_ARANGE:
-            {
-                ggml_compute_forward_arange(params, tensor);
-            } break;
-        case GGML_OP_TIMESTEP_EMBEDDING:
-            {
-                ggml_compute_forward_timestep_embedding(params, tensor);
-            } break;
-        case GGML_OP_ARGSORT:
-            {
-                ggml_compute_forward_argsort(params, tensor);
-            } break;
-        case GGML_OP_LEAKY_RELU:
-            {
-                ggml_compute_forward_leaky_relu(params, tensor);
-            } break;
-        case GGML_OP_FLASH_ATTN_EXT:
-            {
-                ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
-            } break;
-        case GGML_OP_FLASH_ATTN_BACK:
-            {
-                int32_t t = ggml_get_op_params_i32(tensor, 0);
-                GGML_ASSERT(t == 0 || t == 1);
-                bool masked = t != 0;
-                ggml_compute_forward_flash_attn_back(params, masked, tensor);
-            } break;
-        case GGML_OP_SSM_CONV:
-            {
-                ggml_compute_forward_ssm_conv(params, tensor);
-            } break;
-        case GGML_OP_SSM_SCAN:
-            {
-                ggml_compute_forward_ssm_scan(params, tensor);
-            } break;
-        case GGML_OP_WIN_PART:
-            {
-                ggml_compute_forward_win_part(params, tensor);
-            } break;
-        case GGML_OP_WIN_UNPART:
-            {
-                ggml_compute_forward_win_unpart(params, tensor);
-            } break;
-        case GGML_OP_UNARY:
-            {
-                ggml_compute_forward_unary(params, tensor);
-            } break;
-        case GGML_OP_GET_REL_POS:
-            {
-                ggml_compute_forward_get_rel_pos(params, tensor);
-            } break;
-        case GGML_OP_ADD_REL_POS:
-            {
-                ggml_compute_forward_add_rel_pos(params, tensor);
-            } break;
-        case GGML_OP_RWKV_WKV6:
-            {
-                ggml_compute_forward_rwkv_wkv6(params, tensor);
-            } break;
-        case GGML_OP_MAP_UNARY:
-            {
-                ggml_unary_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_unary(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_BINARY:
-            {
-                ggml_binary_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_binary(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM1_F32:
-            {
-                ggml_custom1_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom1_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM2_F32:
-            {
-                ggml_custom2_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom2_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM3_F32:
-            {
-                ggml_custom3_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom3_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM1:
-            {
-                ggml_compute_forward_map_custom1(params, tensor);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM2:
-            {
-                ggml_compute_forward_map_custom2(params, tensor);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM3:
-            {
-                ggml_compute_forward_map_custom3(params, tensor);
-            }
-            break;
-        case GGML_OP_CROSS_ENTROPY_LOSS:
-            {
-                ggml_compute_forward_cross_entropy_loss(params, tensor);
-            }
-            break;
-        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
-            {
-                ggml_compute_forward_cross_entropy_loss_back(params, tensor);
-            }
-            break;
-        case GGML_OP_OPT_STEP_ADAMW:
-            {
-                ggml_compute_forward_opt_step_adamw(params, tensor);
-            }
-            break;
-        case GGML_OP_NONE:
-            {
-                // nop
-            } break;
-        case GGML_OP_COUNT:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// Android's libc implementation "bionic" does not support setting affinity
-#if defined(__gnu_linux__)
-static void set_numa_thread_affinity(int thread_n) {
-    if (!ggml_is_numa()) {
-        return;
-    }
-
-    int node_num;
-    int rv;
-    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
-
-    switch(g_state.numa.numa_strategy) {
-        case GGML_NUMA_STRATEGY_DISTRIBUTE:
-            // run thread on node_num thread_n / (threads per node)
-            node_num = thread_n % g_state.numa.n_nodes;
-            break;
-        case GGML_NUMA_STRATEGY_ISOLATE:
-            // run thread on current_node
-            node_num = g_state.numa.current_node;
-            break;
-        case GGML_NUMA_STRATEGY_NUMACTL:
-            // use the cpuset that numactl gave us
-            rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset);
-            if (rv) {
-                fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv));
-            }
-            return;
-        default:
-            return;
-    }
-
-    struct ggml_numa_node * node = &g_state.numa.nodes[node_num];
-
-    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
-    CPU_ZERO_S(setsize, cpus);
-    for (size_t i = 0; i < node->n_cpus; ++i) {
-        CPU_SET_S(node->cpus[i], setsize, cpus);
-    }
-
-    rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
-    if (rv) {
-            fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
-    }
-
-    CPU_FREE(cpus);
-}
-
-static void clear_numa_thread_affinity(void) {
-    if (!ggml_is_numa()) {
-        return;
-    }
-
-    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
-
-    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
-    CPU_ZERO_S(setsize, cpus);
-    for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
-        CPU_SET_S(i, setsize, cpus);
-    }
-
-    int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
-    if (rv) {
-        fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
-    }
-
-    CPU_FREE(cpus);
-}
-#else
-// TODO: Windows etc.
-// (the linux implementation may also work on BSD, someone should test)
-static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n);  }
-static void clear_numa_thread_affinity(void) {}
-#endif
-
-static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
-    int n_tasks = 0;
-
-    if (ggml_is_empty(node)) {
-        // no need to multi-thread a no-op
-        n_tasks = 1;
-        return n_tasks;
-    }
-
-    switch (node->op) {
-        case GGML_OP_CPY:
-        case GGML_OP_DUP:
-        case GGML_OP_CONT:
-        case GGML_OP_ADD:
-        case GGML_OP_ADD1:
-        case GGML_OP_ACC:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_SUB:
-        case GGML_OP_SQR:
-        case GGML_OP_SQRT:
-        case GGML_OP_LOG:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-        case GGML_OP_SUM:
-        case GGML_OP_SUM_ROWS:
-        case GGML_OP_MEAN:
-        case GGML_OP_ARGMAX:
-            {
-                n_tasks = 1;
-            } break;
-        case GGML_OP_COUNT_EQUAL:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_REPEAT:
-        case GGML_OP_REPEAT_BACK:
-        case GGML_OP_LEAKY_RELU:
-            {
-                n_tasks = 1;
-            } break;
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(node)) {
-                case GGML_UNARY_OP_ABS:
-                case GGML_UNARY_OP_SGN:
-                case GGML_UNARY_OP_NEG:
-                case GGML_UNARY_OP_STEP:
-                case GGML_UNARY_OP_TANH:
-                case GGML_UNARY_OP_ELU:
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_SIGMOID:
-                case GGML_UNARY_OP_HARDSWISH:
-                case GGML_UNARY_OP_HARDSIGMOID:
-                case GGML_UNARY_OP_EXP:
-                    {
-                        n_tasks = 1;
-                    } break;
-
-                case GGML_UNARY_OP_GELU:
-                case GGML_UNARY_OP_GELU_QUICK:
-                case GGML_UNARY_OP_SILU:
-                    {
-                        n_tasks = n_threads;
-                    } break;
-                default:
-                    GGML_ABORT("fatal error");
-            }
-            break;
-        case GGML_OP_SILU_BACK:
-        case GGML_OP_MUL:
-        case GGML_OP_DIV:
-        case GGML_OP_NORM:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_RMS_NORM_BACK:
-        case GGML_OP_GROUP_NORM:
-        case GGML_OP_CONCAT:
-        case GGML_OP_MUL_MAT:
-        case GGML_OP_MUL_MAT_ID:
-        case GGML_OP_OUT_PROD:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_GET_ROWS:
-            {
-                // FIXME: get_rows can use additional threads, but the cost of launching additional threads
-                // decreases performance with GPU offloading
-                //n_tasks = n_threads;
-                n_tasks = 1;
-            } break;
-        case GGML_OP_SCALE:
-        case GGML_OP_SET:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_GET_ROWS_BACK:
-        case GGML_OP_DIAG:
-            {
-                n_tasks = 1;
-            } break;
-        case GGML_OP_DIAG_MASK_ZERO:
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX_BACK:
-        case GGML_OP_ROPE:
-        case GGML_OP_ROPE_BACK:
-        case GGML_OP_ADD_REL_POS:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_CLAMP:
-            {
-                n_tasks = 1; //TODO
-            } break;
-        case GGML_OP_SOFT_MAX:
-            {
-                n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
-            } break;
-        case GGML_OP_IM2COL:
-        case GGML_OP_IM2COL_BACK:
-        case GGML_OP_CONV_TRANSPOSE_1D:
-        case GGML_OP_CONV_TRANSPOSE_2D:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_POOL_1D:
-        case GGML_OP_POOL_2D:
-        case GGML_OP_POOL_2D_BACK:
-            {
-                n_tasks = 1;
-            } break;
-        case GGML_OP_UPSCALE:
-        case GGML_OP_PAD:
-        case GGML_OP_ARANGE:
-        case GGML_OP_TIMESTEP_EMBEDDING:
-        case GGML_OP_ARGSORT:
-        case GGML_OP_FLASH_ATTN_EXT:
-        case GGML_OP_FLASH_ATTN_BACK:
-        case GGML_OP_SSM_CONV:
-        case GGML_OP_SSM_SCAN:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_WIN_PART:
-        case GGML_OP_WIN_UNPART:
-        case GGML_OP_GET_REL_POS:
-        case GGML_OP_RWKV_WKV6:
-        case GGML_OP_MAP_UNARY:
-        case GGML_OP_MAP_BINARY:
-        case GGML_OP_MAP_CUSTOM1_F32:
-        case GGML_OP_MAP_CUSTOM2_F32:
-        case GGML_OP_MAP_CUSTOM3_F32:
-            {
-                n_tasks = 1;
-            } break;
-        case GGML_OP_MAP_CUSTOM1:
-            {
-                struct ggml_map_custom1_op_params p;
-                memcpy(&p, node->op_params, sizeof(p));
-                if (p.n_tasks == GGML_N_TASKS_MAX) {
-                    n_tasks = n_threads;
-                } else {
-                    n_tasks = MIN(p.n_tasks, n_threads);
-                }
-            } break;
-        case GGML_OP_MAP_CUSTOM2:
-            {
-                struct ggml_map_custom2_op_params p;
-                memcpy(&p, node->op_params, sizeof(p));
-                if (p.n_tasks == GGML_N_TASKS_MAX) {
-                    n_tasks = n_threads;
-                } else {
-                    n_tasks = MIN(p.n_tasks, n_threads);
-                }
-            } break;
-        case GGML_OP_MAP_CUSTOM3:
-            {
-                struct ggml_map_custom3_op_params p;
-                memcpy(&p, node->op_params, sizeof(p));
-                if (p.n_tasks == GGML_N_TASKS_MAX) {
-                    n_tasks = n_threads;
-                } else {
-                    n_tasks = MIN(p.n_tasks, n_threads);
-                }
-            } break;
-        case GGML_OP_CROSS_ENTROPY_LOSS:
-        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
-        case GGML_OP_OPT_STEP_ADAMW:
-            {
-                n_tasks = n_threads;
-            } break;
-        case GGML_OP_NONE:
-            {
-                n_tasks = 1;
-            } break;
-        case GGML_OP_COUNT:
-            {
-                GGML_ABORT("fatal error");
-            }
-        default:
-            {
-                fprintf(stderr, "%s: op not implemented: ", __func__);
-                if (node->op < GGML_OP_COUNT) {
-                    fprintf(stderr, "%s\n", ggml_op_name(node->op));
-                } else {
-                    fprintf(stderr, "%d\n", node->op);
-                }
-                GGML_ABORT("fatal error");
-            }
-    }
-
-    assert(n_tasks > 0);
-
-    return n_tasks;
-}
-
-static thread_ret_t ggml_graph_compute_secondary_thread(void* data);
-
-#if defined(_WIN32)
-#include "windows.h"
-
-// TODO: support > 64 CPUs
-bool ggml_thread_apply_affinity(bool * mask) {
-    HANDLE    h = GetCurrentThread();
-    uint64_t  bitmask = 0ULL;
-
-    assert(GGML_MAX_N_THREADS >= 64);
-
-    for (int32_t i = 0; i < 8; i++) {
-        int32_t idx = i * 8;
-        uint8_t val = 0;
-        val |= mask[idx + 0] << 0;
-        val |= mask[idx + 1] << 1;
-        val |= mask[idx + 2] << 2;
-        val |= mask[idx + 3] << 3;
-        val |= mask[idx + 4] << 4;
-        val |= mask[idx + 5] << 5;
-        val |= mask[idx + 6] << 6;
-        val |= mask[idx + 7] << 7;
-        bitmask |= (uint64_t)val << idx;
-    }
-
-    for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) {
-        if (mask[i]) {
-            fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n");
-            break;
-        }
-    }
-
-    DWORD_PTR m = (DWORD_PTR)bitmask;
-
-    m = SetThreadAffinityMask(h, m);
-
-    return m != 0;
-}
-
-static bool ggml_thread_apply_priority(int32_t prio) {
-    // Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
-    // This is up to the applications.
-    DWORD p = THREAD_PRIORITY_NORMAL;
-    switch (prio) {
-        case GGML_SCHED_PRIO_NORMAL:   p = THREAD_PRIORITY_NORMAL;        break;
-        case GGML_SCHED_PRIO_MEDIUM:   p = THREAD_PRIORITY_ABOVE_NORMAL;  break;
-        case GGML_SCHED_PRIO_HIGH:     p = THREAD_PRIORITY_HIGHEST;       break;
-        case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
-    }
-
-    if (prio == GGML_SCHED_PRIO_NORMAL) {
-        // Keep inherited policy/priority
-        return true;
-    }
-
-    if (!SetThreadPriority(GetCurrentThread(), p)) {
-        fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError());
-        return false;
-    }
-
-    return true;
-}
-
-#elif defined(__APPLE__)
-#include <sys/types.h>
-#include <sys/resource.h>
-
-static bool ggml_thread_apply_affinity(const bool * mask) {
-    // Not supported on Apple platforms
-    UNUSED(mask);
-    return true;
-}
-
-static bool ggml_thread_apply_priority(int32_t prio) {
-    struct sched_param p;
-    int32_t policy = SCHED_OTHER;
-    switch (prio) {
-        case GGML_SCHED_PRIO_NORMAL:   policy = SCHED_OTHER; p.sched_priority = 0;  break;
-        case GGML_SCHED_PRIO_MEDIUM:   policy = SCHED_FIFO;  p.sched_priority = 40; break;
-        case GGML_SCHED_PRIO_HIGH:     policy = SCHED_FIFO;  p.sched_priority = 80; break;
-        case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO;  p.sched_priority = 90; break;
-    }
-
-    if (prio == GGML_SCHED_PRIO_NORMAL) {
-        // Keep inherited policy/priority
-        return true;
-    }
-
-    int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
-    if (err != 0) {
-        fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
-        return false;
-    }
-
-    return true;
-}
-
-#elif defined(__gnu_linux__)
-// TODO: this may not work on BSD, to be verified
-
-static bool ggml_thread_apply_affinity(const bool * mask) {
-    cpu_set_t cpuset;
-    int err;
-
-    CPU_ZERO(&cpuset);
-
-    for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
-        if (mask[i]) {
-            GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i);
-            CPU_SET(i, &cpuset);
-        }
-    }
-
-#ifdef __ANDROID__
-    err = sched_setaffinity(0, sizeof(cpuset), &cpuset);
-    if (err < 0) {
-        err = errno;
-    }
-#else
-    err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);
-#endif
-    if (err != 0) {
-        fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err);
-        return false;
-    }
-
-    return true;
-}
-
-static bool ggml_thread_apply_priority(int32_t prio) {
-    struct sched_param p;
-    int32_t policy = SCHED_OTHER;
-    switch (prio) {
-        case GGML_SCHED_PRIO_NORMAL:   policy = SCHED_OTHER; p.sched_priority = 0;  break;
-        case GGML_SCHED_PRIO_MEDIUM:   policy = SCHED_FIFO;  p.sched_priority = 40; break;
-        case GGML_SCHED_PRIO_HIGH:     policy = SCHED_FIFO;  p.sched_priority = 80; break;
-        case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO;  p.sched_priority = 90; break;
-    }
-
-    if (prio == GGML_SCHED_PRIO_NORMAL) {
-        // Keep inherited policy/priority
-        return true;
-    }
-
-    int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
-    if (err != 0) {
-        fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
-        return false;
-    }
-
-    return true;
-}
-
-#else // unsupported platforms
-
-static bool ggml_thread_apply_affinity(const bool * mask) {
-    UNUSED(mask);
-    return true;
-}
-
-static bool ggml_thread_apply_priority(int32_t prio) {
-    UNUSED(prio);
-    return true;
-}
-
-#endif
-
-static bool ggml_thread_cpumask_is_valid(const bool * mask) {
-    for (int i = 0; i < GGML_MAX_N_THREADS; i++) {
-        if (mask[i]) { return true; }
-    }
-    return false;
-}
-
-static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {
-    if (!strict) {
-        memcpy(local_mask, global_mask, GGML_MAX_N_THREADS);
-        return;
-    } else {
-        memset(local_mask, 0, GGML_MAX_N_THREADS);
-        int32_t base_idx = *iter;
-        for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
-            int32_t idx = base_idx + i;
-            if (idx >= GGML_MAX_N_THREADS) {
-                // Just a cheaper modulo
-                idx -= GGML_MAX_N_THREADS;
-            }
-            if (global_mask[idx]) {
-                local_mask[idx] = 1;
-                *iter = idx + 1;
-                return;
-            }
-        }
-    }
-}
-
-void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
-    if (!threadpool) return;
-
-    const int n_threads = threadpool->n_threads_max;
-
-#ifndef GGML_USE_OPENMP
-    struct ggml_compute_state* workers = threadpool->workers;
-
-    ggml_mutex_lock(&threadpool->mutex);
-
-    threadpool->stop = true;
-    threadpool->pause = false;
-
-    ggml_cond_broadcast(&threadpool->cond);
-    ggml_mutex_unlock(&threadpool->mutex);
-
-    for (int j = 1; j < n_threads; j++) {
-        int32_t rc = ggml_thread_join(workers[j].thrd, NULL);
-        GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED);
-        UNUSED(rc);
-    }
-
-    ggml_mutex_destroy(&threadpool->mutex);
-    ggml_cond_destroy(&threadpool->cond);
-#endif // GGML_USE_OPENMP
-
-    const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;
-    ggml_aligned_free(threadpool->workers, workers_size);
-    ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));
-}
-
-#ifndef GGML_USE_OPENMP
-// pause/resume must be called under mutex
-static void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) {
-    GGML_PRINT_DEBUG("Pausing threadpool\n");
-    threadpool->pause = true;
-    ggml_cond_broadcast(&threadpool->cond);
-}
-
-static void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) {
-    GGML_PRINT_DEBUG("Resuming threadpool\n");
-    threadpool->pause = false;
-    ggml_cond_broadcast(&threadpool->cond);
-}
-#endif
-
-void ggml_threadpool_pause(struct ggml_threadpool * threadpool) {
-#ifndef GGML_USE_OPENMP
-    ggml_mutex_lock(&threadpool->mutex);
-    if (!threadpool->pause) {
-       ggml_threadpool_pause_locked(threadpool);
-    }
-    ggml_mutex_unlock(&threadpool->mutex);
-#else
-    UNUSED(threadpool);
-#endif
-}
-
-void ggml_threadpool_resume(struct ggml_threadpool * threadpool) {
-#ifndef GGML_USE_OPENMP
-    ggml_mutex_lock(&threadpool->mutex);
-    if (threadpool->pause) {
-       ggml_threadpool_resume_locked(threadpool);
-    }
-    ggml_mutex_unlock(&threadpool->mutex);
-#else
-    UNUSED(threadpool);
-#endif
-}
-
-struct ggml_cplan ggml_graph_plan(
-          const struct ggml_cgraph * cgraph,
-                               int   n_threads,
-            struct ggml_threadpool * threadpool) {
-
-    if (threadpool == NULL) {
-        //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
-    }
-    if (n_threads <= 0) {
-        n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS;
-    }
-
-    size_t work_size = 0;
-
-    struct ggml_cplan cplan;
-    memset(&cplan, 0, sizeof(struct ggml_cplan));
-
-    int max_tasks = 1;
-
-    // thread scheduling for the different operations + work buffer size estimation
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * node = cgraph->nodes[i];
-
-        const int n_tasks = ggml_get_n_tasks(node, n_threads);
-
-        max_tasks = MAX(max_tasks, n_tasks);
-
-        size_t cur = 0;
-
-        switch (node->op) {
-            case GGML_OP_CPY:
-            case GGML_OP_DUP:
-                {
-                    if (ggml_is_quantized(node->type) ||
-                        // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
-                        (node->src[0]->type == GGML_TYPE_F16  && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
-                        (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
-                        cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
-                    }
-                } break;
-            case GGML_OP_ADD:
-            case GGML_OP_ADD1:
-                {
-                    if (ggml_is_quantized(node->src[0]->type)) {
-                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
-                    }
-                } break;
-            case GGML_OP_ACC:
-                {
-                    if (ggml_is_quantized(node->src[0]->type)) {
-                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
-                    }
-                } break;
-            case GGML_OP_COUNT_EQUAL:
-                {
-                    cur = ggml_type_size(node->type)*n_tasks;
-                } break;
-            case GGML_OP_MUL_MAT:
-                {
-                    const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
-
-                    if (node->src[1]->type != vec_dot_type) {
-                        cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
-                    }
-                } break;
-            case GGML_OP_MUL_MAT_ID:
-                {
-                    cur = 0;
-                    const struct ggml_tensor * src0 = node->src[0];
-                    const struct ggml_tensor * src1 = node->src[1];
-                    const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
-                    if (src1->type != vec_dot_type) {
-                        cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
-                    }
-                    const int n_as = src0->ne[2];
-                    cur += GGML_PAD(cur, sizeof(int64_t));       // align
-                    cur += n_as * sizeof(int64_t);               // matrix_row_counts
-                    cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
-                } break;
-            case GGML_OP_OUT_PROD:
-                {
-                    if (ggml_is_quantized(node->src[0]->type)) {
-                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
-                    }
-                } break;
-            case GGML_OP_SOFT_MAX:
-            case GGML_OP_ROPE:
-                {
-                    cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
-                } break;
-            case GGML_OP_CONV_TRANSPOSE_1D:
-                {
-                    GGML_ASSERT(node->src[0]->ne[3] == 1);
-                    GGML_ASSERT(node->src[1]->ne[2] == 1);
-                    GGML_ASSERT(node->src[1]->ne[3] == 1);
-
-                    const int64_t ne00 = node->src[0]->ne[0];  // K
-                    const int64_t ne01 = node->src[0]->ne[1];  // Cout
-                    const int64_t ne02 = node->src[0]->ne[2];  // Cin
-
-                    const int64_t ne10 = node->src[1]->ne[0];  // L
-                    const int64_t ne11 = node->src[1]->ne[1];  // Cin
-
-                    if ((node->src[0]->type == GGML_TYPE_F16 ||
-                         node->src[0]->type == GGML_TYPE_BF16) &&
-                        node->src[1]->type == GGML_TYPE_F32) {
-                        cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
-                        cur += sizeof(ggml_fp16_t)*ne10*ne11;
-                    } else if (node->src[0]->type == GGML_TYPE_F32 &&
-                               node->src[1]->type == GGML_TYPE_F32) {
-                        cur += sizeof(float)*ne00*ne01*ne02;
-                        cur += sizeof(float)*ne10*ne11;
-                    } else {
-                        GGML_ABORT("fatal error");
-                    }
-                } break;
-            case GGML_OP_CONV_TRANSPOSE_2D:
-                {
-                    const int64_t ne00 = node->src[0]->ne[0]; // W
-                    const int64_t ne01 = node->src[0]->ne[1]; // H
-                    const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
-                    const int64_t ne03 = node->src[0]->ne[3]; // Channels In
-
-                    const int64_t ne10 = node->src[1]->ne[0]; // W
-                    const int64_t ne11 = node->src[1]->ne[1]; // H
-                    const int64_t ne12 = node->src[1]->ne[2]; // Channels In
-
-                    cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
-                    cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
-                } break;
-            case GGML_OP_FLASH_ATTN_EXT:
-                {
-                    const int64_t ne00 = node->src[0]->ne[0]; // D
-
-                    cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
-                } break;
-            case GGML_OP_FLASH_ATTN_BACK:
-                {
-                    const int64_t    D = node->src[0]->ne[0];
-                    const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
-                    const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
-                    if (node->src[1]->type == GGML_TYPE_F32) {
-                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
-                    } else if (node->src[1]->type == GGML_TYPE_F16) {
-                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
-                    } else if (node->src[1]->type == GGML_TYPE_BF16) {
-                        cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
-                        cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
-                    }
-                } break;
-
-            case GGML_OP_CROSS_ENTROPY_LOSS:
-                {
-                    cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
-                } break;
-            case GGML_OP_COUNT:
-                {
-                    GGML_ABORT("fatal error");
-                }
-            default:
-                break;
-        }
-
-        work_size = MAX(work_size, cur);
-    }
-
-    if (work_size > 0) {
-        work_size += CACHE_LINE_SIZE*(n_threads);
-    }
-
-    cplan.threadpool = threadpool;
-    cplan.n_threads  = MIN(max_tasks, n_threads);
-    cplan.work_size  = work_size;
-    cplan.work_data  = NULL;
-
-    return cplan;
-}
-
-static thread_ret_t ggml_graph_compute_thread(void * data) {
-    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
-    struct ggml_threadpool    * tp    = state->threadpool;
-
-    const struct ggml_cgraph * cgraph = tp->cgraph;
-    const struct ggml_cplan  * cplan  = tp->cplan;
-
-    set_numa_thread_affinity(state->ith);
-
-    struct ggml_compute_params params = {
-        /*.ith       =*/ state->ith,
-        /*.nth       =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
-        /*.wsize     =*/ cplan->work_size,
-        /*.wdata     =*/ cplan->work_data,
-        /*.threadpool=*/ tp,
-    };
-
-    for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
-        struct ggml_tensor * node = cgraph->nodes[node_n];
-
-        ggml_compute_forward(&params, node);
-
-        if (state->ith == 0 && cplan->abort_callback &&
-                cplan->abort_callback(cplan->abort_callback_data)) {
-            tp->abort = true;
-            tp->ec    = GGML_STATUS_ABORTED;
-        }
-
-        ggml_barrier(state->threadpool);
-    }
-
-    return 0;
-}
-
-#ifndef GGML_USE_OPENMP
-
-// check if thread is active
-static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) {
-    struct ggml_threadpool * threadpool = state->threadpool;
-    int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
-    return (state->ith < n_threads);
-}
-
-// check if thread is ready to proceed (exit from polling or sleeping)
-static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) {
-    struct ggml_threadpool * threadpool = state->threadpool;
-
-    if (state->pending || threadpool->stop || threadpool->pause) { return true; }
-
-    // check for new graph/work
-    int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
-    if (new_graph != state->last_graph) {
-        state->pending    = ggml_graph_compute_thread_active(state);
-        state->last_graph = new_graph;
-    }
-
-    return state->pending;
-}
-
-// sync thread state after polling
-static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
-    // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
-    #ifdef GGML_TSAN_ENABLED
-    atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
-    #else
-    atomic_thread_fence(memory_order_seq_cst);
-    #endif
-    UNUSED(state);
-}
-
-static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
-    struct ggml_threadpool * threadpool = state->threadpool;
-
-    // Skip polling for unused threads
-    if (!ggml_graph_compute_thread_active(state)) {
-        return state->pending;
-    }
-
-    // This seems to make 0 ... 100 a decent range for polling level across modern processors.
-    // Perhaps, we can adjust it dynamically based on load and things.
-    const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
-
-    for (uint64_t i=0; !ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {
-        // No new work. Keep polling.
-        ggml_thread_cpu_relax();
-    }
-
-    return state->pending;
-}
-
-static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) {
-    struct ggml_threadpool * threadpool = state->threadpool;
-
-    if (ggml_graph_compute_poll_for_work(state)) {
-        ggml_graph_compute_thread_sync(state);
-        return state->pending;
-    }
-
-    ggml_mutex_lock_shared(&threadpool->mutex);
-    while (!ggml_graph_compute_thread_ready(state)) {
-        // No new work. Wait for the signal.
-        GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
-        ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
-    }
-    ggml_mutex_unlock_shared(&threadpool->mutex);
-
-    return state->pending;
-}
-
-static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
-    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
-    struct ggml_threadpool * threadpool = state->threadpool;
-
-    ggml_thread_apply_priority(threadpool->prio);
-    if (ggml_thread_cpumask_is_valid(state->cpumask)) {
-        ggml_thread_apply_affinity(state->cpumask);
-    }
-
-    while (true) {
-        // Check if we need to sleep
-        while (threadpool->pause) {
-            GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith);
-            ggml_mutex_lock_shared(&threadpool->mutex);
-            if (threadpool->pause) {
-                ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
-            }
-            GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith);
-            ggml_mutex_unlock_shared(&threadpool->mutex);
-        }
-
-        // This needs to be checked for after the cond_wait
-        if (threadpool->stop) break;
-
-        // Check if there is new work
-        // The main thread is the only one that can dispatch new work
-
-        ggml_graph_compute_check_for_work(state);
-        if (state->pending) {
-            state->pending = false;
-
-            ggml_graph_compute_thread(state);
-        }
-    }
-
-    return (thread_ret_t) 0;
-}
-
-// Start processing new graph
-static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads)
-{
-    // Always take the mutex here because the worker threads are doing hybrid poll/wait
-
-    ggml_mutex_lock(&threadpool->mutex);
-
-    GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
-
-    // Update the number of active threads
-    atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
-
-    // Indicate the graph is ready to be processed
-    // We need the full seq-cst fence here because of the polling threads (used in thread_sync)
-    atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
-
-    if (threadpool->pause) {
-       // Update main thread prio and affinity to match the threadpool settings
-       ggml_thread_apply_priority(threadpool->prio);
-       if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
-           ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
-       }
-
-       // resume does cond broadcast
-       ggml_threadpool_resume_locked(threadpool);
-    } else {
-       ggml_cond_broadcast(&threadpool->cond);
-    }
-
-    ggml_mutex_unlock(&threadpool->mutex);
-}
-
-#endif // GGML_USE_OPENMP
-
-void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {
-    p->n_threads  = n_threads;
-    p->prio       = 0;     // default priority (usually means normal or inherited)
-    p->poll       = 50;    // hybrid-polling enabled
-    p->strict_cpu = false; // no strict placement (all threads share same cpumask)
-    p->paused     = false; // threads are ready to go
-    memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
-}
-
-struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
-    struct ggml_threadpool_params p;
-    ggml_threadpool_params_init(&p, n_threads);
-    return p;
-}
-
-bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
-    if (p0->n_threads      != p1->n_threads  )    return false;
-    if (p0->prio           != p1->prio       )    return false;
-    if (p0->poll           != p1->poll       )    return false;
-    if (p0->strict_cpu     != p1->strict_cpu )    return false;
-    return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
-}
-
-static struct ggml_threadpool * ggml_threadpool_new_impl(
-    struct ggml_threadpool_params * tpp,
-               struct ggml_cgraph * cgraph,
-                struct ggml_cplan * cplan) {
-
-    struct ggml_threadpool * threadpool =
-        ggml_aligned_malloc(sizeof(struct ggml_threadpool));
-    {
-        threadpool->cgraph           = cgraph;
-        threadpool->cplan            = cplan;
-        threadpool->n_graph          = 0;
-        threadpool->n_barrier        = 0;
-        threadpool->n_barrier_passed = 0;
-        threadpool->current_chunk    = 0;
-        threadpool->stop             = false;
-        threadpool->pause            = tpp->paused;
-        threadpool->abort            = false;
-        threadpool->workers          = NULL;
-        threadpool->n_threads_max    = tpp->n_threads;
-        threadpool->n_threads_cur    = tpp->n_threads;
-        threadpool->poll             = tpp->poll;
-        threadpool->prio             = tpp->prio;
-        threadpool->ec               = GGML_STATUS_SUCCESS;
-    }
-
-    // Allocate and init workers state
-    const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
-    struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);
-
-    memset(workers, 0, workers_size);
-    for (int j = 0; j < tpp->n_threads; j++) {
-        workers[j].threadpool = threadpool;
-        workers[j].ith        = j;
-    }
-
-    threadpool->workers = workers;
-
-#ifndef GGML_USE_OPENMP
-    ggml_mutex_init(&threadpool->mutex);
-    ggml_cond_init(&threadpool->cond);
-
-    // Spin the threads for all workers, and update CPU placements.
-    // Place the main thread last (towards the higher numbered CPU cores).
-
-    int32_t cpumask_iter = 0;
-
-    for (int j = 1; j < tpp->n_threads; j++) {
-        ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);
-
-        int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]);
-        GGML_ASSERT(rc == 0);
-    }
-
-    ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);
-
-    if (!threadpool->pause) {
-        // Update main thread prio and affinity at the start, otherwise we'll do it in resume
-        ggml_thread_apply_priority(threadpool->prio);
-        if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
-            ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
-        }
-    }
-#endif // GGML_USE_OPENMP
-
-    return threadpool;
-}
-
-struct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) {
-    return ggml_threadpool_new_impl(tpp, NULL, NULL);
-}
-
-enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
-    ggml_cpu_init();
-
-    GGML_ASSERT(cplan);
-    GGML_ASSERT(cplan->n_threads > 0);
-    GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
-
-    int n_threads                               = cplan->n_threads;
-    struct ggml_threadpool * threadpool = cplan->threadpool;
-
-    bool disposable_threadpool = false;
-
-    if (threadpool == NULL) {
-        //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
-        disposable_threadpool = true;
-
-        struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads);
-        threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan);
-    } else {
-        // Reset some of the parameters that need resetting
-        // No worker threads should be accessing the parameters below at this stage
-        threadpool->cgraph           = cgraph;
-        threadpool->cplan            = cplan;
-        threadpool->current_chunk    = 0;
-        threadpool->abort            = false;
-        threadpool->ec               = GGML_STATUS_SUCCESS;
-    }
-
-#ifdef GGML_USE_OPENMP
-    if (n_threads > 1) {
-        #pragma omp parallel num_threads(n_threads)
-        {
-            #pragma omp single
-            {
-                // update the number of threads from the actual number of threads that we got from OpenMP
-                n_threads = omp_get_num_threads();
-                atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
-            }
-
-            ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
-        }
-    } else {
-        atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
-        ggml_graph_compute_thread(&threadpool->workers[0]);
-    }
-#else
-    if (n_threads > threadpool->n_threads_max) {
-        GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
-        n_threads = threadpool->n_threads_max;
-    }
-
-    // Kick all threads to start the new graph
-    ggml_graph_compute_kickoff(threadpool, n_threads);
-
-    // This is a work thread too
-    ggml_graph_compute_thread(&threadpool->workers[0]);
-#endif
-
-    // don't leave affinity set on the main thread
-    clear_numa_thread_affinity();
-
-    enum ggml_status ret = threadpool->ec;
-
-    if (disposable_threadpool) {
-        ggml_threadpool_free(threadpool);
-    }
-
-    return ret;
-}
-
-enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
-    struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL);
-
-    cplan.work_data = (uint8_t *)ggml_new_buffer(ctx, cplan.work_size);
-
-    return ggml_graph_compute(cgraph, &cplan);
-}
-
-int ggml_cpu_has_neon(void) {
-#if defined(__ARM_ARCH)
-    return ggml_arm_arch_features.has_neon;
-#else
-    return 0;
-#endif
-}
-
-int ggml_cpu_has_sve(void) {
-#if defined(__ARM_ARCH)
-    return ggml_arm_arch_features.has_sve;
-#else
-    return 0;
-#endif
-}
-
-int ggml_cpu_has_matmul_int8(void) {
-#if defined(__ARM_ARCH)
-    return ggml_arm_arch_features.has_i8mm;
-#else
-    return 0;
-#endif
-}
-
-int ggml_cpu_get_sve_cnt(void) {
-#if defined(__ARM_ARCH)
-    return ggml_arm_arch_features.sve_cnt;
-#else
-    return 0;
-#endif
-}
-
-void ggml_cpu_init(void) {
-    // needed to initialize f16 tables
-    {
-        struct ggml_init_params params = { 0, NULL, false };
-        struct ggml_context * ctx = ggml_init(params);
-        ggml_free(ctx);
-    }
-
-    ggml_critical_section_start();
-
-    static bool is_first_call = true;
-
-    if (is_first_call) {
-        // initialize GELU, Quick GELU, SILU and EXP F32 tables
-        {
-            const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
-
-            for (int i = 0; i < (1 << 16); ++i) {
-                union {
-                    uint16_t u16;
-                    ggml_fp16_t fp16;
-                } u = {i};
-                float f = GGML_FP16_TO_FP32(u.fp16);
-                ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
-                ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
-            }
-
-            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
-
-            GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
-        }
-
-#if defined(__ARM_ARCH)
-        ggml_init_arm_arch_features();
-#endif
-
-        is_first_call = false;
-    }
-
-    ggml_critical_section_end();
-}
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
deleted file mode 100644 (file)
index 357cee6..0000000
+++ /dev/null
@@ -1,3364 +0,0 @@
-#include "ggml-cuda.h"
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-
-#include "ggml-cuda/common.cuh"
-#include "ggml-cuda/acc.cuh"
-#include "ggml-cuda/arange.cuh"
-#include "ggml-cuda/argmax.cuh"
-#include "ggml-cuda/argsort.cuh"
-#include "ggml-cuda/binbcast.cuh"
-#include "ggml-cuda/clamp.cuh"
-#include "ggml-cuda/concat.cuh"
-#include "ggml-cuda/conv-transpose-1d.cuh"
-#include "ggml-cuda/convert.cuh"
-#include "ggml-cuda/count-equal.cuh"
-#include "ggml-cuda/cpy.cuh"
-#include "ggml-cuda/cross-entropy-loss.cuh"
-#include "ggml-cuda/diagmask.cuh"
-#include "ggml-cuda/dmmv.cuh"
-#include "ggml-cuda/fattn.cuh"
-#include "ggml-cuda/getrows.cuh"
-#include "ggml-cuda/im2col.cuh"
-#include "ggml-cuda/mmq.cuh"
-#include "ggml-cuda/mmvq.cuh"
-#include "ggml-cuda/norm.cuh"
-#include "ggml-cuda/opt-step-adamw.cuh"
-#include "ggml-cuda/out-prod.cuh"
-#include "ggml-cuda/pad.cuh"
-#include "ggml-cuda/pool2d.cuh"
-#include "ggml-cuda/quantize.cuh"
-#include "ggml-cuda/rope.cuh"
-#include "ggml-cuda/scale.cuh"
-#include "ggml-cuda/softmax.cuh"
-#include "ggml-cuda/sum.cuh"
-#include "ggml-cuda/sumrows.cuh"
-#include "ggml-cuda/tsembd.cuh"
-#include "ggml-cuda/unary.cuh"
-#include "ggml-cuda/upscale.cuh"
-#include "ggml-cuda/wkv6.cuh"
-
-#include <algorithm>
-#include <array>
-#include <atomic>
-#include <cinttypes>
-#include <cstddef>
-#include <cstdint>
-#include <float.h>
-#include <limits>
-#include <map>
-#include <memory>
-#include <mutex>
-#include <stdint.h>
-#include <stdio.h>
-#include <stdarg.h>
-#include <stdlib.h>
-#include <string>
-#include <vector>
-
-static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
-
-[[noreturn]]
-void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
-    int id = -1; // in case cudaGetDevice fails
-    cudaGetDevice(&id);
-
-    GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
-    GGML_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
-    GGML_LOG_ERROR("  %s\n", stmt);
-    // abort with GGML_ABORT to get a stack trace
-    GGML_ABORT(GGML_CUDA_NAME " error");
-}
-
-// this is faster on Windows
-// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
-void ggml_cuda_set_device(int device) {
-    int current_device;
-    CUDA_CHECK(cudaGetDevice(&current_device));
-
-    if (device == current_device) {
-        return;
-    }
-
-    CUDA_CHECK(cudaSetDevice(device));
-}
-
-int ggml_cuda_get_device() {
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
-    return id;
-}
-
-static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
-    ggml_cuda_set_device(device);
-#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
-    auto res = hipMallocManaged(ptr, size);
-    if (res == hipSuccess) {
-        // if error we "need" to know why...
-        CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
-    }
-    return res;
-#else
-
-#if !defined(GGML_USE_HIPBLAS)
-    cudaError_t err;
-    if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
-    {
-        err = cudaMallocManaged(ptr, size);
-    }
-    else
-    {
-        err = cudaMalloc(ptr, size);
-    }
-    return err;
-#else
-    return cudaMalloc(ptr, size);
-#endif // !defined(GGML_USE_HIPBLAS)
-
-#endif
-}
-
-static ggml_cuda_device_info ggml_cuda_init() {
-#ifdef __HIP_PLATFORM_AMD__
-    // Workaround for a rocBLAS bug when using multiple graphics cards:
-    // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
-    rocblas_initialize();
-    CUDA_CHECK(cudaDeviceSynchronize());
-#endif
-
-    ggml_cuda_device_info info = {};
-
-    cudaError_t err = cudaGetDeviceCount(&info.device_count);
-    if (err != cudaSuccess) {
-        GGML_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
-        return info;
-    }
-
-    GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
-
-    int64_t total_vram = 0;
-#ifdef GGML_CUDA_FORCE_MMQ
-    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    yes\n", __func__);
-#else
-    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    no\n", __func__);
-#endif // GGML_CUDA_FORCE_MMQ
-#ifdef GGML_CUDA_FORCE_CUBLAS
-    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
-#else
-    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
-#endif // GGML_CUDA_FORCE_CUBLAS
-    GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
-    for (int id = 0; id < info.device_count; ++id) {
-        int device_vmm = 0;
-
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
-        CUdevice device;
-        CU_CHECK(cuDeviceGet(&device, id));
-        CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
-
-        if (device_vmm) {
-            CUmemAllocationProp alloc_prop = {};
-            alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
-            alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
-            alloc_prop.location.id = id;
-            CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
-        }
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
-        info.devices[id].vmm = !!device_vmm;
-
-        cudaDeviceProp prop;
-        CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
-        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
-
-        info.default_tensor_split[id] = total_vram;
-        total_vram += prop.totalGlobalMem;
-
-        info.devices[id].nsm   = prop.multiProcessorCount;
-        info.devices[id].smpb  = prop.sharedMemPerBlock;
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-        info.devices[id].smpbo = prop.sharedMemPerBlock;
-        info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
-#else
-        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
-        info.devices[id].cc = 100*prop.major + 10*prop.minor;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    }
-
-    for (int id = 0; id < info.device_count; ++id) {
-        info.default_tensor_split[id] /= total_vram;
-    }
-
-    // configure logging to stdout
-    // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
-
-    return info;
-}
-
-const ggml_cuda_device_info & ggml_cuda_info() {
-    static ggml_cuda_device_info info = ggml_cuda_init();
-    return info;
-}
-
-// #define DEBUG_CUDA_MALLOC
-
-// buffer pool for cuda (legacy)
-struct ggml_cuda_pool_leg : public ggml_cuda_pool {
-    static const int MAX_BUFFERS = 256;
-
-    int device;
-    struct ggml_cuda_buffer {
-        void * ptr = nullptr;
-        size_t size = 0;
-    };
-
-    ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
-    size_t pool_size = 0;
-
-    explicit ggml_cuda_pool_leg(int device) :
-        device(device) {
-    }
-
-    ~ggml_cuda_pool_leg() {
-        ggml_cuda_set_device(device);
-        for (int i = 0; i < MAX_BUFFERS; ++i) {
-            ggml_cuda_buffer & b = buffer_pool[i];
-            if (b.ptr != nullptr) {
-                CUDA_CHECK(cudaFree(b.ptr));
-                pool_size -= b.size;
-            }
-        }
-        GGML_ASSERT(pool_size == 0);
-    }
-
-    void * alloc(size_t size, size_t * actual_size) override {
-#ifdef DEBUG_CUDA_MALLOC
-        int nnz = 0;
-        size_t max_size = 0;
-#endif
-        size_t best_diff = 1ull << 36;
-        int ibest = -1;
-        for (int i = 0; i < MAX_BUFFERS; ++i) {
-            ggml_cuda_buffer& b = buffer_pool[i];
-            if (b.ptr != nullptr) {
-#ifdef DEBUG_CUDA_MALLOC
-                ++nnz;
-                if (b.size > max_size) max_size = b.size;
-#endif
-                if (b.size >= size) {
-                    size_t diff = b.size - size;
-                    if (diff < best_diff) {
-                        best_diff = diff;
-                        ibest = i;
-                        if (!best_diff) {
-                            void * ptr = b.ptr;
-                            *actual_size = b.size;
-                            b.ptr = nullptr;
-                            b.size = 0;
-                            return ptr;
-                        }
-                    }
-                }
-            }
-        }
-        if (ibest >= 0) {
-            ggml_cuda_buffer& b = buffer_pool[ibest];
-            void * ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
-        }
-        void * ptr;
-        size_t look_ahead_size = (size_t) (1.05 * size);
-        look_ahead_size = 256 * ((look_ahead_size + 255)/256);
-        ggml_cuda_set_device(device);
-        CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
-        *actual_size = look_ahead_size;
-        pool_size += look_ahead_size;
-#ifdef DEBUG_CUDA_MALLOC
-        GGML_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
-                           (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
-#endif
-        return ptr;
-    }
-
-    void free(void * ptr, size_t size) override {
-        for (int i = 0; i < MAX_BUFFERS; ++i) {
-            ggml_cuda_buffer& b = buffer_pool[i];
-            if (b.ptr == nullptr) {
-                b.ptr = ptr;
-                b.size = size;
-                return;
-            }
-        }
-        GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
-        ggml_cuda_set_device(device);
-        CUDA_CHECK(cudaFree(ptr));
-        pool_size -= size;
-    }
-};
-
-// pool with virtual memory
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
-struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
-    static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
-
-    int device;
-    CUdeviceptr pool_addr = 0;
-    size_t pool_used = 0;
-    size_t pool_size = 0;
-    size_t granularity;
-
-    explicit ggml_cuda_pool_vmm(int device) :
-        device(device),
-        granularity(ggml_cuda_info().devices[device].vmm_granularity) {
-    }
-
-    ~ggml_cuda_pool_vmm() {
-        if (pool_addr != 0) {
-            CU_CHECK(cuMemUnmap(pool_addr, pool_size));
-            CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
-        }
-    }
-
-    void * alloc(size_t size, size_t * actual_size) override {
-        // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
-        const size_t alignment = 128;
-        size = alignment * ((size + alignment - 1) / alignment);
-
-        size_t avail = pool_size - pool_used;
-
-        if (size > avail) {
-            // round up to the next multiple of the granularity
-            size_t reserve_size = size - avail;
-            reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
-
-            GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
-
-            // allocate more physical memory
-            CUmemAllocationProp prop = {};
-            prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
-            prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
-            prop.location.id = device;
-            CUmemGenericAllocationHandle handle;
-            CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
-
-            // reserve virtual address space (if not already reserved)
-            if (pool_addr == 0) {
-                CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
-            }
-
-            // map at the end of the pool
-            CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
-
-            // the memory allocation handle is no longer needed after mapping
-            CU_CHECK(cuMemRelease(handle));
-
-            // set access
-            CUmemAccessDesc access = {};
-            access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
-            access.location.id = device;
-            access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
-            CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
-
-            // add to the pool
-            pool_size += reserve_size;
-
-            //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
-            //       device, (unsigned long long) (pool_size/1024/1024),
-            //       (unsigned long long) (reserve_size/1024/1024));
-        }
-
-        GGML_ASSERT(pool_addr != 0);
-
-        void * ptr = (void *) (pool_addr + pool_used);
-        *actual_size = size;
-        pool_used += size;
-
-#ifdef DEBUG_CUDA_MALLOC
-        printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
-#endif
-
-        return ptr;
-    }
-
-    void free(void * ptr, size_t size) override {
-#ifdef DEBUG_CUDA_MALLOC
-        printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
-#endif
-
-        pool_used -= size;
-
-        // all deallocations must be in reverse order of the allocations
-        GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
-    }
-};
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
-
-std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
-    if (ggml_cuda_info().devices[device].vmm) {
-        return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
-    }
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
-    return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
-}
-
-// cuda buffer
-
-struct ggml_backend_cuda_buffer_context {
-    int device;
-    void * dev_ptr = nullptr;
-    std::string name;
-
-    ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
-        device(device), dev_ptr(dev_ptr),
-        name(GGML_CUDA_NAME + std::to_string(device)) {
-    }
-
-    ~ggml_backend_cuda_buffer_context() {
-        CUDA_CHECK(cudaFree(dev_ptr));
-    }
-};
-
-static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-    delete ctx;
-}
-
-static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
-    return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer;
-}
-
-static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
-    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-    return ctx->dev_ptr;
-}
-
-static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
-    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
-    if (tensor->view_src != NULL) {
-        assert(tensor->view_src->buffer->buft == buffer->buft);
-        return;
-    }
-
-    if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
-        // initialize padding to 0 to avoid possible NaN values
-        size_t original_size = ggml_nbytes(tensor);
-        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
-
-        if (padded_size > original_size) {
-            ggml_cuda_set_device(ctx->device);
-            CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
-        }
-    }
-}
-
-static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
-    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
-    ggml_cuda_set_device(ctx->device);
-    CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
-    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-}
-
-static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
-    ggml_cuda_set_device(ctx->device);
-    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
-    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-}
-
-static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
-    ggml_cuda_set_device(ctx->device);
-    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
-    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-}
-
-static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
-    if (ggml_backend_buffer_is_cuda(src->buffer)) {
-        ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
-        ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
-        if (src_ctx->device == dst_ctx->device) {
-            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
-        } else {
-#ifdef GGML_CUDA_NO_PEER_COPY
-            return false;
-#else
-            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
-#endif
-        }
-        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-        return true;
-    }
-    return false;
-
-    GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
-    ggml_cuda_set_device(ctx->device);
-    CUDA_CHECK(cudaDeviceSynchronize());
-    CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
-    CUDA_CHECK(cudaDeviceSynchronize());
-}
-
-static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
-    /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_cuda_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,
-    /* .memset_tensor   = */ ggml_backend_cuda_buffer_memset_tensor,
-    /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,
-    /* .cpy_tensor      = */ ggml_backend_cuda_buffer_cpy_tensor,
-    /* .clear           = */ ggml_backend_cuda_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// cuda buffer type
-struct ggml_backend_cuda_buffer_type_context {
-    int device;
-    std::string name;
-};
-
-static const char * ggml_backend_cuda_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
-
-    return ctx->name.c_str();
-}
-
-static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_cuda_buffer_type_get_name;
-}
-
-static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
-
-    ggml_cuda_set_device(buft_ctx->device);
-
-    void * dev_ptr;
-    cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
-    if (err != cudaSuccess) {
-        // clear the error
-        cudaGetLastError();
-        GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
-        return nullptr;
-    }
-
-    ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
-
-    return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
-}
-
-static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return 128;
-
-    GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    size_t size = ggml_nbytes(tensor);
-    int64_t ne0 = tensor->ne[0];
-
-    if (ggml_is_quantized(tensor->type)) {
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-    }
-
-    return size;
-
-    GGML_UNUSED(buft);
-}
-
-static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_cuda_buffer_type_get_name,
-    /* .alloc_buffer     = */ ggml_backend_cuda_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_cuda_buffer_type_get_alignment,
-    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-    /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,
-    /* .is_host          = */ NULL,
-};
-
-ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-
-    if (device >= ggml_backend_cuda_get_device_count()) {
-        return nullptr;
-    }
-
-    static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
-
-    static bool ggml_backend_cuda_buffer_type_initialized = false;
-
-    if (!ggml_backend_cuda_buffer_type_initialized) {
-        for (int i = 0; i < ggml_backend_cuda_get_device_count(); i++) {
-            ggml_backend_cuda_buffer_types[i] = {
-                /* .iface    = */ ggml_backend_cuda_buffer_type_interface,
-                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), i),
-                /* .context  = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},
-            };
-        }
-        ggml_backend_cuda_buffer_type_initialized = true;
-    }
-
-    return &ggml_backend_cuda_buffer_types[device];
-}
-
-// cuda split buffer
-
-static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
-    int64_t row_rounding = 0;
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
-            continue;
-        }
-
-        const int cc = ggml_cuda_info().devices[id].cc;
-        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
-    }
-    return row_rounding;
-}
-
-static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
-    const int64_t nrows = ggml_nrows(tensor);
-    const int64_t rounding = get_row_rounding(tensor_split);
-
-    *row_low = id == 0 ? 0 : nrows*tensor_split[id];
-    *row_low -= *row_low % rounding;
-
-    if (id == ggml_backend_cuda_get_device_count() - 1) {
-        *row_high = nrows;
-    } else {
-        *row_high = nrows*tensor_split[id + 1];
-        *row_high -= *row_high % rounding;
-    }
-}
-
-static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
-    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
-
-    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
-}
-
-struct ggml_backend_cuda_split_buffer_type_context {
-    int main_device;
-    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
-    std::string name;
-};
-
-struct ggml_backend_cuda_split_buffer_context {
-    ~ggml_backend_cuda_split_buffer_context() {
-        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
-            for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) {
-                for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
-                    if (extra->events[id][is] != nullptr) {
-                        CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
-                    }
-                }
-                if (extra->data_device[id] != nullptr) {
-                    CUDA_CHECK(cudaFree(extra->data_device[id]));
-                }
-            }
-            delete extra;
-        }
-    }
-
-    std::vector<ggml_tensor_extra_gpu *> tensor_extras;
-};
-
-
-static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
-    delete ctx;
-}
-
-static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
-    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
-    return (void *)0x1000;
-
-    GGML_UNUSED(buffer);
-}
-
-static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
-    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
-
-    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
-    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
-
-    const int64_t ne0 = tensor->ne[0];
-
-    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
-    ctx->tensor_extras.push_back(extra);
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
-
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
-
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
-
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-
-        // FIXME: do not crash if cudaMalloc fails
-        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
-        ggml_cuda_set_device(id);
-        char * buf;
-        CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
-
-        // set padding to 0 to avoid possible NaN values
-        if (size > original_size) {
-            CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
-        }
-
-        extra->data_device[id] = buf;
-
-        for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
-            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
-        }
-    }
-    tensor->extra = extra;
-}
-
-static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    // split tensors must always be set in their entirety at once
-    GGML_ASSERT(offset == 0);
-    GGML_ASSERT(size == ggml_nbytes(tensor));
-
-    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
-
-    const int64_t ne0 = tensor->ne[0];
-    const size_t nb1 = tensor->nb[1];
-    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
-
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
-
-        const size_t offset_split = row_low*nb1;
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
-
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-
-        const char * buf_host = (const char *)data + offset_split;
-        CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
-    }
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-    }
-}
-
-static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    // split tensors must always be set in their entirety at once
-    GGML_ASSERT(offset == 0);
-    GGML_ASSERT(size == ggml_nbytes(tensor));
-
-    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
-
-    const int64_t ne0 = tensor->ne[0];
-    const size_t nb1 = tensor->nb[1];
-    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
-
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
-
-        const size_t offset_split = row_low*nb1;
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
-
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-
-        char * buf_host = (char *)data + offset_split;
-        CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
-    }
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
-    }
-}
-
-static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    GGML_UNUSED(buffer);
-    GGML_UNUSED(value);
-}
-
-static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
-    /* .free_buffer     = */ ggml_backend_cuda_split_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_cuda_split_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_cuda_split_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_cuda_split_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_cuda_split_buffer_get_tensor,
-    /* .cpy_tensor      = */ NULL,
-    /* .clear           = */ ggml_backend_cuda_split_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// cuda split buffer type
-
-static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
-
-    return ctx->name.c_str();
-}
-
-static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_get_name;
-}
-
-static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
-    // instead, we allocate them for each tensor separately in init_tensor
-    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
-    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
-    ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();
-
-    return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size);
-}
-
-static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return 128;
-
-    GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
-
-    size_t total_size = 0;
-
-    const int64_t ne0 = tensor->ne[0];
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
-
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
-
-        total_size += ggml_nbytes_split(tensor, nrows_split);
-
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-    }
-
-    return total_size;
-}
-
-static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
-    return false;
-
-    GGML_UNUSED(buft);
-}
-
-static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_cuda_split_buffer_type_get_name,
-    /* .alloc_buffer     = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_cuda_split_buffer_type_get_alignment,
-    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-    /* .get_alloc_size   = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
-    /* .is_host          = */ ggml_backend_cuda_split_buffer_type_is_host,
-};
-
-ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-
-    static std::map<std::pair<int, std::array<float, GGML_CUDA_MAX_DEVICES>>, struct ggml_backend_buffer_type> buft_map;
-
-    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
-
-    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });
-    if (all_zero) {
-        tensor_split_arr = ggml_cuda_info().default_tensor_split;
-    } else {
-        float split_sum = 0.0f;
-        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
-            tensor_split_arr[i] = split_sum;
-            split_sum += tensor_split[i];
-        }
-        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
-            tensor_split_arr[i] /= split_sum;
-        }
-    }
-
-    auto it = buft_map.find({main_device, tensor_split_arr});
-    if (it != buft_map.end()) {
-        return &it->second;
-    }
-    auto * ctx = new ggml_backend_cuda_split_buffer_type_context{
-        main_device,
-        tensor_split_arr,
-        GGML_CUDA_NAME + std::to_string(main_device) + "_Split",
-    };
-
-    struct ggml_backend_buffer_type buft {
-        /* .iface   = */ ggml_backend_cuda_split_buffer_type_interface,
-        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device),
-        /* .context = */ ctx,
-    };
-
-    auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);
-    return &result.first->second;
-}
-
-// host buffer type
-
-static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    return GGML_CUDA_NAME "_Host";
-
-    GGML_UNUSED(buft);
-}
-
-static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    CUDA_CHECK(cudaFreeHost(buffer->context));
-}
-
-static void * ggml_cuda_host_malloc(size_t size) {
-    if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
-        return nullptr;
-    }
-
-    void * ptr = nullptr;
-    cudaError_t err = cudaMallocHost((void **) &ptr, size);
-    if (err != cudaSuccess) {
-        // clear the error
-        cudaGetLastError();
-        GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
-                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
-        return nullptr;
-    }
-
-    return ptr;
-}
-
-static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    void * ptr = ggml_cuda_host_malloc(size);
-
-    if (ptr == nullptr) {
-        // fallback to cpu buffer
-        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
-    }
-
-    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
-    buffer->buft = buft;
-    buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
-
-    return buffer;
-}
-
-ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
-    static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {
-        /* .iface    = */ {
-            /* .get_name         = */ ggml_backend_cuda_host_buffer_type_name,
-            /* .alloc_buffer     = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
-            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
-            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
-        },
-        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0),
-        /* .context  = */ nullptr,
-    };
-
-    return &ggml_backend_cuda_buffer_type_host;
-}
-
-//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {
-//    return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
-//}
-
-/// kernels
-
-typedef void (*ggml_cuda_op_mul_mat_t)(
-    ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
-    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, cudaStream_t stream);
-
-#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
-#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
-#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
-
-#define MUL_MAT_SRC1_COL_STRIDE 128
-
-static __global__ void mul_mat_p021_f16_f32(
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
-
-    const half * x = (const half *) vx;
-
-    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
-    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
-    const int channel_x = channel / (nchannels_y / nchannels_x);
-
-    const int nrows_y = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst = row_x;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
-        const int col_x = col_x0 + threadIdx.x;
-
-        if (col_x >= ncols_x) {
-            break;
-        }
-
-        // x is transposed and permuted
-        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
-        const float xi = __half2float(x[ix]);
-
-        const int row_y = col_x;
-
-        // y is not transposed but permuted
-        const int iy = channel*nrows_y + row_y;
-
-        tmp += xi * y[iy];
-    }
-
-    // dst is not transposed and not permuted
-    const int idst = channel*nrows_dst + row_dst;
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (threadIdx.x == 0) {
-        dst[idst] = tmp;
-    }
-}
-
-static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
-
-    const half * x = (const half *) vx;
-
-    const int row_x     = blockDim.y*blockIdx.y + threadIdx.y;
-    const int channel   = blockDim.z*blockIdx.z + threadIdx.z;
-    const int channel_x = channel / channel_x_divisor;
-
-    const int nrows_y   = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst   = row_x;
-
-    const int idst = channel*nrows_dst + row_dst;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
-        const int col_x = col_x0 + threadIdx.x;
-
-        if (col_x >= ncols_x) {
-            break;
-        }
-
-        const int row_y = col_x;
-
-        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
-        const int iy = channel*nrows_y + row_y;
-
-        const float xi = __half2float(x[ix]);
-
-        tmp += xi * y[iy];
-    }
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (threadIdx.x == 0) {
-        dst[idst] = tmp;
-    }
-}
-
-static void ggml_mul_mat_p021_f16_f32_cuda(
-    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
-    const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
-
-    const dim3 block_nums(1, nrows_x, nchannels_y);
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
-}
-
-static void ggml_mul_mat_vec_nc_f16_f32_cuda(
-    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
-    const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
-
-    const dim3 block_nums(1, nrows_x, nchannels_y);
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
-        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
-}
-
-static cudaError_t ggml_cuda_cpy_tensor_2d(
-    void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
-
-    GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
-    const char * src_ptr = (const char *) src->data;
-    char       * dst_ptr = (char       *) dst;
-
-    const int64_t ne0 = src->ne[0];
-    const int64_t nb0 = src->nb[0];
-    const int64_t nb1 = src->nb[1];
-    const int64_t nb2 = src->nb[2];
-    const int64_t nb3 = src->nb[3];
-    const enum ggml_type type = src->type;
-    const int64_t ts = ggml_type_size(type);
-    const int64_t bs = ggml_blck_size(type);
-    const int64_t i1_diff = i1_high - i1_low;
-
-    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
-    if (nb0 == ts && nb1 == ts*ne0/bs) {
-        return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
-    } else if (nb0 == ts) {
-        return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
-    } else {
-        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
-            const void * rx = (const void *) ((const char *) x + i1*nb1);
-            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
-            // pretend the row is a matrix with cols=1
-            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
-            if (r != cudaSuccess) {
-                return r;
-            }
-        }
-        return cudaSuccess;
-    }
-}
-
-static void ggml_cuda_op_mul_mat_cublas(
-    ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
-    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, cudaStream_t stream) {
-
-    GGML_ASSERT(src0_dd_i  != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_dd_i   != nullptr);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne10 = src1->ne[0];
-
-    const int64_t ne0 = dst->ne[0];
-
-    const int64_t row_diff = row_high - row_low;
-
-    int id = ggml_cuda_get_device();
-
-    // the main device has a larger memory buffer to hold the results from all GPUs
-    // ldc == nrows of the matrix that cuBLAS writes into
-    int64_t ldc = id == ctx.device ? ne0 : row_diff;
-
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
-        // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
-        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
-        if (src0->type != GGML_TYPE_F16) {
-            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
-            GGML_ASSERT(to_fp16_cuda != nullptr);
-            size_t ne = row_diff*ne00;
-            src0_as_f16.alloc(ne);
-            to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
-        }
-        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
-
-        ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
-        if (src1->type != GGML_TYPE_F16) {
-            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
-            GGML_ASSERT(to_fp16_cuda != nullptr);
-            size_t ne = src1_ncols*ne10;
-            src1_as_f16.alloc(ne);
-            to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
-        }
-        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
-        ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
-
-        const half alpha_f16 = 1.0f;
-        const half beta_f16 = 0.0f;
-
-        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
-        CUBLAS_CHECK(
-            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
-                    row_diff, src1_ncols, ne10,
-                    &alpha_f16, src0_ptr,       CUDA_R_16F, ne00,
-                                src1_ptr,       CUDA_R_16F, ne10,
-                    &beta_f16,   dst_f16.get(), CUDA_R_16F, ldc,
-                    CUBLAS_COMPUTE_16F,
-                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
-        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-        to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
-    } else {
-        ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
-        ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
-
-        if (src0->type != GGML_TYPE_F32) {
-            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
-            GGML_ASSERT(to_fp32_cuda != nullptr);
-            src0_ddq_as_f32.alloc(row_diff*ne00);
-            to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
-        }
-        if (src1->type != GGML_TYPE_F32) {
-            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
-            GGML_ASSERT(to_fp32_cuda != nullptr);
-            src1_ddq_as_f32.alloc(src1_ncols*ne10);
-            to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
-        }
-
-        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
-        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
-
-        const float alpha = 1.0f;
-        const float beta = 0.0f;
-
-        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
-        CUBLAS_CHECK(
-            cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
-                    row_diff, src1_ncols, ne10,
-                    &alpha, src0_ddf_i,  ne00,
-                            src1_ddf1_i, ne10,
-                    &beta,  dst_dd_i,    ldc));
-    }
-
-    GGML_UNUSED(dst);
-    GGML_UNUSED(src1_ddq_i);
-    GGML_UNUSED(src1_padded_row_size);
-}
-
-static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
-    static bool peer_access_enabled = false;
-
-    const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
-
-    if (peer_access_enabled == enable_peer_access) {
-        return;
-    }
-
-#ifdef NDEBUG
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        ggml_cuda_set_device(id);
-        CUDA_CHECK(cudaDeviceSynchronize());
-    }
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        ggml_cuda_set_device(id);
-
-        for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
-            if (id == id_other) {
-                continue;
-            }
-            if (id != main_device && id_other != main_device) {
-                continue;
-            }
-
-            int can_access_peer;
-            CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
-            if (can_access_peer) {
-                if (enable_peer_access) {
-                    cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
-                    if (err != cudaErrorPeerAccessAlreadyEnabled) {
-                        CUDA_CHECK(err);
-                    } else {
-                        // reset the error
-                        cudaGetLastError();
-                    }
-                } else {
-                    cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
-                    if (err != cudaErrorPeerAccessNotEnabled) {
-                        CUDA_CHECK(err);
-                    } else {
-                        // reset the error
-                        cudaGetLastError();
-                    }
-                }
-            }
-        }
-    }
-
-    ggml_cuda_set_device(main_device);
-#endif // NDEBUG
-
-    peer_access_enabled = enable_peer_access;
-
-    GGML_UNUSED(main_device);
-}
-
-static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
-    void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
-
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
-    // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
-    cudaMemcpy3DPeerParms p = {};
-    p.dstDevice = dstDevice;
-    p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
-    p.srcDevice = srcDevice;
-    p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
-    p.extent = make_cudaExtent(width, height, 1);
-    return cudaMemcpy3DPeerAsync(&p, stream);
-#else
-    // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
-    GGML_UNUSED(dstDevice);
-    GGML_UNUSED(srcDevice);
-    return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
-}
-
-static void ggml_cuda_op_mul_mat(
-    ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
-    quantize_cuda_t quantize_src1) {
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-    const int64_t nrows1 = ggml_nrows(src1);
-
-    GGML_ASSERT(ne03 == ne13);
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-
-    const int64_t nb2 = dst->nb[2];
-    const int64_t nb3 = dst->nb[3];
-
-    GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
-    GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
-    ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
-    ggml_backend_cuda_buffer_context * dst_ctx  = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
-
-    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
-
-    const int64_t i02_divisor = ne12 / ne02;
-
-    const size_t src0_ts = ggml_type_size(src0->type);
-    const size_t src0_bs = ggml_blck_size(src0->type);
-    const size_t q8_1_ts = sizeof(block_q8_1);
-    const size_t q8_1_bs = QK8_1;
-
-    const bool src0_is_contiguous = ggml_is_contiguous(src0);
-    const bool src1_is_contiguous = ggml_is_contiguous(src1);
-
-    const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
-
-    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
-    GGML_ASSERT(!(split && ne02 > 1));
-    GGML_ASSERT(!(split && ne03 > 1));
-    GGML_ASSERT(!(split && ne02 < ne12));
-
-    ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
-
-
-    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
-    if (split) {
-        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
-        tensor_split = buft_ctx->tensor_split;
-    }
-
-    struct dev_data {
-        int cc;
-
-        ggml_cuda_pool_alloc<char>   src0_dd_alloc;
-        ggml_cuda_pool_alloc<float> src1_ddf_alloc;
-        ggml_cuda_pool_alloc<char>  src1_ddq_alloc;
-        ggml_cuda_pool_alloc<float>   dst_dd_alloc;
-
-        char  *  src0_dd = nullptr;
-        float * src1_ddf = nullptr; // float
-        char  * src1_ddq = nullptr; // q8_1
-        float *   dst_dd = nullptr;
-
-        int64_t  row_low;
-        int64_t row_high;
-    };
-
-    dev_data dev[GGML_CUDA_MAX_DEVICES];
-
-    int used_devices = 0;
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        dev[id].cc = ggml_cuda_info().devices[id].cc;
-
-        // by default, use all rows
-        dev[id].row_low  = 0;
-        dev[id].row_high = ne01;
-
-        // for multi GPU, get the row boundaries from tensor split
-        // and round to mul_mat_q tile sizes
-        if (split) {
-            const int64_t rounding = get_row_rounding(tensor_split);
-
-            if (id != 0) {
-                dev[id].row_low  = ne01*tensor_split[id];
-                if (dev[id].row_low < ne01) {
-                    dev[id].row_low -= dev[id].row_low % rounding;
-                }
-            }
-
-            if (id != ggml_backend_cuda_get_device_count() - 1) {
-                dev[id].row_high  = ne01*tensor_split[id + 1];
-                if (dev[id].row_high < ne01) {
-                    dev[id].row_high -= dev[id].row_high % rounding;
-                }
-            }
-        }
-    }
-
-    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
-            continue;
-        }
-
-        used_devices++;
-
-        const bool src1_on_device = id == src1_ctx->device;
-        const bool  dst_on_device = id == dst_ctx->device;
-
-        ggml_cuda_set_device(id);
-        cudaStream_t stream = ctx.stream(id, 0);
-
-        if (src0_is_contiguous) {
-            dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
-        } else {
-            // If src0 is not contiguous it will be copied to a temporary buffer.
-            // This buffer needs to be cleared entirely because multiple regions will function as padding.
-            const size_t nbytes_data    = ggml_nbytes(src0);
-            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
-            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
-        // TODO: remove this for MUSA once the Guilty Lockup issue is resolved
-#ifndef GGML_USE_MUSA
-            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
-#else // GGML_USE_MUSA
-            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
-#endif // !GGML_USE_MUSA
-        }
-
-        // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
-        if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
-            const size_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
-            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
-            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
-        }
-
-        if (src1_on_device && src1_is_contiguous) {
-            dev[id].src1_ddf = (float *) src1->data;
-        } else {
-            dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
-        }
-
-        if (quantize_src1) {
-            size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
-            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
-                src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
-            }
-            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
-
-            if (src1_on_device && src1_is_contiguous) {
-                quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
-                CUDA_CHECK(cudaGetLastError());
-            }
-        }
-
-        if (dst_on_device) {
-            dev[id].dst_dd = (float *) dst->data;
-        } else {
-            const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
-            dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf);
-        }
-    }
-
-    // if multiple devices are used they need to wait for the main device
-    // here an event is recorded that signals that the main device has finished calculating the input data
-    if (split && used_devices > 1) {
-        ggml_cuda_set_device(ctx.device);
-        CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream()));
-    }
-
-    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
-    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
-        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;
-        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
-
-        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-            if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
-                continue;
-            }
-
-            const bool src1_on_device = id == src1_ctx->device;
-            const bool  dst_on_device = id == dst_ctx->device;
-            const int64_t row_diff = dev[id].row_high - dev[id].row_low;
-
-            ggml_cuda_set_device(id);
-            cudaStream_t stream = ctx.stream(id, is);
-
-            // wait for main GPU data if necessary
-            if (split && (id != ctx.device || is != 0)) {
-                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0));
-            }
-
-            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
-                const int64_t i03 = i0 / ne12;
-                const int64_t i02 = i0 % ne12;
-
-                size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
-                if (quantize_src1 == quantize_mmq_q8_1_cuda) {
-                    src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
-                } else {
-                    src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
-                }
-
-                // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
-                float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
-                char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;
-                float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
-
-                // the main device memory buffer can be on VRAM scratch, with space for all partial results
-                // in that case an offset on dst_ddf_i is needed
-                if (id == ctx.device) {
-                    dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
-                }
-
-                // copy src0, src1 to device if necessary
-                if (src1_is_contiguous) {
-                    if (id != ctx.device) {
-                        if (quantize_src1) {
-                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
-                            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
-                                const size_t pitch = ne11*sizeof(block_q8_1_mmq);
-                                const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
-                                const size_t height = src1_padded_col_size/(4*QK8_1);
-                                CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
-                            } else {
-                                CUDA_CHECK(cudaMemcpyPeerAsync(
-                                    src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
-                            }
-                        } else {
-                            float * src1_ddf_i_source = (float *) src1->data;
-                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
-                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device,
-                                                            src1_ncols*ne10*sizeof(float), stream));
-                        }
-                    }
-                } else if (src1_on_device && !src1_is_contiguous) {
-                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
-                                src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
-                } else {
-                    GGML_ABORT("fatal error");
-                }
-
-                if (quantize_src1 && !src1_is_contiguous) {
-                    quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
-                    CUDA_CHECK(cudaGetLastError());
-                }
-
-                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
-                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
-                }
-
-                // do the computation
-                op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
-                    dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
-                CUDA_CHECK(cudaGetLastError());
-
-                // copy dst to host or other device if necessary
-                if (!dst_on_device) {
-                    void * dst_off_device = dst->data;
-                    if (split) {
-                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
-                        // dst is NOT transposed.
-                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
-                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
-                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
-                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
-                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
-                        dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
-                        CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
-                            dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
-                    } else {
-                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
-                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
-                        dhf_dst_i += src1_col_0*ne0;
-                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
-                    }
-                }
-
-                // add event for the main device to wait on until other device is done
-                if (split && (id != ctx.device || is != 0)) {
-                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
-                }
-            }
-        }
-    }
-
-    // main device waits for all other devices to be finished
-    if (split && ggml_backend_cuda_get_device_count() > 1) {
-        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
-        is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS;
-
-        ggml_cuda_set_device(ctx.device);
-        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-            if (dev[id].row_low == dev[id].row_high) {
-                continue;
-            }
-            for (int64_t is = 0; is < is_max; ++is) {
-                CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0));
-            }
-        }
-    }
-}
-
-static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
-    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
-    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
-    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-
-    const int64_t ne12 = src1->ne[2];
-
-    cudaStream_t main_stream = ctx.stream();
-
-    void  * src0_ddq = src0->data;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
-
-    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
-}
-
-static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(!ggml_is_transposed(src0));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-    GGML_ASSERT(!ggml_is_permuted(src0));
-    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-
-    const int64_t nb01 = src0->nb[1];
-    const int64_t nb02 = src0->nb[2];
-
-    const int64_t ne12 = src1->ne[2];
-
-    cudaStream_t main_stream = ctx.stream();
-
-    void  * src0_ddq = src0->data;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
-
-    const int64_t row_stride_x = nb01 / sizeof(half);
-    const int64_t channel_stride_x = nb02 / sizeof(half);
-
-    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
-}
-
-static __global__ void k_compute_batched_ptrs(
-        const half * src0_as_f16, const half * src1_as_f16, char * dst,
-        const void ** ptrs_src, void ** ptrs_dst,
-        int64_t ne12, int64_t ne13,
-        int64_t ne23,
-        size_t  nb02, size_t  nb03,
-        size_t  nb12, size_t  nb13,
-        size_t  nbd2, size_t  nbd3,
-        int64_t r2,   int64_t r3) {
-    int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
-    int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
-
-    if (i13 >= ne13 || i12 >= ne12) {
-        return;
-    }
-
-    int64_t i03 = i13 / r3;
-    int64_t i02 = i12 / r2;
-
-    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
-    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
-}
-
-static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(!ggml_is_transposed(src0));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-
-    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t ne_dst = ggml_nelements(dst);
-
-    cudaStream_t main_stream = ctx.stream();
-
-    CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
-
-    void * src0_ddq = src0->data;
-    half * src0_f16 = (half *) src0_ddq;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
-
-    // convert src1 to fp16
-    ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
-    if (src1->type != GGML_TYPE_F16) {
-        const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
-        const int64_t ne_src1 = ggml_nelements(src1);
-        src1_f16_alloc.alloc(ne_src1);
-        GGML_ASSERT(to_fp16_cuda != nullptr);
-        to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
-    }
-    half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
-
-    ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
-    char * dst_t;
-
-    cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
-    cudaDataType_t      cu_data_type    = CUDA_R_16F;
-
-    // dst strides
-    size_t nbd2 = dst->nb[2];
-    size_t nbd3 = dst->nb[3];
-
-    const half  alpha_f16 = 1.0f;
-    const half  beta_f16  = 0.0f;
-
-    const float alpha_f32 = 1.0f;
-    const float beta_f32  = 0.0f;
-
-    const void * alpha = &alpha_f16;
-    const void * beta  = &beta_f16;
-
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        dst_t = (char *) dst_f16.alloc(ne_dst);
-
-        nbd2 /= sizeof(float) / sizeof(half);
-        nbd3 /= sizeof(float) / sizeof(half);
-    } else {
-        dst_t = (char *) dst_ddf;
-
-        cu_compute_type = CUBLAS_COMPUTE_32F;
-        cu_data_type    = CUDA_R_32F;
-
-        alpha = &alpha_f32;
-        beta  = &beta_f32;
-    }
-
-    GGML_ASSERT(ne12 % ne02 == 0);
-    GGML_ASSERT(ne13 % ne03 == 0);
-
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
-
-#if 0
-    // use cublasGemmEx
-    {
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                int i03 = i13 / r3;
-                int i02 = i12 / r2;
-
-                CUBLAS_CHECK(
-                        cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
-                            ne01, ne11, ne10,
-                            alpha, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F,   nb01/sizeof(half),
-                                   (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F,   nb11/sizeof(float),
-                            beta,  (      char *)       dst_t + i12*nbd2          + i13*nbd3,          cu_data_type, ne01,
-                            cu_compute_type,
-                            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-            }
-        }
-    }
-#else
-#ifdef GGML_USE_MUSA
-    GGML_ASSERT(false);
-#else // !GGML_USE_MUSA
-    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
-        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
-        // use cublasGemmStridedBatchedEx
-        CUBLAS_CHECK(
-        cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
-                ne01, ne11, ne10,
-                alpha, (const char *) src0_f16, CUDA_R_16F,   nb01/nb00, nb02/nb00,  // strideA
-                       (const char *) src1_f16, CUDA_R_16F,   nb11/nb10, nb12/nb10,  // strideB
-                beta,  (      char *)    dst_t, cu_data_type, ne01,       nb2/nb0,   // strideC
-                ne12*ne13,
-                cu_compute_type,
-                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-    } else {
-        // use cublasGemmBatchedEx
-        const int ne23 = ne12*ne13;
-
-        ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
-        ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
-
-        dim3 block_dims(ne13, ne12);
-        k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
-                src0_f16, src1_f16, dst_t,
-                ptrs_src.get(), ptrs_dst.get(),
-                ne12, ne13,
-                ne23,
-                nb02, nb03,
-                src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
-                src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
-                nbd2, nbd3,
-                r2, r3);
-        CUDA_CHECK(cudaGetLastError());
-
-        CUBLAS_CHECK(
-        cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
-                ne01, ne11, ne10,
-                alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/nb00,
-                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   nb11/nb10,
-                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
-                ne23,
-                cu_compute_type,
-                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-    }
-#endif // GGML_USE_MUSA
-#endif
-
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-        to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
-    }
-}
-
-static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
-
-    bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
-    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
-    bool              use_mul_mat_q =  ggml_is_quantized(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
-
-    // if mmvq is available it's a better choice than dmmv:
-#ifndef GGML_CUDA_FORCE_DMMV
-    use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
-#endif // GGML_CUDA_FORCE_DMMV
-
-    bool any_gpus_with_slow_fp16 = false;
-
-    if (split) {
-        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
-        auto & tensor_split = buft_ctx->tensor_split;
-        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-            // skip devices that are not going to do any work:
-            if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
-                continue;
-            }
-
-            const int cc            = ggml_cuda_info().devices[id].cc;
-            use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
-        }
-    } else {
-        const int cc            = ggml_cuda_info().devices[ctx.device].cc;
-        use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
-    }
-
-    // debug helpers
-    //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
-    //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
-    //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
-    //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
-    //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
-    //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
-
-    if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
-        // FP32 precision KQ single-batch for batch size 1 without FlashAttention
-        ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
-    } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
-        // FP32 precision KQV single-batch for batch size 1 without FlashAttention
-        ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
-               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
-        // KQ + KQV multi-batch without FlashAttention
-        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
-    } else if (use_dequantize_mul_mat_vec) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
-    } else if (use_mul_mat_vec_q) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
-    } else if (use_mul_mat_q) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
-    } else {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
-    }
-}
-
-struct mmid_row_mapping {
-    int32_t i1;
-    int32_t i2;
-};
-
-static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
-                                                 int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
-                                                 const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
-                                                 int64_t ne11, int64_t ne10,
-                                                 size_t nb11, size_t nb12) {
-    int32_t iid1 = blockIdx.x;
-    int32_t id = blockIdx.y;
-
-    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
-
-    if (row_id_i != i02) {
-        return;
-    }
-
-    const int64_t i11 = id % ne11;
-    const int64_t i12 = iid1;
-
-    __shared__ int src1_row;
-    if (threadIdx.x == 0) {
-        src1_row = atomicAdd(cur_src1_row, 1);
-        row_mapping[src1_row] = {id, iid1};
-    }
-    __syncthreads();
-
-    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
-    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
-
-    for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
-        src1_row_contiguous[i] = src1_row_original[i];
-    }
-}
-
-static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
-                                                  const mmid_row_mapping * __restrict__ row_mapping,
-                                                  int64_t ne0,
-                                                  size_t nb1, size_t nb2) {
-    int32_t i = blockIdx.x;
-
-    const int32_t i1 = row_mapping[i].i1;
-    const int32_t i2 = row_mapping[i].i2;
-
-    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
-    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
-
-    for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
-        dst_row_original[j] = dst_row_contiguous[j];
-    }
-}
-
-static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-    const ggml_tensor * ids  = dst->src[2];
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
-
-    cudaStream_t stream = ctx.stream();
-
-    const int64_t n_as = ne02;
-    const int64_t n_ids = ids->ne[0];
-
-    std::vector<char> ids_host(ggml_nbytes(ids));
-    const char * ids_dev = (const char *) ids->data;
-    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
-    CUDA_CHECK(cudaStreamSynchronize(stream));
-
-    ggml_tensor src0_row = *src0;
-    ggml_tensor src1_row = *src1;
-    ggml_tensor dst_row  = *dst;
-
-    char * src0_original = (char *) src0->data;
-    char * src1_original = (char *) src1->data;
-    char * dst_original  = (char *)  dst->data;
-
-    src0_row.ne[2] = 1;
-    src0_row.ne[3] = 1;
-    src0_row.nb[3] = nb02;
-
-    src1_row.ne[1] = 1;
-    src1_row.ne[2] = 1;
-    src1_row.ne[3] = 1;
-    src1_row.nb[2] = nb11;
-    src1_row.nb[3] = nb11;
-
-    dst_row.ne[1] = 1;
-    dst_row.ne[2] = 1;
-    dst_row.ne[3] = 1;
-    dst_row.nb[2] = nb1;
-    dst_row.nb[3] = nb1;
-
-    if (ne12 == 1) {
-        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-            for (int64_t id = 0; id < n_ids; id++) {
-                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-
-                GGML_ASSERT(i02 >= 0 && i02 < n_as);
-
-                const int64_t i11 = id % ne11;
-                const int64_t i12 = iid1;
-
-                const int64_t i1 = id;
-                const int64_t i2 = i12;
-
-                src0_row.data = src0_original + i02*nb02;
-                src1_row.data = src1_original + i11*nb11 + i12*nb12;
-                dst_row.data  =  dst_original + i1*nb1   + i2*nb2;
-
-                ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-            }
-        }
-    } else {
-        ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
-        ggml_cuda_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
-
-        src1_row.data = src1_contiguous.get();
-        dst_row.data  =  dst_contiguous.get();
-
-        for (int64_t i02 = 0; i02 < n_as; i02++) {
-            int64_t num_src1_rows = 0;
-
-            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-                for (int64_t id = 0; id < n_ids; id++) {
-                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-
-                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
-
-                    if (row_id_i != i02) {
-                        continue;
-                    }
-
-                    num_src1_rows++;
-                }
-            }
-
-            if (num_src1_rows == 0) {
-                continue;
-            }
-
-            ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
-            ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
-            CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
-
-            {
-                dim3 block_dims(std::min((unsigned int)ne10, 768u));
-                dim3 grid_dims(ids->ne[1], n_ids);
-                k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
-                        src1_original, src1_contiguous.get(),
-                        dev_cur_src1_row.get(), dev_row_mapping.get(),
-                        ids_dev, i02, ids->nb[1], ids->nb[0],
-                        ne11, ne10,
-                        nb11, nb12);
-                CUDA_CHECK(cudaGetLastError());
-            }
-
-            src0_row.data = src0_original + i02*nb02;
-
-            GGML_ASSERT(nb11 == sizeof(float)*ne10);
-            GGML_ASSERT(nb1 == sizeof(float)*ne0);
-
-            src1_row.ne[1] = num_src1_rows;
-            src1_row.nb[1] = nb11;
-            src1_row.nb[2] = num_src1_rows*nb11;
-            src1_row.nb[3] = num_src1_rows*nb11;
-
-            dst_row.ne[1] = num_src1_rows;
-            dst_row.nb[1] = nb1;
-            dst_row.nb[2] = num_src1_rows*nb1;
-            dst_row.nb[3] = num_src1_rows*nb1;
-
-            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-
-            {
-                dim3 block_dims(std::min((unsigned int)ne0, 768u));
-                dim3 grid_dims(num_src1_rows);
-                k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
-                        dst_original, dst_contiguous.get(),
-                        dev_row_mapping.get(),
-                        ne0,
-                        nb1, nb2);
-                CUDA_CHECK(cudaGetLastError());
-            }
-        }
-    }
-}
-
-static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
-    // why is this here instead of mul_mat?
-    if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
-        ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
-    }
-
-    switch (dst->op) {
-        case GGML_OP_ARGMAX:
-            ggml_cuda_argmax(ctx, dst);
-            break;
-        case GGML_OP_COUNT_EQUAL:
-            ggml_cuda_count_equal(ctx, dst);
-            break;
-        case GGML_OP_REPEAT:
-            ggml_cuda_op_repeat(ctx, dst);
-            break;
-        case GGML_OP_REPEAT_BACK:
-            ggml_cuda_op_repeat_back(ctx, dst);
-            break;
-        case GGML_OP_GET_ROWS:
-            ggml_cuda_op_get_rows(ctx, dst);
-            break;
-        case GGML_OP_DUP:
-            ggml_cuda_dup(ctx, dst);
-            break;
-        case GGML_OP_CPY:
-            ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
-            break;
-        case GGML_OP_CONT:
-            ggml_cuda_dup(ctx, dst);
-            break;
-        case GGML_OP_ADD:
-        case GGML_OP_ADD1: // TODO: more efficient implementation
-            ggml_cuda_op_add(ctx, dst);
-            break;
-        case GGML_OP_SUB:
-            ggml_cuda_op_sub(ctx, dst);
-            break;
-        case GGML_OP_ACC:
-            ggml_cuda_op_acc(ctx, dst);
-            break;
-        case GGML_OP_MUL:
-            ggml_cuda_op_mul(ctx, dst);
-            break;
-        case GGML_OP_DIV:
-            ggml_cuda_op_div(ctx, dst);
-            break;
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(dst)) {
-                case GGML_UNARY_OP_NEG:
-                    ggml_cuda_op_neg(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_STEP:
-                    ggml_cuda_op_step(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_GELU:
-                    ggml_cuda_op_gelu(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_SILU:
-                    ggml_cuda_op_silu(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_GELU_QUICK:
-                    ggml_cuda_op_gelu_quick(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_TANH:
-                    ggml_cuda_op_tanh(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_RELU:
-                    ggml_cuda_op_relu(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_SIGMOID:
-                    ggml_cuda_op_sigmoid(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_HARDSIGMOID:
-                    ggml_cuda_op_hardsigmoid(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_HARDSWISH:
-                    ggml_cuda_op_hardswish(ctx, dst);
-                    break;
-                case GGML_UNARY_OP_EXP:
-                    ggml_cuda_op_exp(ctx, dst);
-                    break;
-                default:
-                    return false;
-            }
-            break;
-        case GGML_OP_NORM:
-            ggml_cuda_op_norm(ctx, dst);
-            break;
-        case GGML_OP_GROUP_NORM:
-            ggml_cuda_op_group_norm(ctx, dst);
-            break;
-        case GGML_OP_CONCAT:
-            ggml_cuda_op_concat(ctx, dst);
-            break;
-        case GGML_OP_UPSCALE:
-            ggml_cuda_op_upscale(ctx, dst);
-            break;
-        case GGML_OP_PAD:
-            ggml_cuda_op_pad(ctx, dst);
-            break;
-        case GGML_OP_ARANGE:
-            ggml_cuda_op_arange(ctx, dst);
-            break;
-        case GGML_OP_TIMESTEP_EMBEDDING:
-            ggml_cuda_op_timestep_embedding(ctx, dst);
-            break;
-        case GGML_OP_LEAKY_RELU:
-            ggml_cuda_op_leaky_relu(ctx, dst);
-            break;
-        case GGML_OP_RMS_NORM:
-            ggml_cuda_op_rms_norm(ctx, dst);
-            break;
-        case GGML_OP_MUL_MAT:
-            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
-                GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
-                return false;
-            } else {
-                ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
-            }
-            break;
-        case GGML_OP_MUL_MAT_ID:
-            ggml_cuda_mul_mat_id(ctx, dst);
-            break;
-        case GGML_OP_OUT_PROD:
-            ggml_cuda_out_prod(ctx, dst);
-            break;
-        case GGML_OP_SCALE:
-            ggml_cuda_op_scale(ctx, dst);
-            break;
-        case GGML_OP_SQR:
-            ggml_cuda_op_sqr(ctx, dst);
-            break;
-        case GGML_OP_SQRT:
-            ggml_cuda_op_sqrt(ctx, dst);
-            break;
-        case GGML_OP_SIN:
-            ggml_cuda_op_sin(ctx, dst);
-            break;
-        case GGML_OP_COS:
-            ggml_cuda_op_cos(ctx, dst);
-            break;
-        case GGML_OP_CLAMP:
-            ggml_cuda_op_clamp(ctx, dst);
-            break;
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-                break;
-        case GGML_OP_DIAG_MASK_INF:
-            ggml_cuda_op_diag_mask_inf(ctx, dst);
-            break;
-        case GGML_OP_SOFT_MAX:
-            ggml_cuda_op_soft_max(ctx, dst);
-            break;
-        case GGML_OP_ROPE:
-            ggml_cuda_op_rope(ctx, dst);
-            break;
-        case GGML_OP_IM2COL:
-            ggml_cuda_op_im2col(ctx, dst);
-            break;
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            ggml_cuda_op_conv_transpose_1d(ctx,dst);
-            break;
-        case GGML_OP_POOL_2D:
-            ggml_cuda_op_pool2d(ctx, dst);
-            break;
-        case GGML_OP_SUM:
-            ggml_cuda_op_sum(ctx, dst);
-            break;
-        case GGML_OP_SUM_ROWS:
-            ggml_cuda_op_sum_rows(ctx, dst);
-            break;
-        case GGML_OP_ARGSORT:
-            ggml_cuda_op_argsort(ctx, dst);
-            break;
-        case GGML_OP_FLASH_ATTN_EXT:
-            ggml_cuda_flash_attn_ext(ctx, dst);
-            break;
-        case GGML_OP_CROSS_ENTROPY_LOSS:
-            ggml_cuda_cross_entropy_loss(ctx, dst);
-            break;
-        case GGML_OP_RWKV_WKV6:
-            ggml_cuda_op_rwkv_wkv6(ctx, dst);
-            break;
-        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
-            ggml_cuda_cross_entropy_loss_back(ctx, dst);
-            break;
-        case GGML_OP_OPT_STEP_ADAMW:
-            ggml_cuda_opt_step_adamw(ctx, dst);
-            break;
-        default:
-            return false;
-    }
-
-    cudaError_t err = cudaGetLastError();
-    if (err != cudaSuccess) {
-        GGML_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
-        CUDA_CHECK(err);
-    }
-
-    return true;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend
-
-static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    return cuda_ctx->name.c_str();
-}
-
-static void ggml_backend_cuda_free(ggml_backend_t backend) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    delete cuda_ctx;
-    delete backend;
-}
-
-static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
-    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
-
-    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
-}
-
-static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
-    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
-
-    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
-}
-
-static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
-    ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
-    ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
-
-    if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
-        return false;
-    }
-
-    if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
-        return false;
-    }
-
-    // device -> device copy
-    ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
-    ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
-
-    ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
-    ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
-
-    if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
-#ifndef NDEBUG
-        GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
-#endif
-        return false;
-    }
-
-    if (backend_src != backend_dst) {
-        // copy on src stream
-        if (cuda_ctx_src->device == cuda_ctx_dst->device) {
-            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
-        } else {
-#ifdef GGML_CUDA_NO_PEER_COPY
-            return false;
-#else
-            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
-#endif
-        }
-
-        // record event on src stream after the copy
-        if (!cuda_ctx_src->copy_event) {
-            ggml_cuda_set_device(cuda_ctx_src->device);
-            CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
-        }
-
-        CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
-
-        // wait on dst stream for the copy to complete
-        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
-    } else {
-        // src and dst are on the same backend
-        CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
-    }
-    return true;
-}
-
-static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream()));
-
-    GGML_UNUSED(backend);
-}
-
-#ifdef USE_CUDA_GRAPH
-static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
-    graph_node_properties->node_address = node->data;
-    graph_node_properties->node_op = node->op;
-    for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        graph_node_properties->ne[i] = node->ne[i];
-        graph_node_properties->nb[i] = node->nb[i];
-    }
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
-    }
-    memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
-}
-
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
-    if (node->data != graph_node_properties->node_address &&
-          node->op != GGML_OP_CPY &&
-          node->op != GGML_OP_VIEW) {
-        return false;
-    }
-
-    if (node->op != graph_node_properties->node_op) {
-        return false;
-    }
-
-    for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        if (node->ne[i] != graph_node_properties->ne[i]) {
-            return false;
-        }
-        if (node->nb[i] != graph_node_properties->nb[i]) {
-            return false;
-        }
-    }
-
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (node->src[i] &&
-            node->src[i]->data != graph_node_properties->src_address[i] &&
-            node->op != GGML_OP_CPY &&
-            node->op != GGML_OP_VIEW
-        ) {
-            return false;
-        }
-    }
-
-    if (node->op == GGML_OP_SCALE &&
-        memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
-        return false;
-    }
-
-    return true;
-}
-#endif
-
-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    ggml_cuda_set_device(cuda_ctx->device);
-
-#ifdef USE_CUDA_GRAPH
-    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
-
-    // Objects required for CUDA Graph
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
-    }
-
-    bool use_cuda_graph = true;
-    bool cuda_graph_update_required = false;
-    // vector of pointers to CUDA cpy kernels, which are required to identify
-    // kernel parameters which need updated in the graph for each token
-    std::vector<void *> ggml_cuda_cpy_fn_ptrs;
-
-    if (cuda_ctx->cuda_graph->graph == nullptr) {
-        if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
-            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
-#endif
-        }
-    }
-
-    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
-    // or previous graph capture failure.
-    // Also disable for multi-gpu for now. TO DO investigate
-    if (disable_cuda_graphs_due_to_env
-        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
-        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
-        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
-        use_cuda_graph = false;
-    }
-
-    if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) {
-            cuda_graph_update_required = true;
-        }
-
-        // Check if the graph size has changed
-        if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
-            cuda_graph_update_required = true;
-            cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
-        }
-
-        // Loop over nodes in GGML graph to determine if CUDA graph update is required
-        // and store properties to allow this comparison for the next token
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            bool has_matching_properties = true;
-            if (!cuda_graph_update_required) {
-                has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
-            }
-            if (!has_matching_properties) {
-                cuda_graph_update_required = true;
-            }
-            set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
-        }
-
-        // Loop over nodes in GGML graph to obtain info needed for CUDA graph
-        cuda_ctx->cuda_graph->updated_kernel_arg.clear();
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            ggml_tensor * node = cgraph->nodes[i];
-
-            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
-                continue;
-            }
-
-            if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
-                use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_MUL_MAT_ID) {
-                use_cuda_graph = false; // This node type is not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
-                // disable CUDA graphs for batch size > 1 for now.
-                // Changes in batch size or context size can cause changes to the grid size of some kernels.
-                use_cuda_graph = false;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-#endif
-            }
-
-            if (node->op == GGML_OP_CPY) {
-                // store the copy op parameter which changes with each token.
-                cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
-                // store a pointer to each copy op CUDA kernel to identify it later
-                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
-                if (!ptr) {
-                    use_cuda_graph = false;
-#ifndef NDEBUG
-                    GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
-#endif
-                } else {
-                    if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
-                        ggml_cuda_cpy_fn_ptrs.push_back(ptr);
-                    }
-                }
-            }
-
-            if (!use_cuda_graph) {
-                break;
-            }
-        }
-
-        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
-        if (use_cuda_graph && cuda_graph_update_required) {
-            cuda_ctx->cuda_graph->number_consecutive_updates++;
-        } else {
-            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
-        }
-
-        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
-            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
-#endif
-        }
-    }
-
-    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
-        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
-    }
-
-#else
-    bool use_cuda_graph = false;
-    bool cuda_graph_update_required = false;
-#endif // USE_CUDA_GRAPH
-
-    bool graph_evaluated_or_captured = false;
-
-    while (!graph_evaluated_or_captured) {
-        // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
-        // With the use of CUDA graphs, the execution will be performed by the graph launch.
-        if (!use_cuda_graph || cuda_graph_update_required) {
-            for (int i = 0; i < cgraph->n_nodes; i++) {
-                ggml_tensor * node = cgraph->nodes[i];
-
-                if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
-                    continue;
-                }
-
-#ifndef NDEBUG
-                assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
-                for (int j = 0; j < GGML_MAX_SRC; j++) {
-                    if (node->src[j] != nullptr) {
-                        assert(node->src[j]->buffer);
-                        assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
-                               ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
-                    }
-                }
-#endif
-
-                bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
-                if (!ok) {
-                    GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
-                }
-                GGML_ASSERT(ok);
-            }
-        }
-
-#ifdef USE_CUDA_GRAPH
-        if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
-            if (cuda_ctx->cuda_graph->graph != nullptr) {
-                CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
-                cuda_ctx->cuda_graph->graph = nullptr;
-            }
-            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
-
-#if 0
-            if (disable_cuda_graphs_due_to_failed_capture) {
-                use_cuda_graph = false;
-                cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
-#endif
-            } else {
-                graph_evaluated_or_captured = true; // CUDA graph has been captured
-            }
-#endif
-            graph_evaluated_or_captured = true; // CUDA graph has been captured
-        } else {
-            graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
-        }
-    }
-
-    if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
-        }
-
-        // Perform update to graph (if required for this token), and change copy parameter (required for every token)
-
-        if (cuda_graph_update_required) {
-            // Extract nodes from graph
-            // First call with null argument gets number of nodes in graph
-            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-            // Subsequent call with non-null argument gets nodes
-            cuda_ctx->cuda_graph->nodes.clear();
-            cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
-            cuda_ctx->cuda_graph->params.clear();
-            cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
-            if (cuda_ctx->cuda_graph->num_nodes > 0) {
-                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-
-                // Loop over nodes, and extract kernel parameters from each node
-                for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                    cudaGraphNodeType node_type;
-                    CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
-                    if (node_type == cudaGraphNodeTypeKernel) {
-                        cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
-                        if (stat == cudaErrorInvalidDeviceFunction) {
-                            // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
-                            // We don't need to update blas nodes, so clear error and move on.
-                            cudaGetLastError();
-                        } else {
-                            GGML_ASSERT(stat == cudaSuccess);
-                        }
-                    }
-                }
-            }
-        }
-
-        // One of the arguments to the copy kernel is updated for each token, hence we need to
-        // replace that argument with the updated value in the CUDA graph
-        if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
-            int k = 0;
-            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
-                    char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-                    cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
-                    CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
-                }
-            }
-        }
-
-        // Update graph executable
-        cudaGraphExecUpdateResultInfo result_info;
-        cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-        if (stat == cudaErrorGraphExecUpdateFailure) {
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
-#endif
-            // The pre-existing graph exec cannot be updated due to violated constraints
-            // so instead clear error and re-instantiate
-            cudaGetLastError();
-            CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
-            cuda_ctx->cuda_graph->instance = nullptr;
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
-        } else {
-            GGML_ASSERT(stat == cudaSuccess);
-        }
-        // Launch graph
-        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
-#else
-        graph_evaluated_or_captured = true;
-#endif // USE_CUDA_GRAPH
-    }
-
-    return GGML_STATUS_SUCCESS;
-}
-
-static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream()));
-}
-
-static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    if (ggml_backend_is_cuda(backend)) {
-        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0));
-    } else {
-#if 0
-        // untested
-        auto wait_fn = [](void * user_data) {
-            ggml_backend_event_t event = (ggml_backend_event_t)user_data;
-            ggml_backend_event_synchronize(event);
-        };
-
-        CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));
-#endif
-        GGML_ABORT("fatal error");
-    }
-}
-
-static const ggml_backend_i ggml_backend_cuda_interface = {
-    /* .get_name                = */ ggml_backend_cuda_get_name,
-    /* .free                    = */ ggml_backend_cuda_free,
-    /* .set_tensor_async        = */ ggml_backend_cuda_set_tensor_async,
-    /* .get_tensor_async        = */ ggml_backend_cuda_get_tensor_async,
-    /* .cpy_tensor_async        = */ ggml_backend_cuda_cpy_tensor_async,
-    /* .synchronize             = */ ggml_backend_cuda_synchronize,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
-    /* .event_record            = */ ggml_backend_cuda_event_record,
-    /* .event_wait              = */ ggml_backend_cuda_event_wait,
-};
-
-static ggml_guid_t ggml_backend_cuda_guid() {
-    static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };
-    return &guid;
-}
-
-bool ggml_backend_is_cuda(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());
-}
-
-int ggml_backend_cuda_get_device_count() {
-    return ggml_cuda_info().device_count;
-}
-
-void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) {
-    cudaDeviceProp prop;
-    CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
-    snprintf(description, description_size, "%s", prop.name);
-}
-
-void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) {
-    ggml_cuda_set_device(device);
-
-    CUDA_CHECK(cudaMemGetInfo(free, total));
-}
-
-bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
-    if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
-        return false;
-    }
-
-#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
-    cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
-    if (err != cudaSuccess) {
-        // clear the error
-        cudaGetLastError();
-
-        GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
-                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
-        return false;
-    }
-    return true;
-#else
-    return false;
-#endif
-}
-
-void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
-    if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
-        return;
-    }
-
-    cudaError_t err = cudaHostUnregister(buffer);
-    if (err != cudaSuccess) {
-        // clear the error
-        cudaGetLastError();
-    }
-}
-
-
-// backend device
-
-struct ggml_backend_cuda_device_context {
-    int device;
-    std::string name;
-    std::string description;
-};
-
-static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
-    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
-    return ctx->name.c_str();
-}
-
-static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) {
-    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
-    return ctx->description.c_str();
-}
-
-static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
-    ggml_cuda_set_device(ctx->device);
-    CUDA_CHECK(cudaMemGetInfo(free, total));
-}
-
-static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
-    GGML_UNUSED(dev);
-    return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_cuda_device_get_name(dev);
-    props->description = ggml_backend_cuda_device_get_description(dev);
-    props->type        = ggml_backend_cuda_device_get_type(dev);
-    ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
-
-    bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
-#ifdef GGML_CUDA_NO_PEER_COPY
-    bool events = false;
-#else
-    bool events = true;
-#endif
-
-    props->caps = {
-        /* .async                 = */ true,
-        /* .host_buffer           = */ host_buffer,
-        /* .buffer_from_host_ptr  = */ false,
-        /* .events                = */ events,
-    };
-}
-
-static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) {
-    GGML_UNUSED(params);
-    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
-    return ggml_backend_cuda_init(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) {
-    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
-    return ggml_backend_cuda_buffer_type(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(ggml_backend_dev_t dev) {
-    GGML_UNUSED(dev);
-    return ggml_backend_cuda_host_buffer_type();
-}
-
-// TODO: move these functions here
-static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
-
-    // split buffers can only be used with GGML_OP_MUL_MAT
-    if (op->op != GGML_OP_MUL_MAT) {
-        for (int i = 0; i < GGML_MAX_SRC; i++) {
-            if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) {
-                return false;
-            }
-        }
-    }
-
-    // check if all the sources are allocated on this device
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) {
-            ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context;
-            if (buft_ctx->device != dev_ctx->device) {
-                return false;
-            }
-        }
-    }
-
-    switch (op->op) {
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_NEG:
-                case GGML_UNARY_OP_STEP:
-                case GGML_UNARY_OP_GELU:
-                case GGML_UNARY_OP_SILU:
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_SIGMOID:
-                case GGML_UNARY_OP_HARDSIGMOID:
-                case GGML_UNARY_OP_HARDSWISH:
-                case GGML_UNARY_OP_GELU_QUICK:
-                case GGML_UNARY_OP_TANH:
-                case GGML_UNARY_OP_EXP:
-                    return ggml_is_contiguous(op->src[0]);
-                default:
-                    return false;
-            }
-            break;
-        case GGML_OP_MUL_MAT:
-        case GGML_OP_MUL_MAT_ID:
-            {
-                struct ggml_tensor * a = op->src[0];
-                struct ggml_tensor * b = op->src[1];
-                if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
-                    return false;
-                }
-                if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
-                    return false;
-                }
-#ifdef GGML_USE_MUSA
-                if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
-                    !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
-                    return false;
-                }
-#endif // GGML_USE_MUSA
-                switch (a->type) {
-                    case GGML_TYPE_F32:
-                    case GGML_TYPE_F16:
-                    case GGML_TYPE_Q4_0:
-                    case GGML_TYPE_Q4_1:
-                    case GGML_TYPE_Q5_0:
-                    case GGML_TYPE_Q5_1:
-                    case GGML_TYPE_Q8_0:
-                    case GGML_TYPE_Q2_K:
-                    case GGML_TYPE_Q3_K:
-                    case GGML_TYPE_Q4_K:
-                    case GGML_TYPE_Q5_K:
-                    case GGML_TYPE_Q6_K:
-                    case GGML_TYPE_Q8_K:
-                    case GGML_TYPE_IQ1_M:
-                    case GGML_TYPE_IQ1_S:
-                    case GGML_TYPE_IQ2_S:
-                    case GGML_TYPE_IQ2_XS:
-                    case GGML_TYPE_IQ2_XXS:
-                    case GGML_TYPE_IQ3_S:
-                    case GGML_TYPE_IQ3_XXS:
-                    case GGML_TYPE_IQ4_NL:
-                    case GGML_TYPE_IQ4_XS:
-#ifdef GGML_USE_MUSA
-                        if (a->type == GGML_TYPE_Q3_K) {
-                            return false;
-                        }
-#endif // GGML_USE_MUSA
-                        return true;
-                    default:
-                        return false;
-                }
-            } break;
-        case GGML_OP_OUT_PROD:
-            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
-        case GGML_OP_GET_ROWS:
-            {
-                switch (op->src[0]->type) {
-                    case GGML_TYPE_F16:
-                    case GGML_TYPE_F32:
-                    case GGML_TYPE_Q4_0:
-                    case GGML_TYPE_Q4_1:
-                    case GGML_TYPE_Q5_0:
-                    case GGML_TYPE_Q5_1:
-                    case GGML_TYPE_Q8_0:
-                        return true;
-                    default:
-                        return false;
-                }
-            } break;
-        case GGML_OP_CPY:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                ggml_type src1_type = op->src[1]->type;
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
-                    return true;
-                }
-                return false;
-            } break;
-        case GGML_OP_DUP:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
-            } break;
-        case GGML_OP_ARGMAX:
-        case GGML_OP_COUNT_EQUAL:
-            {
-                return true;
-            } break;
-        case GGML_OP_REPEAT:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
-            } break;
-        case GGML_OP_REPEAT_BACK:
-                return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
-        case GGML_OP_CONCAT:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
-            } break;
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                ggml_type src1_type = op->src[1]->type;
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                return false;
-            } break;
-        case GGML_OP_NORM:
-        case GGML_OP_RMS_NORM:
-            return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
-            break;
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_ADD:
-        case GGML_OP_ADD1:
-        case GGML_OP_SUB:
-        case GGML_OP_MUL:
-        case GGML_OP_DIV:
-        case GGML_OP_SCALE:
-        case GGML_OP_SQR:
-        case GGML_OP_SQRT:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-        case GGML_OP_CLAMP:
-            return true;
-        case GGML_OP_CONT:
-            return op->src[0]->type != GGML_TYPE_BF16;
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX:
-            return true;
-        case GGML_OP_ROPE:
-            return ggml_is_contiguous(op->src[0]);
-        case GGML_OP_IM2COL:
-        case GGML_OP_POOL_2D:
-        case GGML_OP_SUM:
-        case GGML_OP_SUM_ROWS:
-        case GGML_OP_ARGSORT:
-        case GGML_OP_ACC:
-        case GGML_OP_GROUP_NORM:
-        case GGML_OP_UPSCALE:
-        case GGML_OP_PAD:
-        case GGML_OP_ARANGE:
-        case GGML_OP_TIMESTEP_EMBEDDING:
-        case GGML_OP_LEAKY_RELU:
-        case GGML_OP_RWKV_WKV6:
-            return true;
-        case GGML_OP_FLASH_ATTN_EXT: {
-#ifndef FLASH_ATTN_AVAILABLE
-            return false;
-#endif
-            if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
-                return false;
-            }
-            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
-                return true;
-            }
-            if (op->src[0]->ne[0] == 128) {
-                return true;
-            }
-            if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
-                return true;
-            }
-            const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-            return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
-        }
-        case GGML_OP_CROSS_ENTROPY_LOSS:
-        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
-        case GGML_OP_OPT_STEP_ADAMW:
-            return true;
-        default:
-            return false;
-    }
-}
-
-static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
-}
-
-static int64_t get_op_batch_size(const ggml_tensor * op) {
-    switch (op->op) {
-        case GGML_OP_GET_ROWS:
-            return 0;
-        case GGML_OP_MUL_MAT:
-            return op->ne[1];
-        case GGML_OP_MUL_MAT_ID:
-        case GGML_OP_ROPE:
-            return op->ne[2];
-        default:
-            return ggml_nrows(op);
-    }
-}
-
-static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    const int min_batch_size = 32;
-
-    return get_op_batch_size(op) >= min_batch_size;
-
-    GGML_UNUSED(dev);
-}
-
-static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
-#ifdef GGML_CUDA_NO_PEER_COPY
-    return nullptr;
-#else
-    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context;
-
-    ggml_cuda_set_device(dev_ctx->device);
-
-    cudaEvent_t event;
-    CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
-
-    return new ggml_backend_event {
-        /* .device  = */ dev,
-        /* .context = */ event,
-    };
-#endif
-}
-
-static void ggml_backend_cuda_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
-    GGML_UNUSED(dev);
-
-    CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));
-    delete event;
-}
-
-static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
-    GGML_UNUSED(dev);
-    CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
-}
-
-static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
-    /* .get_name                = */ ggml_backend_cuda_device_get_name,
-    /* .get_description         = */ ggml_backend_cuda_device_get_description,
-    /* .get_memory              = */ ggml_backend_cuda_device_get_memory,
-    /* .get_type                = */ ggml_backend_cuda_device_get_type,
-    /* .get_props               = */ ggml_backend_cuda_device_get_props,
-    /* .init_backend            = */ ggml_backend_cuda_device_init_backend,
-    /* .get_buffer_type         = */ ggml_backend_cuda_device_get_buffer_type,
-    /* .get_host_buffer_type    = */ ggml_backend_cuda_device_get_host_buffer_type,
-    /* .buffer_from_host_ptr    = */ NULL,
-    /* .supports_op             = */ ggml_backend_cuda_device_supports_op,
-    /* .supports_buft           = */ ggml_backend_cuda_device_supports_buft,
-    /* .offload_op              = */ ggml_backend_cuda_device_offload_op,
-    /* .event_new               = */ ggml_backend_cuda_device_event_new,
-    /* .event_free              = */ ggml_backend_cuda_device_event_free,
-    /* .event_synchronize       = */ ggml_backend_cuda_device_event_synchronize,
-};
-
-// backend reg
-
-struct ggml_backend_cuda_reg_context {
-    std::vector<ggml_backend_dev_t> devices;
-};
-
-static const char * ggml_backend_cuda_reg_get_name(ggml_backend_reg_t reg) {
-    GGML_UNUSED(reg);
-    return GGML_CUDA_NAME;
-}
-
-static size_t ggml_backend_cuda_reg_get_device_count(ggml_backend_reg_t reg) {
-    ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
-    return ctx->devices.size();
-}
-
-static ggml_backend_dev_t ggml_backend_cuda_reg_get_device(ggml_backend_reg_t reg, size_t index) {
-    ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
-    GGML_ASSERT(index < ctx->devices.size());
-    return ctx->devices[index];
-}
-
-static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
-    GGML_UNUSED(reg);
-    if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
-        return (void *)ggml_backend_cuda_split_buffer_type;
-    }
-    if (strcmp(name, "ggml_backend_register_host_buffer") == 0) {
-        return (void *)ggml_backend_cuda_register_host_buffer;
-    }
-    if (strcmp(name, "ggml_backend_unregister_host_buffer") == 0) {
-        return (void *)ggml_backend_cuda_unregister_host_buffer;
-    }
-    return nullptr;
-}
-
-static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = {
-    /* .get_name          = */ ggml_backend_cuda_reg_get_name,
-    /* .get_device_count  = */ ggml_backend_cuda_reg_get_device_count,
-    /* .get_device_get    = */ ggml_backend_cuda_reg_get_device,
-    /* .get_proc_address  = */ ggml_backend_cuda_reg_get_proc_address,
-};
-
-// backend registry
-ggml_backend_reg_t ggml_backend_cuda_reg() {
-    static ggml_backend_reg reg;
-    static bool initialized = false;
-
-    {
-        static std::mutex mutex;
-        std::lock_guard<std::mutex> lock(mutex);
-        if (!initialized) {
-            ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
-
-            for (int i = 0; i < ggml_cuda_info().device_count; i++) {
-                ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
-                dev_ctx->device = i;
-                dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
-
-                ggml_cuda_set_device(i);
-                cudaDeviceProp prop;
-                CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
-                dev_ctx->description = prop.name;
-
-                ggml_backend_dev_t dev = new ggml_backend_device {
-                    /* .interface = */ ggml_backend_cuda_device_interface,
-                    /* .reg       = */ &reg,
-                    /* .context   = */ dev_ctx
-                };
-                ctx->devices.push_back(dev);
-            }
-
-            reg = ggml_backend_reg {
-                /* .interface = */ ggml_backend_cuda_reg_interface,
-                /* .context   = */ ctx
-            };
-        }
-
-        initialized = true;
-    }
-
-    return &reg;
-}
-
-ggml_backend_t ggml_backend_cuda_init(int device) {
-    if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
-        GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device);
-        return nullptr;
-    }
-
-    ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
-    if (ctx == nullptr) {
-        GGML_LOG_ERROR("%s: failed to allocate context\n", __func__);
-        return nullptr;
-    }
-
-    ggml_backend_t cuda_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_cuda_guid(),
-        /* .interface = */ ggml_backend_cuda_interface,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
-        /* .context   = */ ctx,
-    };
-
-    return cuda_backend;
-}
diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
new file mode 100644 (file)
index 0000000..e1482a2
--- /dev/null
@@ -0,0 +1,155 @@
+cmake_minimum_required(VERSION 3.18)  # for CMAKE_CUDA_ARCHITECTURES
+
+find_package(CUDAToolkit)
+
+if (CUDAToolkit_FOUND)
+    message(STATUS "CUDA Toolkit found")
+
+    if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+        # native == GPUs available at build time
+        # 52     == Maxwell, lowest CUDA 12 standard
+        # 60     == P100, FP16 CUDA intrinsics
+        # 61     == Pascal, __dp4a instruction (per-byte integer dot product)
+        # 70     == V100, FP16 tensor cores
+        # 75     == Turing, int8 tensor cores
+        if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
+            set(CMAKE_CUDA_ARCHITECTURES "native")
+        elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
+            set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
+        else()
+            set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
+        endif()
+    endif()
+    message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
+
+    enable_language(CUDA)
+
+    file(GLOB   GGML_HEADERS_CUDA "*.cuh")
+    list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
+
+    file(GLOB   GGML_SOURCES_CUDA "*.cu")
+    file(GLOB   SRCS "template-instances/fattn-wmma*.cu")
+    list(APPEND GGML_SOURCES_CUDA ${SRCS})
+    file(GLOB   SRCS "template-instances/mmq*.cu")
+    list(APPEND GGML_SOURCES_CUDA ${SRCS})
+
+    if (GGML_CUDA_FA_ALL_QUANTS)
+        file(GLOB   SRCS "template-instances/fattn-vec*.cu")
+        list(APPEND GGML_SOURCES_CUDA ${SRCS})
+        add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
+    else()
+        file(GLOB   SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
+        list(APPEND GGML_SOURCES_CUDA ${SRCS})
+        file(GLOB   SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
+        list(APPEND GGML_SOURCES_CUDA ${SRCS})
+        file(GLOB   SRCS "template-instances/fattn-vec*f16-f16.cu")
+        list(APPEND GGML_SOURCES_CUDA ${SRCS})
+    endif()
+
+    add_library(ggml-cuda
+                ${GGML_HEADERS_CUDA}
+                ${GGML_SOURCES_CUDA}
+                )
+
+    target_link_libraries(ggml-cuda PRIVATE ggml-base)
+    target_include_directories(ggml-cuda PRIVATE . ..)
+
+    add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
+
+    if (GGML_CUDA_GRAPHS)
+        add_compile_definitions(GGML_CUDA_USE_GRAPHS)
+    endif()
+
+    if (GGML_CUDA_FORCE_MMQ)
+        add_compile_definitions(GGML_CUDA_FORCE_MMQ)
+    endif()
+
+    if (GGML_CUDA_FORCE_CUBLAS)
+        add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
+    endif()
+
+    if (GGML_CUDA_NO_VMM)
+        add_compile_definitions(GGML_CUDA_NO_VMM)
+    endif()
+
+    if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
+        add_compile_definitions(GGML_CUDA_F16)
+    endif()
+
+    if (GGML_CUDA_NO_PEER_COPY)
+        add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+    endif()
+
+    if (GGML_STATIC)
+        if (WIN32)
+            # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
+            target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
+        else ()
+            target_link_libraries(ggml-cuda PRIVATE  CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+        endif()
+    else()
+        target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
+    endif()
+
+    if (GGML_CUDA_NO_VMM)
+        # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
+    else()
+        target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver)
+    endif()
+
+    set(CUDA_CXX_FLAGS "")
+
+    set(CUDA_FLAGS -use_fast_math)
+
+    if (GGML_FATAL_WARNINGS)
+        list(APPEND CUDA_FLAGS -Werror all-warnings)
+    endif()
+
+    if (GGML_ALL_WARNINGS AND NOT MSVC)
+        set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
+        if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
+            list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})
+        endif()
+
+        execute_process(
+            COMMAND ${NVCC_CMD} -Xcompiler --version
+            OUTPUT_VARIABLE CUDA_CCFULLVER
+            ERROR_QUIET
+        )
+
+        if (NOT CUDA_CCFULLVER MATCHES clang)
+            set(CUDA_CCID "GNU")
+            execute_process(
+                COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
+                OUTPUT_VARIABLE CUDA_CCVER
+                ERROR_QUIET
+            )
+        else()
+            if (CUDA_CCFULLVER MATCHES Apple)
+                set(CUDA_CCID "AppleClang")
+            else()
+                set(CUDA_CCID "Clang")
+            endif()
+            string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
+        endif()
+
+        message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
+
+        get_flags(${CUDA_CCID} ${CUDA_CCVER})
+        list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS})  # This is passed to -Xcompiler later
+    endif()
+
+    if (NOT MSVC)
+        list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
+    endif()
+
+    list(JOIN   CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED)  # pass host compiler flags as a single argument
+
+    if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
+        list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
+    endif()
+
+    target_compile_options(ggml-cuda PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
+else()
+    message(FATAL_ERROR "CUDA Toolkit not found")
+endif()
diff --git a/ggml/src/ggml-kompute.cpp b/ggml/src/ggml-kompute.cpp
deleted file mode 100644 (file)
index 2fea9e4..0000000
+++ /dev/null
@@ -1,2184 +0,0 @@
-#include "ggml-impl.h"
-#include "ggml-backend.h"
-#include "ggml-backend-impl.h"
-#include "ggml-kompute.h"
-
-// These are generated at build time by cmake custom command
-#include "shaderop_scale.h"
-#include "shaderop_scale_8.h"
-#include "shaderop_add.h"
-#include "shaderop_addrow.h"
-#include "shaderop_mul.h"
-#include "shaderop_silu.h"
-#include "shaderop_relu.h"
-#include "shaderop_gelu.h"
-#include "shaderop_softmax.h"
-#include "shaderop_norm.h"
-#include "shaderop_rmsnorm.h"
-#include "shaderop_diagmask.h"
-#include "shaderop_mul_mat_f16.h"
-#include "shaderop_mul_mat_q8_0.h"
-#include "shaderop_mul_mat_q4_0.h"
-#include "shaderop_mul_mat_q4_1.h"
-#include "shaderop_mul_mat_q4_k.h"
-#include "shaderop_mul_mat_q6_k.h"
-#include "shaderop_mul_mat_mat_f32.h"
-#include "shaderop_getrows_f32.h"
-#include "shaderop_getrows_f16.h"
-#include "shaderop_getrows_q4_0.h"
-#include "shaderop_getrows_q4_1.h"
-#include "shaderop_getrows_q6_k.h"
-#include "shaderop_rope_f16.h"
-#include "shaderop_rope_f32.h"
-#include "shaderop_cpy_f16_f16.h"
-#include "shaderop_cpy_f16_f32.h"
-#include "shaderop_cpy_f32_f16.h"
-#include "shaderop_cpy_f32_f32.h"
-
-#include <algorithm>
-#include <array>
-#include <cassert>
-#include <cstdint>
-#include <cstdio>
-#include <cstring>
-#include <iostream>
-#include <memory>
-#include <mutex>
-#include <stdexcept>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-#include <kompute/Kompute.hpp>
-#include <vulkan/vulkan.hpp>
-
-#ifdef __linux__
-#include <cstdlib> // for setenv
-#endif
-
-#define QK4_0 32
-#define QR4_0 2
-#define QK4_1 32
-#define QK_NL 16
-
-typedef ggml_fp16_t half;
-
-static std::string ggml_kompute_format_name(int device) {
-    return "Kompute" + std::to_string(device);
-}
-
-struct ggml_kompute_context {
-    int device;
-    std::string name;
-    std::shared_ptr<vk::DescriptorPool> pool;
-
-    ggml_kompute_context(int device)
-        : device(device), name(ggml_kompute_format_name(device)) {}
-};
-
-// FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
-// and consolidate the init functions and simplify object lifetime management. As it currently stands,
-// we *have* to have the kompute manager no matter what for device discovery, but the kompute context
-// is only created when a device is set and vulkan is explicitly turned on.
-static ggml_kompute_context *s_kompute_context = nullptr;
-
-class kompute_manager {
-    kp::Manager *s_mgr = nullptr;
-
-public:
-    kp::Manager *operator()() {
-        if (s_mgr && !s_mgr->hasInstance()) {
-            destroy();
-        }
-        if (!s_mgr) {
-            s_mgr = new kp::Manager;
-        }
-        return s_mgr;
-    }
-
-    void destroy() {
-        delete s_mgr;
-        s_mgr = nullptr;
-    }
-};
-
-static kompute_manager komputeManager;
-
-struct ggml_vk_memory {
-    void *data = nullptr;
-    size_t size = 0;
-    vk::DeviceMemory *primaryMemory = nullptr;
-    vk::Buffer *primaryBuffer = nullptr;
-    vk::DeviceMemory *stagingMemory = nullptr;
-    vk::Buffer *stagingBuffer = nullptr;
-};
-
-#ifdef __linux__
-__attribute__((constructor))
-static void enable_sam() {
-    setenv("RADV_PERFTEST", "sam", false);
-}
-#endif
-
-static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
-    vk::PhysicalDeviceFeatures availableFeatures;
-    physical_device.getFeatures(&availableFeatures);
-
-    if (!availableFeatures.shaderInt16)
-        return false;
-
-    vk::PhysicalDeviceVulkan11Features availableFeatures11;
-    vk::PhysicalDeviceVulkan12Features availableFeatures12;
-
-    availableFeatures11.pNext = &availableFeatures12;
-    availableFeatures12.pNext = nullptr;
-
-    vk::PhysicalDeviceFeatures2 features2;
-    features2.pNext = &availableFeatures11;
-
-    physical_device.getFeatures2(&features2);
-
-    if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
-        !availableFeatures11.storageBuffer16BitAccess) {
-        return false;
-    }
-
-    if (!availableFeatures12.storageBuffer8BitAccess ||
-        !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
-        !availableFeatures12.shaderFloat16 ||
-        !availableFeatures12.shaderInt8) {
-        return false;
-    }
-
-    return true;
-}
-
-static const char * ggml_vk_getVendorName(uint32_t vendorID) {
-    switch (vendorID) {
-        case 0x10DE:
-            return "nvidia";
-        case 0x1002:
-            return "amd";
-        case 0x8086:
-            return "intel";
-        default:
-            return "unknown";
-    }
-}
-
-static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t memoryRequired) {
-    std::vector<ggml_vk_device> results;
-    if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
-        return results;
-
-    std::vector<vk::PhysicalDevice> physical_devices;
-    try {
-        physical_devices = komputeManager()->listDevices();
-    } catch (vk::SystemError & err) {
-        std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
-        return results;
-    }
-
-    uint32_t deviceCount = physical_devices.size();
-    if (deviceCount == 0)
-        return results;
-
-    std::unordered_map<std::string, size_t> count_by_name;
-
-    for (uint32_t i = 0; i < deviceCount; i++) {
-        const auto & physical_device = physical_devices[i];
-
-        VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
-        VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
-        const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
-        const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
-        if (major < 1 || minor < 2)
-            continue;
-
-        if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
-            continue;
-
-        size_t heapSize = 0;
-        for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
-            VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
-            if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
-                heapSize = heap.size;
-                break;
-            }
-        }
-
-        if (heapSize < memoryRequired)
-            continue;
-
-        auto ext_props = physical_device.enumerateDeviceExtensionProperties();
-        bool has_maintenance4 = false;
-
-        // Check if maintenance4 is supported
-        for (const auto & properties : ext_props) {
-            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
-                has_maintenance4 = true;
-            }
-        }
-
-        vk::PhysicalDeviceSubgroupProperties subgroup_props;
-        vk::PhysicalDeviceProperties2 dev_props2;
-        vk::PhysicalDeviceMaintenance3Properties dev_props3;
-        vk::PhysicalDeviceMaintenance4Properties dev_props4;
-        dev_props2.pNext = &dev_props3;
-        dev_props3.pNext = &subgroup_props;
-        if (has_maintenance4) {
-            subgroup_props.pNext = &dev_props4;
-        }
-        physical_device.getProperties2(&dev_props2);
-
-        if (subgroup_props.subgroupSize < 32)
-            continue;
-
-        ggml_vk_device d;
-        d.index = i;
-        d.type = dev_props.deviceType;
-        d.heapSize = heapSize;
-        d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
-        d.subgroupSize = subgroup_props.subgroupSize;
-        d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
-
-        if (has_maintenance4) {
-            d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
-        } else {
-            d.maxAlloc = dev_props3.maxMemoryAllocationSize;
-        }
-
-        std::string name(dev_props.deviceName);
-        size_t n_idx = ++count_by_name[name];
-        if (n_idx > 1) {
-            name += " (" + std::to_string(n_idx) + ")";
-        }
-        d.name = strdup(name.c_str());
-
-        results.push_back(d);
-    }
-
-    std::stable_sort(results.begin(), results.end(),
-        [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
-            if (lhs.type != rhs.type) {
-                if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
-                if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
-
-                if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
-                if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
-            }
-            return lhs.heapSize < rhs.heapSize;
-        }
-    );
-
-    return results;
-}
-
-static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
-    static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
-    return devices;
-}
-
-static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
-    devices.erase(
-        std::remove_if(devices.begin(), devices.end(),
-            [&targetVendor](const ggml_vk_device& device) {
-                return device.vendor != targetVendor;
-            }),
-        devices.end()
-    );
-}
-
-static void ggml_vk_filterByName(std::vector<ggml_vk_device>& devices, const std::string& targetName) {
-    devices.erase(
-        std::remove_if(devices.begin(), devices.end(),
-            [&targetName](const ggml_vk_device& device) {
-                return device.name != targetName;
-            }),
-        devices.end()
-    );
-}
-
-static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
-    if (name.empty())
-        return false;
-
-    auto devices = ggml_vk_available_devices_internal(memoryRequired);
-    if (name == "amd" || name == "nvidia" || name == "intel") {
-        ggml_vk_filterByVendor(devices, name);
-    } else if (name != "gpu") {
-        ggml_vk_filterByName(devices, name);
-    }
-
-    if (devices.empty())
-        return false;
-
-    *device = devices.front();
-    return true;
-}
-
-bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
-    return ggml_vk_get_device(device, memoryRequired, std::string(name));
-}
-
-bool ggml_vk_has_vulkan() {
-    return komputeManager()->hasVulkan();
-}
-
-bool ggml_vk_has_device() {
-    return komputeManager()->hasDevice();
-}
-
-ggml_vk_device ggml_vk_current_device() {
-    if (!komputeManager()->hasDevice())
-        return ggml_vk_device();
-
-    auto devices = ggml_vk_available_devices();
-    ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
-    GGML_ASSERT(!devices.empty());
-    return devices.front();
-}
-
-static
-void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
-    std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
-        vk::DescriptorPoolSize(
-          vk::DescriptorType::eStorageBuffer,
-          3 * size // Descriptor count is number of possible tensors to pass into an algorithm
-          )
-    };
-
-    vk::DescriptorPoolCreateInfo descriptorPoolInfo(
-      vk::DescriptorPoolCreateFlags(),
-      size, // Max sets
-      static_cast<uint32_t>(descriptorPoolSizes.size()),
-      descriptorPoolSizes.data());
-
-    ctx->pool = std::make_shared<vk::DescriptorPool>();
-    vk::Result r = komputeManager()->device()->createDescriptorPool(
-      &descriptorPoolInfo, nullptr, ctx->pool.get());
-    if (r != vk::Result::eSuccess)
-        std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
-}
-
-static
-void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
-    if (ctx->pool) {
-        komputeManager()->device()->destroy(
-          *ctx->pool,
-          (vk::Optional<const vk::AllocationCallbacks>)nullptr);
-        ctx->pool = nullptr;
-    }
-}
-
-static
-vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
-    vk::BufferCreateInfo bufferCreateInfo;
-    bufferCreateInfo.size = size;
-    bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
-                             vk::BufferUsageFlagBits::eTransferSrc |
-                             vk::BufferUsageFlagBits::eTransferDst;
-    bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
-
-    vk::Buffer *vkBuffer = new vk::Buffer;
-    vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
-    if (r != vk::Result::eSuccess)
-        std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
-    return vkBuffer;
-}
-
-static
-vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
-
-    uint32_t memoryTypeIndex = -1;
-    bool memoryTypeIndexFound = false;
-    vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
-    for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
-        const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
-        const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
-        if (memoryHeap.size < size) {
-            continue;
-        }
-
-        if (requirements.memoryTypeBits & (1 << i)) {
-            if (((memoryProperties.memoryTypes[i]).propertyFlags &
-                 flags) == flags) {
-                memoryTypeIndex = i;
-                memoryTypeIndexFound = true;
-                if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
-                    *isHostVisible = true;
-                }
-                break;
-            }
-        }
-    }
-    if (!memoryTypeIndexFound) {
-        throw std::runtime_error(
-          "Memory type index for buffer creation not found");
-    }
-
-    vk::MemoryAllocateInfo allocInfo;
-    allocInfo.allocationSize = size;
-    allocInfo.memoryTypeIndex = memoryTypeIndex;
-    vk::DeviceMemory *vkDeviceMemory =  new vk::DeviceMemory;
-    vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
-    if (r != vk::Result::eSuccess) {
-        std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
-        throw std::runtime_error("Error allocating vulkan memory.");
-    }
-    return vkDeviceMemory;
-}
-
-static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
-    size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
-
-    // If offset is already aligned, return it directly
-    if (offset % minStorageBufferOffsetAlignment == 0) {
-        return offset;
-    }
-
-    // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
-    return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
-}
-
-static ggml_vk_memory ggml_vk_allocate(size_t size) {
-    ggml_vk_memory memory;
-    bool isHostVisible = false;
-    {
-        memory.primaryBuffer = ggml_vk_allocate_buffer(size);
-        vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
-        vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
-        memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
-        komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
-        if (isHostVisible) {
-            vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
-            if (r != vk::Result::eSuccess)
-                std::cerr << "Error mapping memory" << vk::to_string(r);
-        }
-    }
-
-    if (!isHostVisible) {
-        memory.stagingBuffer = ggml_vk_allocate_buffer(size);
-        vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
-        vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
-                                                      vk::MemoryPropertyFlagBits::eHostCoherent |
-                                                      vk::MemoryPropertyFlagBits::eHostCached;
-        memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
-        komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
-        vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
-        if (r != vk::Result::eSuccess)
-            std::cerr << "Error mapping memory" << vk::to_string(r);
-    }
-
-    memory.size = size;
-    return memory;
-}
-
-static void ggml_vk_free_memory(ggml_vk_memory &memory)
-{
-    komputeManager()->device()->destroy(
-      *memory.primaryBuffer,
-      (vk::Optional<const vk::AllocationCallbacks>)nullptr);
-    if (memory.stagingBuffer) {
-        komputeManager()->device()->destroy(
-          *memory.stagingBuffer,
-          (vk::Optional<const vk::AllocationCallbacks>)nullptr);
-    }
-    komputeManager()->device()->freeMemory(
-      *memory.primaryMemory,
-      (vk::Optional<const vk::AllocationCallbacks>)nullptr);
-    if (memory.stagingMemory) {
-        komputeManager()->device()->freeMemory(
-          *memory.stagingMemory,
-          (vk::Optional<const vk::AllocationCallbacks>)nullptr);
-    }
-}
-
-static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
-
-static
-ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
-    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
-
-    // compatibility with ggml-backend
-    GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
-
-    ggml_vk_memory * buf_ctx = static_cast<ggml_vk_memory *>(buffer->context);
-
-    const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
-
-    GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
-
-    offset = uint64_t(ioffs);
-    return buf_ctx;
-}
-
-static
-const std::shared_ptr<kp::Tensor> ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
-    uint64_t originalOffset = 0;
-    auto * res = ggml_vk_find_tensor(t, originalOffset);
-    if (!res) {
-        static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
-        return nullTensor;
-    }
-
-    // Create a tensor whose memory will be composed of our buffers at the correct offset
-    const size_t nelements = ggml_nelements(t);
-    size_t nbytes = ggml_nbytes(t);
-
-    size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
-    if (alignedOffset) {
-        *alignedOffset = originalOffset - vulkanOffset;
-        nbytes += *alignedOffset;
-    }
-
-    return komputeManager()->tensor(
-        t->data,
-        nelements,
-        nbytes, kp::Tensor::TensorDataTypes::eFloat,
-        res->primaryMemory, res->primaryBuffer,
-        res->stagingMemory, res->stagingBuffer,
-        vulkanOffset);
-}
-
-static std::vector<uint32_t> getSpirvShader(const unsigned char* rawData, size_t size) {
-    if (size % sizeof(uint32_t) != 0) {
-        throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
-    }
-
-    const uint32_t* data_ptr = reinterpret_cast<const uint32_t*>(rawData);
-    size_t count = size / sizeof(uint32_t);
-    return std::vector<uint32_t>(data_ptr, data_ptr + count);
-}
-
-inline static
-uint32_t safe_divide(uint32_t a, uint32_t b) {
-    if (b <= 1) {
-        return a;
-    }
-    if ((a % b) != 0) {
-        fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
-        GGML_ABORT("safe_divide result would've had remainder");
-    }
-    return a / b;
-}
-
-static void ggml_vk_add(
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
-    int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
-    int32_t ne0,
-    int32_t nb0,  int32_t nb1,  int32_t nb2,  int32_t nb3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
-        kp::shader_data::op_add_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00;
-        int32_t nb00, nb01, nb02, nb03;
-        int32_t ne10, ne11, ne12, ne13;
-        int32_t nb10, nb11, nb12, nb13;
-        int32_t ne0;
-        int32_t nb0, nb1, nb2, nb3;
-    } const pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00,
-        nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, ne13,
-        nb10, nb11, nb12, nb13,
-        ne0,
-        nb0, nb1, nb2, nb3
-    };
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_addrow(kp::Sequence& seq,
-                 const std::shared_ptr<kp::Tensor>& inA,
-                 const std::shared_ptr<kp::Tensor>& inB,
-                 const std::shared_ptr<kp::Tensor>& out,
-                 uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-                 uint32_t size, uint32_t row = 0) {
-
-    const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
-        kp::shader_data::op_addrow_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        uint32_t row;
-    } const pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        row
-    };
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__))
-        s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
-    else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({size});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul(
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
-    int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
-    int32_t ne0,
-    int32_t nb0,  int32_t nb1,  int32_t nb2,  int32_t nb3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
-        kp::shader_data::op_mul_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00;
-        int32_t nb00, nb01, nb02, nb03;
-        int32_t ne10, ne11, ne12, ne13;
-        int32_t nb10, nb11, nb12, nb13;
-        int32_t ne0;
-        int32_t nb0, nb1, nb2, nb3;
-    } const pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00,
-        nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, ne13,
-        nb10, nb11, nb12, nb13,
-        ne0,
-        nb0, nb1, nb2, nb3
-    };
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_scale(kp::Sequence& seq,
-                   const std::shared_ptr<kp::Tensor>& in,
-                   const std::shared_ptr<kp::Tensor>& out,
-                   uint32_t inOff, uint32_t outOff,
-                   uint32_t size, float scale) {
-    const static auto spirv_1 = getSpirvShader(
-        kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
-    );
-    const static auto spirv_8 = getSpirvShader(
-        kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
-    );
-
-    struct PushConstants {
-        uint32_t inOff, outOff;
-        float scale;
-    } const pushConsts {
-        safe_divide(inOff, 4), safe_divide(outOff, 4),
-        scale
-    };
-
-    const auto * spirv = &spirv_1;
-    std::string name(__func__);
-    if (size % 8 == 0) {
-        size /= 8;
-        name += "_8";
-        spirv = &spirv_8;
-    }
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({size});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_xxlu(
-    const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& in,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inOff, uint32_t outOff,
-    uint32_t size
-) {
-    struct PushConstants {
-        uint32_t inOff, outOff;
-    } const pushConsts {
-        safe_divide(inOff, 4), safe_divide(outOff, 4),
-    };
-
-    auto name = std::string(__func__) + "_" + suffix;
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({size});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_silu(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
-        kp::shader_data::op_silu_comp_spv_len);
-
-    ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_relu(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
-        kp::shader_data::op_relu_comp_spv_len);
-
-    ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_gelu(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
-        kp::shader_data::op_gelu_comp_spv_len);
-
-    ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
-}
-
-static void ggml_vk_soft_max(
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
-    float scale
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
-        kp::shader_data::op_softmax_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne01, ne02;
-        float scale;
-        int32_t mask;
-    } pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne01, ne02,
-        scale,
-        bool(inB)
-    };
-
-    auto & inB_ = inB ? inB : inA;
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
-        const uint32_t local_x = 32;
-        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB_, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_norm_(
-    const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& in,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inOff, uint32_t outOff,
-    int32_t ne00, int32_t nb01,
-    int32_t nrows, float epsilon
-) {
-    GGML_ASSERT(nb01%sizeof(float) == 0);
-    GGML_ASSERT(ne00%sizeof(float) == 0);
-
-    struct PushConstants {
-        uint32_t inOff, outOff;
-        uint32_t ne00, nb01;
-        float eps;
-    } pushConsts {
-        safe_divide(inOff, 4), safe_divide(outOff, 4),
-        (uint32_t)ne00, (uint32_t)nb01, epsilon
-    };
-
-    auto name = std::string(__func__) + "_" + suffix;
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({(uint32_t)nrows});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_norm(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
-        kp::shader_data::op_norm_comp_spv_len);
-
-    ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_rms_norm(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
-        kp::shader_data::op_rmsnorm_comp_spv_len);
-
-    ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
-}
-
-static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
-                           const std::shared_ptr<kp::Tensor>& in,
-                           const std::shared_ptr<kp::Tensor>& out,
-                           uint32_t inOff, uint32_t outOff,
-                           uint32_t n_past,
-                           int32_t ne00, int32_t ne01, int32_t ne02) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
-        kp::shader_data::op_diagmask_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inOff, outOff;
-        uint32_t n_past;
-        int32_t ne00, ne01;
-    } pushConsts {
-        safe_divide(inOff, 4), safe_divide(outOff, 4),
-        n_past,
-        ne00, ne01
-    };
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__))
-        s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
-    else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul_mat_f16(
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02,
-    uint32_t nb00, uint32_t nb01, uint32_t nb02,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    uint32_t nb10, uint32_t nb11, uint32_t nb12,
-    int32_t ne0, int32_t ne1,
-    uint32_t r2, uint32_t r3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
-        kp::shader_data::op_mul_mat_f16_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne01, ne02;
-        uint32_t nb00, nb01, nb02;
-        int32_t ne10, ne11, ne12;
-        uint32_t nb10, nb11, nb12;
-        int32_t ne0, ne1;
-        uint32_t r2, r3;
-    } pushConsts {
-        safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne01, ne02,
-        nb00, nb01, nb02,
-        ne10, ne11, ne12,
-        nb10, nb11, nb12,
-        ne0, ne1,
-        r2, r3
-    };
-
-    const unsigned ny = unsigned((ne11 + 4 - 1)/4);
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
-        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
-                         const std::shared_ptr<kp::Tensor>& inA,
-                         const std::shared_ptr<kp::Tensor>& inB,
-                         const std::shared_ptr<kp::Tensor>& out,
-                         uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-                         int32_t ne00, int32_t ne01, int32_t ne02,
-                         uint32_t nb01, uint32_t nb02,
-                         int32_t ne11, int32_t ne12,
-                         uint32_t nb11, uint32_t nb12,
-                         uint32_t nb1, uint32_t nb2) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
-        kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne01, ne02, ne11, ne12;
-        uint32_t nb01, nb02;
-        uint32_t nb11, nb12;
-        uint32_t nb1, nb2;
-    } pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne01, ne02, ne11, ne12,
-        nb01, nb02, nb11, nb12,
-        nb1, nb2
-    };
-
-    const uint32_t local_x = ggml_vk_current_device().subgroupSize;
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(),
-        {inA, inB, out}, spirv,
-        {unsigned(ne01),
-         unsigned(ne11),
-         unsigned(std::max(ne12, ne02))
-         },
-        {local_x},
-        {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01),
-                              unsigned(ne11),
-                              unsigned(std::max(ne12, ne02)),
-                              });
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul_mat_impl(
-    const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    int32_t ne0, int32_t ne1,
-    uint32_t r2, uint32_t r3
-) {
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne01, ne02;
-        int32_t ne10, ne12;
-        int32_t ne0, ne1;
-        uint32_t r2, r3;
-    } pushConsts {
-        safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne01, ne02,
-        ne10, ne12,
-        ne0, ne1,
-        r2, r3
-    };
-
-    auto name = std::string(__func__) + "_" + suffix;
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
-        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_mul_mat_q4_0(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
-        kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
-
-    ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_mul_mat_q4_1(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
-        kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
-
-    ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_mul_mat_q8_0(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
-        kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
-
-    ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
-}
-
-static void ggml_vk_mul_mat_q4_k(
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
-    int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
-    int32_t ne1, int32_t r2, int32_t r3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
-        kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
-    } pushConsts {
-        0, 0, 0,
-        ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
-    };
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_mul_mat_q6_k(
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
-    int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
-        kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne10, ne0, ne1, ne01, gqa;
-    } pushConsts {
-        inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne10, ne0, ne1, ne01, ne12/ne02
-    };
-
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
-        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_get_rows(
-    const std::vector<uint32_t>& spirv,
-    const char * suffix,
-    unsigned element_size, unsigned qk,
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t nb01, int32_t nb1,
-    uint32_t size
-) {
-    GGML_ASSERT(nb01%element_size == 0);
-    GGML_ASSERT(nb1%sizeof(float) == 0);
-    if (qk) GGML_ASSERT(ne00%qk == 0);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, nb01, nb1;
-    } pushConsts {
-        safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, nb01, nb1
-    };
-
-    auto name = std::string(__func__) + "_" + suffix;
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({size});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_f32(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
-        kp::shader_data::op_getrows_f32_comp_spv_len);
-
-    ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_f16(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
-        kp::shader_data::op_getrows_f16_comp_spv_len);
-
-    ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_q4_0(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
-        kp::shader_data::op_getrows_q4_0_comp_spv_len);
-
-    ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_q4_1(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
-        kp::shader_data::op_getrows_q4_1_comp_spv_len);
-
-    ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_get_rows_q6_k(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
-        kp::shader_data::op_getrows_q6_k_comp_spv_len);
-    ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
-}
-
-static void ggml_vk_rope(
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& inA,
-    const std::shared_ptr<kp::Tensor>& inB,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
-    float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
-    int32_t ne01, int32_t ne02, int32_t ne03,
-    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
-    int32_t ne0,
-    uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
-) {
-    GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
-
-    static const auto spirv_f16 = getSpirvShader(
-        kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
-    );
-    static const auto spirv_f32 = getSpirvShader(
-        kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
-    );
-
-    int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
-
-    GGML_ASSERT(nb03 % type_size == 0);
-    GGML_ASSERT(nb02 % type_size == 0);
-    GGML_ASSERT(nb01 % type_size == 0);
-    GGML_ASSERT(nb00 % type_size == 0);
-    GGML_ASSERT(nb3  % type_size == 0);
-    GGML_ASSERT(nb2  % type_size == 0);
-    GGML_ASSERT(nb1  % type_size == 0);
-    GGML_ASSERT(nb0  % type_size == 0);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t n_dims, mode, n_ctx_orig;
-        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-        uint32_t nb00, nb01, nb02, nb03;
-        int32_t ne0;
-        uint32_t nb0, nb1, nb2, nb3;
-    } pushConsts {
-        safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
-        n_dims, mode, n_ctx_orig,
-        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
-        nb00, nb01, nb02, nb03,
-        ne0,
-        nb0, nb1, nb2, nb3
-    };
-
-    auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm<float, PushConstants>(
-            name, s_kompute_context->pool.get(), {inA, inB, out},
-            src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
-            {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
-        );
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-static void ggml_vk_cpy(
-    const std::vector<uint32_t>& spirv,
-    uint32_t in_element_size, uint32_t out_element_size,
-    kp::Sequence& seq,
-    const std::shared_ptr<kp::Tensor>& in,
-    const std::shared_ptr<kp::Tensor>& out,
-    uint32_t inOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
-    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
-    int32_t ne0, int32_t ne1, int32_t ne2,
-    uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
-) {
-    struct PushConstants {
-        uint32_t inOff, outOff;
-        int32_t ne00, ne01, ne02;
-        uint32_t nb00, nb01, nb02, nb03;
-        int32_t ne0, ne1, ne2;
-        uint32_t nb0, nb1, nb2, nb3;
-    } pushConsts {
-        safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
-        ne00, ne01, ne02,
-        nb00, nb01, nb02, nb03,
-        ne0, ne1, ne2,
-        nb0, nb1, nb2, nb3
-    };
-
-    std::string name = std::string(__func__)
-                       + "_i_" + std::to_string(in_element_size)
-                       + "_o_" + std::to_string(out_element_size);
-    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name))
-        s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
-    else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants<PushConstants>({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record<kp::OpAlgoDispatch>(s_algo);
-}
-
-template <typename... Args>
-static void ggml_vk_cpy_f32_f16(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
-        kp::shader_data::op_cpy_f32_f16_comp_spv_len);
-    ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_cpy_f32_f32(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
-        kp::shader_data::op_cpy_f32_f32_comp_spv_len);
-    ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_cpy_f16_f16(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
-        kp::shader_data::op_cpy_f16_f16_comp_spv_len);
-    ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
-}
-
-template <typename... Args>
-static void ggml_vk_cpy_f16_f32(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
-        kp::shader_data::op_cpy_f16_f32_comp_spv_len);
-    ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
-}
-
-static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    switch (op->op) {
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_GELU:
-                case GGML_UNARY_OP_SILU:
-                    return ggml_is_contiguous(op->src[0]);
-                default:
-                    ;
-            }
-            break;
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_ADD:
-        case GGML_OP_MUL:
-        case GGML_OP_SCALE:
-        case GGML_OP_SOFT_MAX:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_NORM:
-        case GGML_OP_ROPE:
-            return true;
-        case GGML_OP_DUP:
-        case GGML_OP_CPY:
-        case GGML_OP_CONT:
-            switch (op->src[0]->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                    break;
-                default:
-                    return false;
-            }
-            switch (op->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                    break;
-                default:
-                    return false;
-            }
-            return true;
-        case GGML_OP_DIAG_MASK_INF:
-            return op->ne[3] == 1;
-        case GGML_OP_GET_ROWS:
-            switch (op->src[0]->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                case GGML_TYPE_Q4_0:
-                case GGML_TYPE_Q4_1:
-                case GGML_TYPE_Q6_K:
-                    return op->ne[2] == 1 && op->ne[3] == 1;
-                default:
-                    ;
-            }
-            return false;
-        case GGML_OP_MUL_MAT:
-            if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
-                return false;
-
-            switch (op->src[0]->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_Q6_K:
-                    return op->ne[3] == 1;
-                case GGML_TYPE_F16:
-                case GGML_TYPE_Q8_0:
-                case GGML_TYPE_Q4_0:
-                case GGML_TYPE_Q4_1:
-                case GGML_TYPE_Q4_K:
-                    return true;
-                default:
-                    ;
-            }
-        default:
-            ;
-    }
-    return false;
-
-    GGML_UNUSED(dev);
-}
-
-static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
-    const int n_seq = 8;
-
-    // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
-    // it to the size of the graph, but I think it can be made smaller?
-    ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
-
-    std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
-
-    for (auto& sequence : sequences) {
-        sequence = komputeManager()->sequence();
-    }
-    for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
-        const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
-
-        auto& seq = *sequences[seq_idx];
-
-        const int node_start = (seq_idx + 0) * n_nodes_per_seq;
-        const int node_end   = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
-
-        bool any_commands_recorded = false;
-
-        for (int i = node_start; i < node_end; ++i) {
-            struct ggml_tensor * src0 = gf->nodes[i]->src[0];
-            struct ggml_tensor * src1 = gf->nodes[i]->src[1];
-            struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
-            struct ggml_tensor * dst = gf->nodes[i];
-            GGML_ASSERT(dst->data != nullptr);
-
-            if (ggml_is_empty(dst)) {
-                continue;
-            }
-
-            switch (dst->op) {
-                case GGML_OP_NONE:
-                case GGML_OP_RESHAPE:
-                case GGML_OP_VIEW:
-                case GGML_OP_TRANSPOSE:
-                case GGML_OP_PERMUTE:
-                    continue; // noop -> next node
-                default:
-                    break;
-            }
-
-            any_commands_recorded = true;
-
-            const int32_t ne00 = src0 ? src0->ne[0] : 0;
-            const int32_t ne01 = src0 ? src0->ne[1] : 0;
-            const int32_t ne02 = src0 ? src0->ne[2] : 0;
-            const int32_t ne03 = src0 ? src0->ne[3] : 0;
-
-            const uint32_t nb00 = src0 ? src0->nb[0] : 0;
-            const uint32_t nb01 = src0 ? src0->nb[1] : 0;
-            const uint32_t nb02 = src0 ? src0->nb[2] : 0;
-            const uint32_t nb03 = src0 ? src0->nb[3] : 0;
-
-            const int32_t ne10 = src1 ? src1->ne[0] : 0;
-            const int32_t ne11 = src1 ? src1->ne[1] : 0;
-            const int32_t ne12 = src1 ? src1->ne[2] : 0;
-            const int32_t ne13 = src1 ? src1->ne[3] : 0;
-
-            const uint32_t nb10 = src1 ? src1->nb[0] : 0;
-            const uint32_t nb11 = src1 ? src1->nb[1] : 0;
-            const uint32_t nb12 = src1 ? src1->nb[2] : 0;
-            const uint32_t nb13 = src1 ? src1->nb[3] : 0;
-
-            const int32_t ne0 = dst ? dst->ne[0] : 0;
-            const int32_t ne1 = dst ? dst->ne[1] : 0;
-            const int32_t ne2 = dst ? dst->ne[2] : 0;
-//            const int32_t ne3 = dst ? dst->ne[3] : 0;
-
-            const uint32_t nb0 = dst ? dst->nb[0] : 0;
-            const uint32_t nb1 = dst ? dst->nb[1] : 0;
-            const uint32_t nb2 = dst ? dst->nb[2] : 0;
-            const uint32_t nb3 = dst ? dst->nb[3] : 0;
-
-            const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
-            const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
-            const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
-
-            const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
-            uint32_t off_src0 = 0;
-            uint32_t off_src1 = 0;
-            uint32_t off_dst  = 0;
-            const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
-            const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
-            const std::shared_ptr<kp::Tensor>& id_dst  = dst  ? ggml_vk_get_tensor(dst,  &off_dst)  : nullTensor;
-
-            switch (dst->op) {
-                case GGML_OP_ADD:
-                    {
-                        if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-                            // src1 is a row
-                            ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
-                        } else {
-                            ggml_vk_add(
-                                seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                ne00, ne01, ne02, ne03,
-                                nb00, nb01, nb02, nb03,
-                                ne10, ne11, ne12, ne13,
-                                nb10, nb11, nb12, nb13,
-                                ne0,
-                                nb0, nb1, nb2, nb3
-                            );
-                        }
-                    } break;
-                case GGML_OP_MUL:
-                    {
-                        ggml_vk_mul(
-                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                            ne00, ne01, ne02, ne03,
-                            nb00, nb01, nb02, nb03,
-                            ne10, ne11, ne12, ne13,
-                            nb10, nb11, nb12, nb13,
-                            ne0,
-                            nb0, nb1, nb2, nb3
-                        );
-                    } break;
-                case GGML_OP_SCALE:
-                    {
-                        float scale; memcpy(&scale, dst->op_params, sizeof(float));
-
-                        ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
-                    } break;
-                case GGML_OP_UNARY:
-                    {
-                        int64_t n = ggml_nelements(dst);
-                        GGML_ASSERT(n % 4 == 0);
-                        switch (ggml_get_unary_op(gf->nodes[i])) {
-                            case GGML_UNARY_OP_SILU:
-                                {
-                                    ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
-                                } break;
-                            case GGML_UNARY_OP_RELU:
-                                {
-                                    ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
-                                } break;
-                            case GGML_UNARY_OP_GELU:
-                                {
-                                    GGML_ASSERT(n % 8 == 0);
-                                    ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
-                                } break;
-                            default:
-                                {
-                                    fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-                                    GGML_ABORT("fatal error");
-                                }
-                        }
-                    } break;
-                case GGML_OP_SOFT_MAX:
-                    {
-                        float scale;
-                        float max_bias;
-
-                        memcpy(&scale,    (float *)dst->op_params + 0, sizeof(float));
-                        memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
-
-#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
-#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
-                        GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
-
-#pragma message("TODO: add ALiBi support")
-#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/7192")
-                        GGML_ASSERT(max_bias == 0.0f);
-
-                        ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
-                    } break;
-                case GGML_OP_DIAG_MASK_INF:
-                    {
-                        const int n_past = ((int32_t *)(dst->op_params))[0];
-                        ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
-                    } break;
-                case GGML_OP_NORM:
-                    {
-                        float eps;
-                        memcpy(&eps, dst->op_params, sizeof(float));
-                        ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
-                    } break;
-                case GGML_OP_RMS_NORM:
-                    {
-                        GGML_ASSERT(ne00 % 4 == 0);
-
-                        float eps;
-                        memcpy(&eps, dst->op_params, sizeof(float));
-                        ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
-                    } break;
-                case GGML_OP_MUL_MAT:
-                    {
-                        GGML_ASSERT(ne00 == ne10);
-
-                        GGML_ASSERT(ne12 % ne02 == 0);
-                        GGML_ASSERT(ne13 % ne03 == 0);
-
-                        const uint32_t r2 = ne12/ne02;
-                        const uint32_t r3 = ne13/ne03;
-
-                        if (src1t != GGML_TYPE_F32) {
-                            fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
-                            goto not_implemented;
-                        }
-
-                        if (ggml_is_transposed(src0) ||
-                            ggml_is_transposed(src1)) {
-                            fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
-                            goto not_implemented;
-                        }
-
-                        switch (src0t) {
-                            case GGML_TYPE_F32:
-                                ggml_vk_mul_mat_mat_f32(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
-                                );
-                                break;
-                            case GGML_TYPE_F16:
-                                ggml_vk_mul_mat_f16(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
-                                    ne0, ne1, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q8_0:
-                                ggml_vk_mul_mat_q8_0(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q4_0:
-                                ggml_vk_mul_mat_q4_0(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q4_1:
-                                ggml_vk_mul_mat_q4_1(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q4_K:
-                                ggml_vk_mul_mat_q4_k(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
-                                );
-                                break;
-                            case GGML_TYPE_Q6_K:
-                                ggml_vk_mul_mat_q6_k(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
-                                );
-                                break;
-                            default: {
-                                fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
-                                goto not_implemented;
-                            }
-                        }
-
-                    } break;
-                case GGML_OP_GET_ROWS:
-                    {
-                        if (src0t == GGML_TYPE_F32) {
-                            ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else if (src0t == GGML_TYPE_F16) {
-                            ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else if (src0t == GGML_TYPE_Q4_0) {
-                            ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else if (src0t == GGML_TYPE_Q4_1) {
-                            ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else if (src0t == GGML_TYPE_Q6_K) {
-                            ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else {
-                            fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
-                            goto not_implemented;
-                        }
-                    } break;
-                case GGML_OP_ROPE:
-                    {
-#pragma message("TODO: implement phi3 frequency factors support")
-#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
-                        GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
-
-#pragma message("TODO: update rope NORM mode to match NEOX mode")
-#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
-
-                        GGML_ASSERT(ne10 == ne02);
-                        GGML_ASSERT(src0t == dstt);
-                        // const int n_past = ((int32_t *) dst->op_params)[0];
-                        const int n_dims     = ((int32_t *) dst->op_params)[1];
-                        const int mode       = ((int32_t *) dst->op_params)[2];
-                        // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
-                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
-                        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-                        memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-                        memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-                        memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-                        memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-                        memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-                        memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-                        ggml_vk_rope(
-                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
-                            freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
-                            ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
-                        );
-                    } break;
-                case GGML_OP_DUP:
-                case GGML_OP_CPY:
-                case GGML_OP_CONT:
-                    {
-                        switch (src0t) {
-                            case GGML_TYPE_F32:
-                                {
-                                    switch (dstt) {
-                                        case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
-                                        case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
-                                        default: goto not_implemented;
-                                    }
-                                } break;
-                            case GGML_TYPE_F16:
-                                {
-                                    switch (dstt) {
-                                        case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
-                                        case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
-                                    default: goto not_implemented;
-                                } break;
-                            default: goto not_implemented;
-                            }
-                        }
-                    } break;
-                default: goto not_implemented;
-            }
-            continue;
-            not_implemented: {}
-            fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-            //GGML_ABORT("fatal error");
-        }
-
-        // Evaluate sequence
-        if (any_commands_recorded) {
-            seq.evalAsync();
-        }
-    }
-
-    // Wait for all sequences to finish
-    for (auto& sequence : sequences) {
-        if (sequence->isRunning())
-            sequence->evalAwait();
-    }
-
-    ggml_vk_free_descriptor_pool(ctx);
-}
-
-template<>
-kp::Tensor::TensorDataTypes
-kp::TensorT<half>::dataType()
-{
-    return TensorDataTypes::eFloat;
-}
-
-template<>
-kp::Tensor::TensorDataTypes
-kp::TensorT<uint8_t>::dataType()
-{
-    return TensorDataTypes::eUnsignedInt;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend interface
-
-struct ggml_backend_kompute_buffer_type_context {
-    int         device;
-    int         device_ref = 0;
-    uint64_t    buffer_alignment;
-    uint64_t    max_alloc;
-    std::string name;
-
-    ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
-        : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
-};
-
-static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
-
-    if (!ctx->device_ref) {
-        komputeManager()->initializeDevice(
-            ctx->device, {}, {
-                "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
-                "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
-            }
-        );
-    }
-
-    assert(ggml_vk_has_device());
-    ctx->device_ref++;
-}
-
-static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
-
-    assert(ctx->device_ref > 0);
-
-    ctx->device_ref--;
-
-    if (!ctx->device_ref) {
-        komputeManager.destroy();
-    }
-}
-
-static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    auto * memory = (ggml_vk_memory *)buffer->context;
-    if (ggml_vk_has_device()) {
-        ggml_vk_free_memory(*memory);
-    }
-    delete memory;
-}
-
-static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
-    return ((ggml_vk_memory *)buffer->context)->data;
-}
-
-static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_UNUSED(buffer);
-
-    const auto res = ggml_vk_get_tensor(tensor);
-    GGML_ASSERT(res);
-
-    memcpy((char *)tensor->data + offset, data, size);
-
-    komputeManager()->sequence()->eval<kp::OpTensorSyncDevice>({res});
-}
-
-static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_UNUSED(buffer);
-
-    const auto res = ggml_vk_get_tensor(tensor);
-    GGML_ASSERT(res);
-
-    komputeManager()->sequence()->eval<kp::OpTensorSyncLocal>({res});
-
-    memcpy(data, (const char *)tensor->data + offset, size);
-}
-
-static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    auto * memory = (ggml_vk_memory *)buffer->context;
-    memset(memory->data, value, buffer->size);
-
-    if (memory->stagingBuffer)
-        komputeManager()->sequence()->eval<kp::OpBufferSyncDevice>(memory->primaryBuffer, memory->stagingBuffer, memory->size);
-}
-
-static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
-    /* .free_buffer     = */ ggml_backend_kompute_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_kompute_buffer_get_base,
-    /* .init_tensor     = */ NULL,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_kompute_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_kompute_buffer_get_tensor,
-    /* .cpy_tensor      = */ NULL,
-    /* .clear           = */ ggml_backend_kompute_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// default buffer type
-
-static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
-    return ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    ggml_backend_kompute_device_ref(buft);
-    auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
-    return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
-}
-
-static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
-    return ctx->buffer_alignment;
-}
-
-static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
-    return ctx->max_alloc;
-}
-
-static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_kompute_buffer_type_get_name,
-    /* .alloc_buffer     = */ ggml_backend_kompute_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_kompute_buffer_type_get_alignment,
-    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
-    /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
-    /* .is_host          = */ NULL,
-};
-
-ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-
-    auto devices = ggml_vk_available_devices();
-    int32_t device_count = (int32_t) devices.size();
-    GGML_ASSERT(device < device_count);
-    GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
-
-    static ggml_backend_buffer_type
-        ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
-
-    static bool ggml_backend_kompute_buffer_type_initialized = false;
-
-    if (!ggml_backend_kompute_buffer_type_initialized) {
-        for (int32_t i = 0; i < device_count; i++) {
-            ggml_backend_kompute_buffer_types[i] = {
-                /* .iface    = */ ggml_backend_kompute_buffer_type_interface,
-                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
-                /* .context  = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
-            };
-        }
-        ggml_backend_kompute_buffer_type_initialized = true;
-    }
-
-    return &ggml_backend_kompute_buffer_types[device];
-}
-
-// backend
-
-static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
-    auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
-    return ctx->name.c_str();
-}
-
-static void ggml_backend_kompute_free(ggml_backend_t backend) {
-    auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
-
-    assert(ctx == s_kompute_context);
-    s_kompute_context = nullptr;
-    if (ctx != nullptr) {
-        delete ctx;
-    }
-
-    delete backend;
-}
-
-static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
-    ggml_vk_graph_compute(ctx, cgraph);
-    return GGML_STATUS_SUCCESS;
-}
-
-static struct ggml_backend_i kompute_backend_i = {
-    /* .get_name                = */ ggml_backend_kompute_name,
-    /* .free                    = */ ggml_backend_kompute_free,
-    /* .set_tensor_async        = */ NULL,
-    /* .get_tensor_async        = */ NULL,
-    /* .cpy_tensor_async        = */ NULL,
-    /* .synchronize             = */ NULL,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_kompute_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_kompute_guid() {
-    static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
-    return &guid;
-}
-
-ggml_backend_t ggml_backend_kompute_init(int device) {
-    GGML_ASSERT(s_kompute_context == nullptr);
-    s_kompute_context = new ggml_kompute_context(device);
-
-    ggml_backend_t kompute_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_kompute_guid(),
-        /* .interface = */ kompute_backend_i,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
-        /* .context   = */ s_kompute_context,
-    };
-
-    return kompute_backend;
-}
-
-bool ggml_backend_is_kompute(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
-}
-
-static size_t ggml_backend_kompute_get_device_count() {
-    auto devices = ggml_vk_available_devices();
-    return devices.size();
-}
-
-static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
-    auto devices = ggml_vk_available_devices();
-    GGML_ASSERT((size_t) device < devices.size());
-    snprintf(description, description_size, "%s", devices[device].name);
-}
-
-static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
-    auto devices = ggml_vk_available_devices();
-    GGML_ASSERT((size_t) device < devices.size());
-    *total = devices[device].heapSize;
-    *free = devices[device].heapSize;
-}
-
-//////////////////////////
-
-struct ggml_backend_kompute_device_context {
-    int device;
-    std::string name;
-    std::string description;
-};
-
-static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    return ctx->name.c_str();
-}
-
-static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    return ctx->description.c_str();
-}
-
-static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    ggml_backend_kompute_get_device_memory(ctx->device, free, total);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    return ggml_backend_kompute_buffer_type(ctx->device);
-}
-
-static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
-        return false;
-    }
-
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
-
-    return buft_ctx->device == ctx->device;
-}
-
-static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
-    GGML_UNUSED(dev);
-    return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_kompute_device_get_name(dev);
-    props->description = ggml_backend_kompute_device_get_description(dev);
-    props->type        = ggml_backend_kompute_device_get_type(dev);
-    ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
-    props->caps = {
-        /* async                  = */ false,
-        /* host_buffer            = */ false,
-        /* .buffer_from_host_ptr  = */ false,
-        /* events                 = */ false,
-    };
-}
-
-static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
-    GGML_UNUSED(params);
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    return ggml_backend_kompute_init(ctx->device);
-}
-
-static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    const int min_batch_size = 32;
-
-    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
-           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
-
-    GGML_UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
-    /* .get_name             = */ ggml_backend_kompute_device_get_name,
-    /* .get_description      = */ ggml_backend_kompute_device_get_description,
-    /* .get_memory           = */ ggml_backend_kompute_device_get_memory,
-    /* .get_type             = */ ggml_backend_kompute_device_get_type,
-    /* .get_props            = */ ggml_backend_kompute_device_get_props,
-    /* .init_backend         = */ ggml_backend_kompute_device_init,
-    /* .get_buffer_type      = */ ggml_backend_kompute_device_get_buffer_type,
-    /* .get_host_buffer_type = */ NULL,
-    /* .buffer_from_host_ptr = */ NULL,
-    /* .supports_op          = */ ggml_backend_kompute_device_supports_op,
-    /* .supports_buft        = */ ggml_backend_kompute_device_supports_buft,
-    /* .offload_op           = */ ggml_backend_kompute_device_offload_op,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
-};
-
-static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
-    GGML_UNUSED(reg);
-    return "Kompute";
-}
-
-static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
-    GGML_UNUSED(reg);
-    return ggml_backend_kompute_get_device_count();
-}
-
-static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
-    static std::vector<ggml_backend_dev_t> devices;
-
-    static bool initialized = false;
-
-    {
-        static std::mutex mutex;
-        std::lock_guard<std::mutex> lock(mutex);
-        if (!initialized) {
-            for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
-                ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
-                char desc[256];
-                ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
-                ctx->device = i;
-                ctx->name = "Kompute" + std::to_string(i);
-                ctx->description = desc;
-                devices.push_back(new ggml_backend_device {
-                    /* .iface   = */ ggml_backend_kompute_device_i,
-                    /* .reg     = */ reg,
-                    /* .context = */ ctx,
-                });
-            }
-            initialized = true;
-        }
-    }
-
-    GGML_ASSERT(device < devices.size());
-    return devices[device];
-}
-
-static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
-    /* .get_name         = */ ggml_backend_kompute_reg_get_name,
-    /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
-    /* .get_device       = */ ggml_backend_kompute_reg_get_device,
-    /* .get_proc_address = */ NULL,
-};
-
-ggml_backend_reg_t ggml_backend_kompute_reg() {
-    static ggml_backend_reg reg = {
-        /* .iface   = */ ggml_backend_kompute_reg_i,
-        /* .context = */ nullptr,
-    };
-
-    return &reg;
-}
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
deleted file mode 100644 (file)
index 04ec511..0000000
+++ /dev/null
@@ -1,4290 +0,0 @@
-#import "ggml-metal.h"
-
-#import "ggml-impl.h"
-#import "ggml-backend-impl.h"
-
-#import <Foundation/Foundation.h>
-
-#import <Metal/Metal.h>
-
-#undef MIN
-#undef MAX
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
-// max memory buffers that can be mapped to the device
-#define GGML_METAL_MAX_BUFFERS 64
-
-// max number of MTLCommandBuffer used to submit a graph for processing
-#define GGML_METAL_MAX_COMMAND_BUFFERS 8
-
-#define UNUSED(x) (void)(x)
-
-// globals
-
-// overload of MTLGPUFamilyMetal3 (not available in some environments)
-static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
-
-// initialized in ggml_backend_metal_reg
-static struct ggml_backend_reg    g_ggml_backend_metal_reg;
-static struct ggml_backend_device g_ggml_backend_metal_device;
-
-// information about a Metal device
-// note: assumes single GPU device - the default one
-// TODO: support multiple GPU devices
-static struct ggml_backend_metal_device_context {
-    id<MTLDevice> mtl_device;
-    int           mtl_device_ref_count;
-
-    bool has_simdgroup_reduction;
-    bool has_simdgroup_mm;
-    bool has_bfloat;
-    bool use_bfloat;
-
-    char name[128];
-} g_ggml_ctx_dev_main = {
-    /*.mtl_device              =*/ nil,
-    /*.mtl_device_ref_count    =*/ 0,
-    /*.has_simdgroup_reduction =*/ false,
-    /*.has_simdgroup_mm        =*/ false,
-    /*.has_bfloat              =*/ false,
-    /*.use_bfloat              =*/ false,
-    /*.name                    =*/ "",
-};
-
-// acquire
-static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
-    assert(ctx != NULL);
-
-    if (ctx->mtl_device == nil) {
-        ctx->mtl_device = MTLCreateSystemDefaultDevice();
-
-        ctx->has_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
-        ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
-
-        ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
-
-        ctx->has_bfloat  = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
-        ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
-
-#if defined(GGML_METAL_USE_BF16)
-        ctx->use_bfloat = ctx->has_bfloat;
-#else
-        ctx->use_bfloat = false;
-#endif
-
-        strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
-    }
-
-    ctx->mtl_device_ref_count++;
-
-    return ctx->mtl_device;
-}
-
-// release
-static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
-    assert(ctx != NULL);
-    assert(ctx->mtl_device_ref_count > 0);
-
-    ctx->mtl_device_ref_count--;
-
-    if (ctx->mtl_device_ref_count == 0) {
-        [ctx->mtl_device release];
-        ctx->mtl_device = nil;
-    }
-}
-
-// kernels
-
-struct ggml_metal_kernel {
-    id<MTLComputePipelineState> pipeline;
-};
-
-enum ggml_metal_kernel_type {
-    GGML_METAL_KERNEL_TYPE_ADD,
-    GGML_METAL_KERNEL_TYPE_ADD_ROW,
-    GGML_METAL_KERNEL_TYPE_SUB,
-    GGML_METAL_KERNEL_TYPE_SUB_ROW,
-    GGML_METAL_KERNEL_TYPE_MUL,
-    GGML_METAL_KERNEL_TYPE_MUL_ROW,
-    GGML_METAL_KERNEL_TYPE_DIV,
-    GGML_METAL_KERNEL_TYPE_DIV_ROW,
-    GGML_METAL_KERNEL_TYPE_REPEAT_F32,
-    GGML_METAL_KERNEL_TYPE_REPEAT_F16,
-    GGML_METAL_KERNEL_TYPE_REPEAT_I32,
-    GGML_METAL_KERNEL_TYPE_REPEAT_I16,
-    GGML_METAL_KERNEL_TYPE_SCALE,
-    GGML_METAL_KERNEL_TYPE_SCALE_4,
-    GGML_METAL_KERNEL_TYPE_CLAMP,
-    GGML_METAL_KERNEL_TYPE_TANH,
-    GGML_METAL_KERNEL_TYPE_RELU,
-    GGML_METAL_KERNEL_TYPE_SIGMOID,
-    GGML_METAL_KERNEL_TYPE_GELU,
-    GGML_METAL_KERNEL_TYPE_GELU_4,
-    GGML_METAL_KERNEL_TYPE_GELU_QUICK,
-    GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
-    GGML_METAL_KERNEL_TYPE_SILU,
-    GGML_METAL_KERNEL_TYPE_SILU_4,
-    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
-    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
-    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
-    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
-    GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
-    GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
-    GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
-    GGML_METAL_KERNEL_TYPE_RMS_NORM,
-    GGML_METAL_KERNEL_TYPE_GROUP_NORM,
-    GGML_METAL_KERNEL_TYPE_NORM,
-    GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
-    GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
-  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
-  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
-  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
-    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
-    GGML_METAL_KERNEL_TYPE_IM2COL_F16,
-    GGML_METAL_KERNEL_TYPE_IM2COL_F32,
-    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
-    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
-    GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
-    GGML_METAL_KERNEL_TYPE_PAD_F32,
-    GGML_METAL_KERNEL_TYPE_ARANGE_F32,
-    GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
-    GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
-    GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
-    GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
-    GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
-    GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
-    GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
-    GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
-    GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
-    GGML_METAL_KERNEL_TYPE_CONCAT,
-    GGML_METAL_KERNEL_TYPE_SQR,
-    GGML_METAL_KERNEL_TYPE_SQRT,
-    GGML_METAL_KERNEL_TYPE_SIN,
-    GGML_METAL_KERNEL_TYPE_COS,
-    GGML_METAL_KERNEL_TYPE_SUM_ROWS,
-    GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
-    GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
-
-    GGML_METAL_KERNEL_TYPE_COUNT
-};
-
-struct ggml_backend_metal_context {
-    id<MTLCommandQueue> queue;
-
-    dispatch_queue_t d_queue;
-
-    struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
-
-    // capture state
-    bool capture_next_compute;
-    bool capture_started;
-
-    id<MTLCaptureScope> capture_scope;
-
-    // command buffer state
-    int n_cb;           // number of extra threads used to submit the command buffers
-    int n_nodes_0;      // number of nodes submitted by the main thread
-    int n_nodes_1;      // remaining number of nodes submitted by the n_cb threads
-    int n_nodes_per_cb;
-
-    struct ggml_cgraph * gf;
-
-    // the callback given to the thread pool
-    void (^encode_async)(size_t ith);
-
-    // n_cb command buffers + 1 used by the main thread
-    id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
-
-    // abort ggml_metal_graph_compute if callback returns true
-    ggml_abort_callback abort_callback;
-    void *              abort_callback_data;
-};
-
-// MSL code
-// TODO: move the contents here when ready
-//       for now it is easier to work in a separate file
-// static NSString * const msl_library_source = @"see metal.metal";
-
-// Here to assist with NSBundle Path Hack
-@interface GGMLMetalClass : NSObject
-@end
-@implementation GGMLMetalClass
-@end
-
-static void * ggml_metal_host_malloc(size_t n) {
-    void * data = NULL;
-
-#if TARGET_OS_OSX
-    kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
-    if (err != KERN_SUCCESS) {
-        GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
-        return NULL;
-    }
-#else
-    const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
-    if (result != 0) {
-        GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
-        return NULL;
-    }
-#endif
-
-    return data;
-}
-
-static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
-    GGML_LOG_INFO("%s: allocating\n", __func__);
-
-#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
-    // Show all the Metal device instances in the system
-    NSArray * devices = MTLCopyAllDevices();
-    for (id<MTLDevice> device in devices) {
-        GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
-    }
-    [devices release]; // since it was created by a *Copy* C method
-#endif
-
-    // init context
-    struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
-    struct ggml_backend_metal_device_context * ctx_dev = dev->context;
-
-    id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
-    GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
-
-    ctx->queue  = [device newCommandQueue];
-    ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
-
-    id<MTLLibrary> metal_library;
-
-    // load library
-    //
-    // - first check if the library is embedded
-    // - then check if the library is in the bundle
-    // - if not found, load the source and compile it
-    // - if that fails, return NULL
-    {
-        NSBundle * bundle = nil;
-#ifdef SWIFT_PACKAGE
-        bundle = SWIFTPM_MODULE_BUNDLE;
-#else
-        bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
-#endif
-
-        NSError * error = nil;
-
-#if GGML_METAL_EMBED_LIBRARY
-        const bool try_metallib = false;
-#else
-        const bool try_metallib = true;
-#endif
-
-        NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
-        if (try_metallib && path_lib != nil) {
-            // pre-compiled library found
-            NSURL * libURL = [NSURL fileURLWithPath:path_lib];
-            GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
-
-            metal_library = [device newLibraryWithURL:libURL error:&error];
-            if (error) {
-                GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
-                return NULL;
-            }
-        } else {
-#if GGML_METAL_EMBED_LIBRARY
-            GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
-
-            extern const char ggml_metallib_start[];
-            extern const char ggml_metallib_end[];
-
-            NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
-#else
-            GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
-
-            NSString * path_source;
-            NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
-
-            GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
-
-            if (path_resource) {
-                path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
-            } else {
-                path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
-            }
-
-            if (path_source == nil) {
-                GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
-                path_source = @"ggml-metal.metal";
-            }
-
-            GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
-
-            NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
-            if (error) {
-                GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
-                return NULL;
-            }
-#endif // GGML_METAL_EMBED_LIBRARY
-
-            @autoreleasepool {
-                // dictionary of preprocessor macros
-                NSMutableDictionary * prep = [NSMutableDictionary dictionary];
-
-                if (ctx_dev->use_bfloat) {
-                    [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
-                }
-
-                MTLCompileOptions * options = [MTLCompileOptions new];
-                options.preprocessorMacros = prep;
-
-                //[options setFastMathEnabled:false];
-
-                metal_library = [device newLibraryWithSource:src options:options error:&error];
-                if (error) {
-                    GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
-                    return NULL;
-                }
-
-#if !__has_feature(objc_arc)
-                [options release];
-#endif
-            }
-#if GGML_METAL_EMBED_LIBRARY
-            [src release];
-#endif // GGML_METAL_EMBED_LIBRARY
-        }
-    }
-
-    // print MTL GPU family:
-    GGML_LOG_INFO("%s: GPU name:   %s\n", __func__, [[device name] UTF8String]);
-
-    // determine max supported GPU family
-    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
-    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
-    {
-        for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
-            if ([device supportsFamily:i]) {
-                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d  (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
-                break;
-            }
-        }
-
-        for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
-            if ([device supportsFamily:i]) {
-                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
-                break;
-            }
-        }
-
-        for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
-            if ([device supportsFamily:i]) {
-                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d  (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
-                break;
-            }
-        }
-    }
-
-    GGML_LOG_INFO("%s: simdgroup reduction   = %s\n", __func__, ctx_dev->has_simdgroup_reduction     ? "true" : "false");
-    GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm            ? "true" : "false");
-    GGML_LOG_INFO("%s: has bfloat            = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false");
-    GGML_LOG_INFO("%s: use bfloat            = %s\n", __func__, ctx_dev->use_bfloat                  ? "true" : "false");
-    GGML_LOG_INFO("%s: hasUnifiedMemory      = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
-
-    ctx->capture_next_compute = false;
-    ctx->capture_started = false;
-    ctx->capture_scope = nil;
-
-    ctx->gf = nil;
-    ctx->encode_async = nil;
-    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
-        ctx->command_buffers[i] = nil;
-    }
-
-#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
-    if (@available(macOS 10.12, iOS 16.0, *)) {
-        GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
-    }
-#endif
-
-    // load kernels
-    {
-        NSError * error = nil;
-
-        for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
-            ctx->kernels[i].pipeline = nil;
-        }
-
-#define GGML_METAL_ADD_KERNEL(e, name, supported) \
-        if (supported) { \
-            struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
-            id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
-            kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
-            GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
-                    (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
-                    (int) kernel->pipeline.threadExecutionWidth); \
-            [metal_function release]; \
-            if (error) { \
-                GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
-                [metal_library release]; \
-                return NULL; \
-            } \
-        } else { \
-            GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
-        }
-
-        const bool has_simdgroup_mm        = ctx_dev->has_simdgroup_mm;
-        const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
-        const bool use_bfloat              = ctx_dev->use_bfloat;
-
-        // simd_sum and simd_max requires MTLGPUFamilyApple7
-
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                       add_row,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                           sub,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                       sub_row,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                           mul,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                    repeat_f32,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                    repeat_f16,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                    repeat_i32,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16,                    repeat_i16,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                          tanh,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                          relu,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                       sigmoid,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                          gelu,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                        gelu_4,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                    gelu_quick,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                 diag_mask_inf,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,                  get_rows_f16,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,                 get_rows_bf16,                  use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,                 get_rows_q4_0,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,                 get_rows_q4_1,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                 get_rows_q5_0,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                 get_rows_q5_1,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                 get_rows_q8_0,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                 get_rows_q2_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                 get_rows_q3_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                 get_rows_q4_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,                 get_rows_q5_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,                 get_rows_q6_K,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,              get_rows_iq2_xxs,               true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,               get_rows_iq2_xs,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,              get_rows_iq3_xxs,               true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,                get_rows_iq3_s,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,                get_rows_iq2_s,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,                get_rows_iq1_s,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,                get_rows_iq1_m,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,               get_rows_iq4_nl,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,               mul_mv_bf16_f32,                has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,          mul_mv_bf16_f32_1row,           has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,            mul_mv_bf16_f32_l4,             has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,              mul_mv_bf16_bf16,               has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              has_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         has_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           has_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,            mul_mv_id_bf16_f32,             has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,               mul_mm_bf16_f32,                has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,            mul_mm_id_bf16_f32,             has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                im2col_ext_f32,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,       flash_attn_ext_bf16_h64,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,       flash_attn_ext_bf16_h80,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,       flash_attn_ext_bf16_h96,        has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,      flash_attn_ext_bf16_h112,       has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,      flash_attn_ext_bf16_h128,       has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,      flash_attn_ext_bf16_h256,       has_simdgroup_mm && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,      flash_attn_ext_q4_0_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,      flash_attn_ext_q4_0_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,      flash_attn_ext_q4_0_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,       flash_attn_ext_q4_1_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,       flash_attn_ext_q4_1_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,       flash_attn_ext_q4_1_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,      flash_attn_ext_q4_1_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,      flash_attn_ext_q4_1_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,      flash_attn_ext_q4_1_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,       flash_attn_ext_q5_0_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,       flash_attn_ext_q5_0_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,       flash_attn_ext_q5_0_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,      flash_attn_ext_q5_0_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,      flash_attn_ext_q5_0_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,      flash_attn_ext_q5_0_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,       flash_attn_ext_q5_1_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,       flash_attn_ext_q5_1_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,       flash_attn_ext_q5_1_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,      flash_attn_ext_q5_1_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,      flash_attn_ext_q5_1_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,      flash_attn_ext_q5_1_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,       flash_attn_ext_q8_0_h64,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,       flash_attn_ext_q8_0_h80,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,       flash_attn_ext_q8_0_h96,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,      flash_attn_ext_q8_0_h112,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,  flash_attn_ext_vec_bf16_h128,   has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,  flash_attn_ext_vec_bf16_h256,   has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,  flash_attn_ext_vec_q5_1_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,  flash_attn_ext_vec_q8_0_h256,   has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,                  cpy_f32_bf16,                   use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,                  cpy_bf16_f32,                   use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,                 cpy_bf16_bf16,                  use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,               pool_2d_max_f32,                true);
-    }
-
-    [metal_library release];
-
-    return ctx;
-}
-
-static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
-    GGML_LOG_INFO("%s: deallocating\n", __func__);
-
-    for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
-        [ctx->kernels[i].pipeline release];
-    }
-
-    Block_release(ctx->encode_async);
-
-    [ctx->queue release];
-
-    dispatch_release(ctx->d_queue);
-
-    free(ctx);
-}
-
-// temporarily defined here for compatibility between ggml-backend and the old API
-
-struct ggml_backend_metal_buffer {
-    void   * data;
-    size_t   size;
-
-    id<MTLBuffer> metal;
-};
-
-struct ggml_backend_metal_buffer_context {
-    void * all_data;
-    size_t all_size;
-    bool owned;
-
-    // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
-    int n_buffers;
-    struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
-};
-
-// finds the Metal buffer that contains the tensor data on the GPU device
-// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
-// Metal buffer based on the host memory pointer
-//
-static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
-    //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
-
-    const int64_t tsize = ggml_nbytes(t);
-
-    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
-
-    struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
-
-    // find the view that contains the tensor fully
-    for (int i = 0; i < buf_ctx->n_buffers; ++i) {
-        const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
-
-        //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
-        if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
-            *offs = (size_t) ioffs;
-
-            //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
-
-            return buf_ctx->buffers[i].metal;
-        }
-    }
-
-    GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
-
-    return nil;
-}
-
-static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
-    const bool has_simdgroup_mm        = ctx_dev->has_simdgroup_mm;
-    const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
-    const bool use_bfloat              = ctx_dev->use_bfloat;
-
-    if (!use_bfloat) {
-        for (size_t i = 0, n = 3; i < n; ++i) {
-            if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
-                return false;
-            }
-        }
-    }
-
-    switch (op->op) {
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_TANH:
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_SIGMOID:
-                case GGML_UNARY_OP_GELU:
-                case GGML_UNARY_OP_GELU_QUICK:
-                case GGML_UNARY_OP_SILU:
-                    return ggml_is_contiguous(op->src[0]);
-                default:
-                    return false;
-            }
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_CONCAT:
-        case GGML_OP_ADD:
-        case GGML_OP_SUB:
-        case GGML_OP_ACC:
-        case GGML_OP_MUL:
-        case GGML_OP_DIV:
-        case GGML_OP_REPEAT:
-        case GGML_OP_SCALE:
-        case GGML_OP_CLAMP:
-            return true;
-        case GGML_OP_SQR:
-        case GGML_OP_SQRT:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-            return ggml_is_contiguous(op->src[0]);
-        case GGML_OP_SUM_ROWS:
-        case GGML_OP_SOFT_MAX:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_GROUP_NORM:
-            return has_simdgroup_reduction;
-        case GGML_OP_NORM:
-        case GGML_OP_ROPE:
-            return true;
-        case GGML_OP_IM2COL:
-            return op->src[0]->type == GGML_TYPE_F16;
-        case GGML_OP_POOL_1D:
-            return false;
-        case GGML_OP_POOL_2D:
-        case GGML_OP_UPSCALE:
-        case GGML_OP_PAD:
-        case GGML_OP_ARANGE:
-        case GGML_OP_TIMESTEP_EMBEDDING:
-        case GGML_OP_ARGSORT:
-        case GGML_OP_LEAKY_RELU:
-            return true;
-        case GGML_OP_FLASH_ATTN_EXT:
-            if (op->src[1]->type != op->src[2]->type) {
-                return false;
-            }
-            return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
-        case GGML_OP_SSM_CONV:
-        case GGML_OP_SSM_SCAN:
-            return true;
-        case GGML_OP_MUL_MAT:
-        case GGML_OP_MUL_MAT_ID:
-            return has_simdgroup_reduction &&
-                (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
-        case GGML_OP_CPY:
-        case GGML_OP_DUP:
-        case GGML_OP_CONT:
-            {
-                switch (op->src[0]->type) {
-                    case GGML_TYPE_F32:
-                        switch (op->type) {
-                           case GGML_TYPE_F32:
-                           case GGML_TYPE_F16:
-                           case GGML_TYPE_BF16:
-                           case GGML_TYPE_Q8_0:
-                           case GGML_TYPE_Q4_0:
-                           case GGML_TYPE_Q4_1:
-                           case GGML_TYPE_Q5_0:
-                           case GGML_TYPE_Q5_1:
-                           case GGML_TYPE_IQ4_NL:
-                                return true;
-                           default:
-                                return false;
-                        }
-                    case GGML_TYPE_F16:
-                        switch (op->type) {
-                            case GGML_TYPE_F32:
-                            case GGML_TYPE_F16:
-                                return true;
-                            default:
-                                return false;
-                        }
-                    case GGML_TYPE_BF16:
-                        switch (op->type) {
-                            case GGML_TYPE_F32:
-                            case GGML_TYPE_BF16:
-                                return true;
-                            default:
-                                return false;
-                        }
-                    default:
-                        return false;
-                };
-            }
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_GET_ROWS:
-            {
-                return op->ne[3] == 1;
-            }
-        default:
-            return false;
-    }
-}
-
-static void ggml_metal_encode_node(
-                        ggml_backend_t   backend,
-                                   int   idx,
-          id<MTLComputeCommandEncoder>   encoder) {
-    struct ggml_backend_metal_context        * ctx     = backend->context;
-    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
-
-    struct ggml_cgraph * gf = ctx->gf;
-
-    struct ggml_tensor * node = ggml_graph_node(gf, idx);
-
-    //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
-
-    struct ggml_tensor * src0 = node->src[0];
-    struct ggml_tensor * src1 = node->src[1];
-    struct ggml_tensor * src2 = node->src[2];
-    struct ggml_tensor * dst  = node;
-
-    if (ggml_is_empty(dst)) {
-        return;
-    }
-
-    switch (dst->op) {
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_PERMUTE:
-            {
-                // noop -> next node
-            } return;
-        default:
-            {
-            } break;
-    }
-
-    if (!ggml_metal_supports_op(ctx_dev, dst)) {
-        GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
-        GGML_ABORT("unsupported op");
-    }
-
-    const int64_t  ne00 = src0 ? src0->ne[0] : 0;
-    const int64_t  ne01 = src0 ? src0->ne[1] : 0;
-    const int64_t  ne02 = src0 ? src0->ne[2] : 0;
-    const int64_t  ne03 = src0 ? src0->ne[3] : 0;
-
-    const uint64_t nb00 = src0 ? src0->nb[0] : 0;
-    const uint64_t nb01 = src0 ? src0->nb[1] : 0;
-    const uint64_t nb02 = src0 ? src0->nb[2] : 0;
-    const uint64_t nb03 = src0 ? src0->nb[3] : 0;
-
-    const int64_t  ne10 = src1 ? src1->ne[0] : 0;
-    const int64_t  ne11 = src1 ? src1->ne[1] : 0;
-    const int64_t  ne12 = src1 ? src1->ne[2] : 0;
-    const int64_t  ne13 = src1 ? src1->ne[3] : 0;
-
-    const uint64_t nb10 = src1 ? src1->nb[0] : 0;
-    const uint64_t nb11 = src1 ? src1->nb[1] : 0;
-    const uint64_t nb12 = src1 ? src1->nb[2] : 0;
-    const uint64_t nb13 = src1 ? src1->nb[3] : 0;
-
-    const int64_t  ne20 = src2 ? src2->ne[0] : 0;
-    const int64_t  ne21 = src2 ? src2->ne[1] : 0;
-    const int64_t  ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
-    const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
-
-    const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
-    const uint64_t nb21 = src2 ? src2->nb[1] : 0;
-    const uint64_t nb22 = src2 ? src2->nb[2] : 0;
-    const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
-
-    const int64_t  ne0  =  dst ?  dst->ne[0] : 0;
-    const int64_t  ne1  =  dst ?  dst->ne[1] : 0;
-    const int64_t  ne2  =  dst ?  dst->ne[2] : 0;
-    const int64_t  ne3  =  dst ?  dst->ne[3] : 0;
-
-    const uint64_t nb0  =  dst ?  dst->nb[0] : 0;
-    const uint64_t nb1  =  dst ?  dst->nb[1] : 0;
-    const uint64_t nb2  =  dst ?  dst->nb[2] : 0;
-    const uint64_t nb3  =  dst ?  dst->nb[3] : 0;
-
-    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
-    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
-    const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
-
-    size_t offs_src0 = 0;
-    size_t offs_src1 = 0;
-    size_t offs_src2 = 0;
-    size_t offs_dst  = 0;
-
-    id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
-    id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
-    id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
-    id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
-
-#if 0
-    GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
-    if (src0) {
-        GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
-                ggml_is_contiguous(src0), src0->name);
-    }
-    if (src1) {
-        GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
-                ggml_is_contiguous(src1), src1->name);
-    }
-    if (dst) {
-        GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
-                dst->name);
-    }
-#endif
-
-    id<MTLDevice> device = ctx_dev->mtl_device;
-
-    switch (dst->op) {
-        case GGML_OP_CONCAT:
-            {
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
-
-                const int32_t dim = ((const int32_t *) dst->op_params)[0];
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                [encoder setBytes:&dim  length:sizeof(dim)  atIndex:27];
-
-                const int nth = MIN(1024, ne0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_ADD:
-        case GGML_OP_SUB:
-        case GGML_OP_MUL:
-        case GGML_OP_DIV:
-            {
-                GGML_ASSERT(src0t == GGML_TYPE_F32);
-                GGML_ASSERT(src1t == GGML_TYPE_F32);
-
-                const size_t offs = 0;
-
-                bool bcast_row = false;
-
-                int64_t nb = ne00; // used by the "row" kernels
-
-                id<MTLComputePipelineState> pipeline = nil;
-
-                if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-                    GGML_ASSERT(ggml_is_contiguous(src0));
-
-                    // src1 is a row
-                    GGML_ASSERT(ne11 == 1);
-
-                    nb = ne00 / 4;
-                    switch (dst->op) {
-                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
-                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
-                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
-                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
-                        default: GGML_ABORT("fatal error");
-                    }
-
-                    bcast_row = true;
-                } else {
-                    switch (dst->op) {
-                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
-                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
-                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
-                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
-                        default: GGML_ABORT("fatal error");
-                    }
-                }
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-                [encoder setBytes:&nb   length:sizeof(nb)   atIndex:28];
-
-                if (bcast_row) {
-                    const int64_t n = ggml_nelements(dst)/4;
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } else {
-                    const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                }
-            } break;
-        case GGML_OP_REPEAT:
-            {
-                id<MTLComputePipelineState> pipeline;
-
-                switch (src0t) {
-                    case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
-                    case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
-                    case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
-                    case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
-                    default: GGML_ABORT("fatal error");
-                }
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-
-                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_ACC:
-            {
-                GGML_ASSERT(src0t == GGML_TYPE_F32);
-                GGML_ASSERT(src1t == GGML_TYPE_F32);
-                GGML_ASSERT(dstt  == GGML_TYPE_F32);
-
-                GGML_ASSERT(ggml_is_contiguous(src0));
-                GGML_ASSERT(ggml_is_contiguous(src1));
-
-                const size_t pnb1 = ((const int32_t *) dst->op_params)[0];
-                const size_t pnb2 = ((const int32_t *) dst->op_params)[1];
-                const size_t pnb3 = ((const int32_t *) dst->op_params)[2];
-                const size_t offs = ((const int32_t *) dst->op_params)[3];
-
-                const bool inplace = (bool) ((const int32_t *) dst->op_params)[4];
-
-                if (!inplace) {
-                    // run a separete kernel to cpy src->dst
-                    // not sure how to avoid this
-                    // TODO: make a simpler cpy_bytes kernel
-
-                    const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                    [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                    [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                    [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                    [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                    [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                    [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                    [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                    [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                    [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                    [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                    [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                    [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                    [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                    [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                    [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                    [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
-
-                    const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                }
-
-                const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
-                [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
-                [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
-                [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
-                [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
-                [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-
-                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_SCALE:
-            {
-                GGML_ASSERT(ggml_is_contiguous(src0));
-
-                float scale;
-                memcpy(&scale, dst->op_params, sizeof(scale));
-
-                int64_t n = ggml_nelements(dst);
-
-                id<MTLComputePipelineState> pipeline = nil;
-
-                if (n % 4 == 0) {
-                    n /= 4;
-                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
-                } else {
-                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
-                }
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
-                [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_CLAMP:
-            {
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
-
-                float min;
-                float max;
-                memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float));
-                memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
-                [encoder setBytes:&min length:sizeof(min) atIndex:2];
-                [encoder setBytes:&max length:sizeof(max) atIndex:3];
-
-                const int64_t n = ggml_nelements(dst);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(node)) {
-                // we are not taking into account the strides, so for now require contiguous tensors
-                GGML_ASSERT(ggml_is_contiguous(src0));
-
-                case GGML_UNARY_OP_TANH:
-                {
-                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-                case GGML_UNARY_OP_RELU:
-                {
-                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-                case GGML_UNARY_OP_SIGMOID:
-                {
-                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    const int64_t n = ggml_nelements(dst);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-                case GGML_UNARY_OP_GELU:
-                {
-                    int64_t n = ggml_nelements(dst);
-
-                    id<MTLComputePipelineState> pipeline = nil;
-
-                    if (n % 4 == 0) {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
-                        n /= 4;
-                    } else {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
-                    }
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-                case GGML_UNARY_OP_GELU_QUICK:
-                {
-                    int64_t n = ggml_nelements(dst);
-
-                    id<MTLComputePipelineState> pipeline = nil;
-
-                    if (n % 4 == 0) {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
-                        n /= 4;
-                    } else {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
-                    }
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-                case GGML_UNARY_OP_SILU:
-                {
-                    int64_t n = ggml_nelements(dst);
-
-                    id<MTLComputePipelineState> pipeline = nil;
-
-                    if (n % 4 == 0) {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
-                        n /= 4;
-                    } else {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
-                    }
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                } break;
-                default:
-                {
-                    GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
-                    GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_OP_SQR:
-            {
-                GGML_ASSERT(ggml_is_contiguous(src0));
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
-
-                const int64_t n = ggml_nelements(dst);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_SQRT:
-            {
-                GGML_ASSERT(ggml_is_contiguous(src0));
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
-
-                const int64_t n = ggml_nelements(dst);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_SIN:
-            {
-                GGML_ASSERT(ggml_is_contiguous(src0));
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
-
-                const int64_t n = ggml_nelements(dst);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_COS:
-            {
-                GGML_ASSERT(ggml_is_contiguous(src0));
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
-
-                const int64_t n = ggml_nelements(dst);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_SUM_ROWS:
-            {
-                GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_SOFT_MAX:
-            {
-                GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
-
-                int nth = 32; // SIMD width
-
-                id<MTLComputePipelineState> pipeline = nil;
-
-                const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
-
-                if (ne00%4 == 0) {
-                    while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
-                        nth *= 2;
-                    }
-                    if (use_f16) {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
-                    } else {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
-                    }
-                } else {
-                    while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
-                        nth *= 2;
-                    }
-                    if (use_f16) {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
-                    } else {
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
-                    }
-                }
-
-                float scale;
-                float max_bias;
-
-                memcpy(&scale,    ((const int32_t *) dst->op_params) + 0, sizeof(scale));
-                memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
-
-                const int64_t nrows_x = ggml_nrows(src0);
-                const int64_t nrows_y = src0->ne[1];
-
-                const uint32_t n_head      = nrows_x/nrows_y;
-                const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-                const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-                const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
-                if (id_src1) {
-                    [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
-                } else {
-                    [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
-                }
-                [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
-                [encoder setBytes:&ne00        length:sizeof(ne00)        atIndex:3];
-                [encoder setBytes:&ne01        length:sizeof(ne01)        atIndex:4];
-                [encoder setBytes:&ne02        length:sizeof(ne02)        atIndex:5];
-                [encoder setBytes:&scale       length:sizeof(scale)       atIndex:6];
-                [encoder setBytes:&max_bias    length:sizeof(max_bias)    atIndex:7];
-                [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
-                [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
-                [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
-                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_DIAG_MASK_INF:
-            {
-                const int n_past = ((const int32_t *)(dst->op_params))[0];
-
-                id<MTLComputePipelineState> pipeline = nil;
-
-                if (ne00%8 == 0) {
-                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
-                } else {
-                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
-                }
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
-
-                if (ne00%8 == 0) {
-                    [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                }
-                else {
-                    [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                }
-            } break;
-        case GGML_OP_SSM_CONV:
-            {
-                GGML_ASSERT(src0t == GGML_TYPE_F32);
-                GGML_ASSERT(src1t == GGML_TYPE_F32);
-
-                GGML_ASSERT(ggml_is_contiguous(src0));
-                GGML_ASSERT(ggml_is_contiguous(src1));
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9];
-                [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10];
-                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
-                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
-                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15];
-                [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16];
-                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17];
-                [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_SSM_SCAN:
-            {
-                struct ggml_tensor * src3 = node->src[3];
-                struct ggml_tensor * src4 = node->src[4];
-                struct ggml_tensor * src5 = node->src[5];
-
-                GGML_ASSERT(src3);
-                GGML_ASSERT(src4);
-                GGML_ASSERT(src5);
-
-                size_t offs_src3 = 0;
-                size_t offs_src4 = 0;
-                size_t offs_src5 = 0;
-
-                id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
-                id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
-                id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
-
-                const int64_t  ne30 = src3->ne[0]; GGML_UNUSED(ne30);
-                const int64_t  ne31 = src3->ne[1]; GGML_UNUSED(ne31);
-
-                const uint64_t nb30 = src3->nb[0];
-                const uint64_t nb31 = src3->nb[1];
-
-                const int64_t  ne40 = src4->ne[0]; GGML_UNUSED(ne40);
-                const int64_t  ne41 = src4->ne[1]; GGML_UNUSED(ne41);
-                const int64_t  ne42 = src4->ne[2]; GGML_UNUSED(ne42);
-
-                const uint64_t nb40 = src4->nb[0];
-                const uint64_t nb41 = src4->nb[1];
-                const uint64_t nb42 = src4->nb[2];
-
-                const int64_t  ne50 = src5->ne[0]; GGML_UNUSED(ne50);
-                const int64_t  ne51 = src5->ne[1]; GGML_UNUSED(ne51);
-                const int64_t  ne52 = src5->ne[2]; GGML_UNUSED(ne52);
-
-                const uint64_t nb50 = src5->nb[0];
-                const uint64_t nb51 = src5->nb[1];
-                const uint64_t nb52 = src5->nb[2];
-
-                const int64_t d_state      = ne00;
-                const int64_t d_inner      = ne01;
-                const int64_t n_seq_tokens = ne11;
-                const int64_t n_seqs       = ne02;
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
-                [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
-                [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
-                [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
-
-                [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7];
-                [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8];
-                [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
-                [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10];
-
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
-                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
-                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
-                [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
-                [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
-                [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
-                [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
-                [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
-                [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
-                [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
-                [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_MUL_MAT:
-            {
-                GGML_ASSERT(ne00 == ne10);
-
-                GGML_ASSERT(ne12 % ne02 == 0);
-                GGML_ASSERT(ne13 % ne03 == 0);
-
-                const uint r2 = ne12/ne02;
-                const uint r3 = ne13/ne03;
-
-                // find the break-even point where the matrix-matrix kernel becomes more efficient compared
-                // to the matrix-vector kernel
-                int ne11_mm_min = 1;
-
-#if 0
-                // the numbers below are measured on M2 Ultra for 7B and 13B models
-                // these numbers do not translate to other devices or model sizes
-                // TODO: need to find a better approach
-                        if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
-                            switch (src0t) {
-                                case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
-                                case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
-                                case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
-                                case GGML_TYPE_Q3_K: ne11_mm_min = 7;  break;
-                                case GGML_TYPE_Q4_0:
-                                case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
-                                case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
-                                case GGML_TYPE_Q5_0:                          // not tested yet
-                                case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
-                                case GGML_TYPE_Q5_K: ne11_mm_min = 7;  break;
-                                case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
-                                default:             ne11_mm_min = 1;  break;
-                            }
-                        }
-#endif
-
-                        // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
-                        // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                        if ([device supportsFamily:MTLGPUFamilyApple7] &&
-                                !ggml_is_transposed(src0) &&
-                                !ggml_is_transposed(src1) &&
-                                src1t == GGML_TYPE_F32 &&
-                                ne00 % 32 == 0 && ne00 >= 64 &&
-                                (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
-                            //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
-                            // some Metal matrix data types require aligned pointers
-                            // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
-                            switch (src0->type) {
-                                case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
-                                case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
-                                case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;
-                                default: break;
-                            }
-
-                            id<MTLComputePipelineState> pipeline = nil;
-
-                            switch (src0->type) {
-                                case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32    ].pipeline; break;
-                                case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32    ].pipeline; break;
-                                case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32   ].pipeline; break;
-                                case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break;
-                                case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break;
-                                case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
-                                case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
-                                case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
-                                case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
-                                case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
-                                default: GGML_ABORT("MUL MAT-MAT not implemented");
-                            }
-
-                            [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
-                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
-                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
-                            [encoder setBytes:&nb03    length:sizeof(nb03) atIndex:7];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:9];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:10];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:11];
-                            [encoder setBytes:&nb13    length:sizeof(nb13) atIndex:12];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:15];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:16];
-                            [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                            [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
-                        } else {
-                            int nth0 = 32;
-                            int nth1 = 1;
-                            int nrows = 1;
-                            //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
-                            id<MTLComputePipelineState> pipeline = nil;
-
-                            // use custom matrix x vector kernel
-                            switch (src0t) {
-                                case GGML_TYPE_F32:
-                                    {
-                                        GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
-                                        nrows = 4;
-                                    } break;
-                                case GGML_TYPE_F16:
-                                    {
-                                        nth0 = 32;
-                                        nth1 = 1;
-                                        if (src1t == GGML_TYPE_F32) {
-                                            if (ne11 * ne12 < 4) {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
-                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
-                                                nrows = ne11;
-                                            } else {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
-                                                nrows = 4;
-                                            }
-                                        } else {
-                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
-                                            nrows = 4;
-                                        }
-                                    } break;
-                                case GGML_TYPE_BF16:
-                                    {
-                                        nth0 = 32;
-                                        nth1 = 1;
-                                        if (src1t == GGML_TYPE_F32) {
-                                            if (ne11 * ne12 < 4) {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
-                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
-                                                nrows = ne11;
-                                            } else {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
-                                                nrows = 4;
-                                            }
-                                        } else {
-                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
-                                            nrows = 4;
-                                        }
-                                    } break;
-                                case GGML_TYPE_Q4_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q4_1:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_1:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q8_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q2_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q3_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q4_K:
-                                    {
-                                        nth0 = 4; //1;
-                                        nth1 = 8; //32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q6_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_XXS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_XS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ3_XXS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ3_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ1_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ1_M:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ4_NL:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ4_XS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
-                                    } break;
-                                default:
-                                    {
-                                        GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
-                                        GGML_ABORT("not implemented");
-                                    }
-                            };
-
-                            [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
-                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
-                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:18];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:19];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:20];
-
-                            if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                                const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                                const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                                const int mem_size = 32*sizeof(float);
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q4_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q3_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q5_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q6_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            } else {
-                                const int64_t ny = (ne11 + nrows - 1)/nrows;
-                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                        }
-            } break;
-        case GGML_OP_MUL_MAT_ID:
-            {
-                const int n_as = src0->ne[2];
-
-                // src2 = ids
-                const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
-
-                GGML_ASSERT(src2t == GGML_TYPE_I32);
-
-                GGML_ASSERT(!ggml_is_transposed(src0));
-                GGML_ASSERT(!ggml_is_transposed(src1));
-
-                GGML_ASSERT(src1t == GGML_TYPE_F32);
-
-                GGML_ASSERT(ne03 == 1);
-                GGML_ASSERT(ne13 == 1);
-
-                // find the break-even point where the matrix-matrix kernel becomes more efficient compared
-                // to the matrix-vector kernel
-                // ne20 = n_used_experts
-                // ne21 = n_rows
-                const int dst_rows = ne20*ne21;
-                const int dst_rows_min = n_as;
-                const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
-
-                // max size of the rowids array in the kernel shared buffer
-                GGML_ASSERT(dst_rows <= dst_rows_max);
-
-                // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
-                // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                // !!!
-                // TODO: for now, always use mat-vec kernels until we figure out how to improve the
-                //       indirect matrix multiplication
-                // !!!
-                if ([device supportsFamily:MTLGPUFamilyApple7] &&
-                        ne00 % 32 == 0 && ne00 >= 64 &&
-                        dst_rows > dst_rows_min) {
-                    // some Metal matrix data types require aligned pointers
-                    // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
-                    switch (src0->type) {
-                        case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
-                        case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
-                        case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;
-                        default: break;
-                    }
-
-                    id<MTLComputePipelineState> pipeline = nil;
-
-                    switch (src0->type) {
-                        case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break;
-                        case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break;
-                        case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32   ].pipeline; break;
-                        case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break;
-                        case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32   ].pipeline; break;
-                        case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32   ].pipeline; break;
-                        case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32   ].pipeline; break;
-                        case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32   ].pipeline; break;
-                        case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32   ].pipeline; break;
-                        case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32   ].pipeline; break;
-                        case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32   ].pipeline; break;
-                        case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32   ].pipeline; break;
-                        case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32   ].pipeline; break;
-                        case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
-                        case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
-                        case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
-                        case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32  ].pipeline; break;
-                        case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32  ].pipeline; break;
-                        case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
-                        case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32  ].pipeline; break;
-                        case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
-                        case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
-                        default: GGML_ABORT("MUL_MAT_ID not implemented");
-                    }
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                    [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
-                    [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
-                    [encoder setBytes:&ne21    length:sizeof(ne21) atIndex:5];
-                    [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
-                    [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:7];
-                    [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:8];
-                    [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:9];
-                    [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:10];
-                    [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:11];
-                    [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:12];
-                    [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:13];
-                    [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:14];
-                    [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:15];
-                    [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:16];
-                    [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:17];
-                    [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:18];
-                    [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:19];
-
-                    [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
-
-                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
-                } else {
-                    int nth0 = 32;
-                    int nth1 = 1;
-                    int nrows = 1;
-                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
-                    id<MTLComputePipelineState> pipeline = nil;
-
-                    // use custom matrix x vector kernel
-                    switch (src0t) {
-                        case GGML_TYPE_F32:
-                            {
-                                GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_F16:
-                            {
-                                GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                nth0 = 32;
-                                nth1 = 1;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_BF16:
-                            {
-                                GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                nth0 = 32;
-                                nth1 = 1;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q4_0:
-                            {
-                                nth0 = 8;
-                                nth1 = 8;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q4_1:
-                            {
-                                nth0 = 8;
-                                nth1 = 8;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q5_0:
-                            {
-                                nth0 = 8;
-                                nth1 = 8;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q5_1:
-                            {
-                                nth0 = 8;
-                                nth1 = 8;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q8_0:
-                            {
-                                nth0 = 8;
-                                nth1 = 8;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q2_K:
-                            {
-                                nth0 = 2;
-                                nth1 = 32;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q3_K:
-                            {
-                                nth0 = 2;
-                                nth1 = 32;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q4_K:
-                            {
-                                nth0 = 4; //1;
-                                nth1 = 8; //32;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q5_K:
-                            {
-                                nth0 = 2;
-                                nth1 = 32;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_Q6_K:
-                            {
-                                nth0 = 2;
-                                nth1 = 32;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ2_XXS:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ2_XS:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ3_XXS:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ3_S:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ2_S:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ1_S:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ1_M:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ4_NL:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
-                            } break;
-                        case GGML_TYPE_IQ4_XS:
-                            {
-                                nth0 = 4;
-                                nth1 = 16;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
-                            } break;
-                        default:
-                            {
-                                GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t);
-                                GGML_ABORT("not implemented");
-                            }
-                    };
-
-                    if (ggml_is_quantized(src0t)) {
-                        GGML_ASSERT(ne00 >= nth0*nth1);
-                    }
-
-                    [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
-                    [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
-                    [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
-                    [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
-                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
-                    [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
-                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
-                    [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
-                    [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
-                    [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
-                    [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
-                    [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
-                    [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
-                    [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
-                    [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
-                    [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
-                    [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
-                    [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:20];
-                    [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:21];
-                    [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:22];
-
-                    const int64_t _ne1 = 1;
-                    const int tgz = dst_rows;
-
-                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                            src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                            src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                        const int mem_size = 32*sizeof(float);
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q4_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q3_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q5_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q6_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    } else {
-                        const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                }
-            } break;
-        case GGML_OP_GET_ROWS:
-            {
-                id<MTLComputePipelineState> pipeline = nil;
-
-                switch (src0->type) {
-                    case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32    ].pipeline; break;
-                    case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16    ].pipeline; break;
-                    case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16   ].pipeline; break;
-                    case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0   ].pipeline; break;
-                    case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1   ].pipeline; break;
-                    case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0   ].pipeline; break;
-                    case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1   ].pipeline; break;
-                    case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0   ].pipeline; break;
-                    case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K   ].pipeline; break;
-                    case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K   ].pipeline; break;
-                    case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K   ].pipeline; break;
-                    case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K   ].pipeline; break;
-                    case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K   ].pipeline; break;
-                    case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
-                    case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
-                    case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
-                    case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S  ].pipeline; break;
-                    case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S  ].pipeline; break;
-                    case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
-                    case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M  ].pipeline; break;
-                    case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
-                    case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
-                    case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
-                    default: GGML_ABORT("not implemented");
-                }
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
-                [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
-                [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
-                [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
-                [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
-                [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
-            } break;
-        case GGML_OP_RMS_NORM:
-            {
-                GGML_ASSERT(ne00 % 4 == 0);
-                GGML_ASSERT(ggml_is_contiguous_1(src0));
-
-                float eps;
-                memcpy(&eps, dst->op_params, sizeof(float));
-
-                int nth = 32; // SIMD width
-
-                while (nth < ne00/4 && nth < 1024) {
-                    nth *= 2;
-                }
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
-                [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
-                const int64_t nrows = ggml_nrows(src0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_GROUP_NORM:
-            {
-                GGML_ASSERT(ne00 % 4 == 0);
-                GGML_ASSERT(ggml_is_contiguous(src0));
-
-                float eps;
-                memcpy(&eps, dst->op_params + 1, sizeof(float));
-
-                const int32_t n_groups = ((const int32_t *) dst->op_params)[0];
-
-                int nth = 32; // SIMD width
-
-                //while (nth < ne00/4 && nth < 1024) {
-                //    nth *= 2;
-                //}
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
-                [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
-                [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
-                [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
-                [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
-                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_NORM:
-            {
-                GGML_ASSERT(ggml_is_contiguous_1(src0));
-
-                float eps;
-                memcpy(&eps, dst->op_params, sizeof(float));
-
-                const int nth = MIN(256, ne00);
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
-                [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
-
-                const int64_t nrows = ggml_nrows(src0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_ROPE:
-            {
-                GGML_ASSERT(ne10 == ne02);
-
-                const int nth = MIN(1024, ne00);
-
-                const int n_past     = ((const int32_t *) dst->op_params)[0];
-                const int n_dims     = ((const int32_t *) dst->op_params)[1];
-                const int mode       = ((const int32_t *) dst->op_params)[2];
-                // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
-                const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
-
-                float freq_base;
-                float freq_scale;
-                float ext_factor;
-                float attn_factor;
-                float beta_fast;
-                float beta_slow;
-
-                memcpy(&freq_base,   (const int32_t *) dst->op_params +  5, sizeof(float));
-                memcpy(&freq_scale,  (const int32_t *) dst->op_params +  6, sizeof(float));
-                memcpy(&ext_factor,  (const int32_t *) dst->op_params +  7, sizeof(float));
-                memcpy(&attn_factor, (const int32_t *) dst->op_params +  8, sizeof(float));
-                memcpy(&beta_fast,   (const int32_t *) dst->op_params +  9, sizeof(float));
-                memcpy(&beta_slow,   (const int32_t *) dst->op_params + 10, sizeof(float));
-
-                const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
-                id<MTLComputePipelineState> pipeline = nil;
-
-                if (!is_neox) {
-                    switch (src0->type) {
-                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
-                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
-                        default: GGML_ABORT("fatal error");
-                    };
-                } else {
-                    switch (src0->type) {
-                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
-                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
-                        default: GGML_ABORT("fatal error");
-                    };
-                }
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
-                if (id_src2 != nil) {
-                    [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
-                } else {
-                    [encoder setBuffer:id_src0 offset:offs_src0        atIndex:2];
-                }
-                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:3];
-                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:5];
-                [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:6];
-                [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:7];
-                [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:8];
-                [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:10];
-                [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:11];
-                [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:12];
-                [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:13];
-                [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:14];
-                [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:15];
-                [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:16];
-                [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:17];
-                [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:18];
-                [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
-                [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
-                [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
-                [encoder setBytes:&n_ctx_orig  length:sizeof(     int) atIndex:22];
-                [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
-                [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
-                [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
-                [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
-                [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
-                [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_IM2COL:
-            {
-                GGML_ASSERT(ggml_is_contiguous(src0));
-                GGML_ASSERT(ggml_is_contiguous(src1));
-                GGML_ASSERT(src0->type == GGML_TYPE_F16);
-                GGML_ASSERT(src1->type == GGML_TYPE_F32);
-                GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
-
-                const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
-                const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
-                const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
-                const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
-                const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
-                const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-
-                const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
-                const int32_t N  = src1->ne[is_2D ? 3 : 2];
-                const int32_t IC = src1->ne[is_2D ? 2 : 1];
-                const int32_t IH = is_2D ? src1->ne[1] : 1;
-                const int32_t IW =         src1->ne[0];
-
-                const int32_t KH = is_2D ? src0->ne[1] : 1;
-                const int32_t KW =         src0->ne[0];
-
-                const int32_t OH = is_2D ? dst->ne[2] : 1;
-                const int32_t OW =         dst->ne[1];
-
-                const int32_t CHW = IC * KH * KW;
-
-                const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
-                const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
-
-                const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
-
-                switch (dst->type) {
-                    case GGML_TYPE_F32: {
-                        pipeline = (is_gt_mttpt ?
-                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
-                                    :
-                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
-                    } break;
-                    case GGML_TYPE_F16: {
-                        pipeline = (is_gt_mttpt ?
-                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
-                                    :
-                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
-                    } break;
-                    default: GGML_ABORT("fatal error");
-                };
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
-                [encoder setBytes:&ofs0    length:sizeof(int32_t) atIndex:2];
-                [encoder setBytes:&ofs1    length:sizeof(int32_t) atIndex:3];
-                [encoder setBytes:&IW      length:sizeof(int32_t) atIndex:4];
-                [encoder setBytes:&IH      length:sizeof(int32_t) atIndex:5];
-                [encoder setBytes:&CHW     length:sizeof(int32_t) atIndex:6];
-                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:7];
-                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:8];
-                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:9];
-                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:10];
-                [encoder setBytes:&d0      length:sizeof(int32_t) atIndex:11];
-                [encoder setBytes:&d1      length:sizeof(int32_t) atIndex:12];
-
-                if (is_gt_mttpt) {
-                    [encoder setBytes:&N   length:sizeof(int32_t) atIndex:13];
-                    [encoder setBytes:&KH  length:sizeof(int32_t) atIndex:14];
-                    [encoder setBytes:&KW  length:sizeof(int32_t) atIndex:15];
-
-                    const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
-
-                    const int64_t  quotient  = N / n_threads + (N % n_threads > 0 ? 1 : 0);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
-                } else {
-                    [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
-                }
-            } break;
-        case GGML_OP_UPSCALE:
-            {
-                GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-                const float sf0 = (float)ne0/src0->ne[0];
-                const float sf1 = (float)ne1/src0->ne[1];
-                const float sf2 = (float)ne2/src0->ne[2];
-                const float sf3 = (float)ne3/src0->ne[3];
-
-                const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-                [encoder setBytes:&sf0  length:sizeof(sf0)  atIndex:18];
-                [encoder setBytes:&sf1  length:sizeof(sf1)  atIndex:19];
-                [encoder setBytes:&sf2  length:sizeof(sf2)  atIndex:20];
-                [encoder setBytes:&sf3  length:sizeof(sf3)  atIndex:21];
-
-                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_PAD:
-            {
-                GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-
-                const int nth = MIN(1024, ne0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_ARANGE:
-            {
-                GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-                float start;
-                float step;
-
-                memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float));
-                memcpy(&step,  ((const int32_t *) dst->op_params) + 2, sizeof(float));
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
-                [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
-                [encoder setBytes:&start length:sizeof(start) atIndex:2];
-                [encoder setBytes:&step  length:sizeof(step)  atIndex:3];
-
-                const int nth = MIN(1024, ne0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_TIMESTEP_EMBEDDING:
-            {
-                GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-                const int dim        = dst->op_params[0];
-                const int max_period = dst->op_params[1];
-
-                const int half = dim / 2;
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&nb1   length:sizeof(nb1) atIndex:2];
-                [encoder setBytes:&dim   length:sizeof(dim) atIndex:3];
-                [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
-
-                const int nth = MIN(1024, half);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_ARGSORT:
-            {
-                GGML_ASSERT(src0->type == GGML_TYPE_F32);
-                GGML_ASSERT( dst->type == GGML_TYPE_I32);
-
-                const int nrows = ggml_nrows(src0);
-
-                enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
-
-                // bitonic sort requires the number of elements to be power of 2
-                int64_t ne00_padded = 1;
-                while (ne00_padded < ne00) {
-                    ne00_padded *= 2;
-                }
-
-                // Metal kernels require the buffer size to be multiple of 16 bytes
-                // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
-                const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
-
-                id<MTLComputePipelineState> pipeline = nil;
-
-                switch (order) {
-                    case GGML_SORT_ORDER_ASC:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline;  break;
-                    case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
-                    default: GGML_ABORT("fatal error");
-                };
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
-                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
-            } break;
-        case GGML_OP_LEAKY_RELU:
-            {
-                GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
-                float slope;
-                memcpy(&slope, dst->op_params, sizeof(float));
-
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
-                [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
-
-                const int64_t n = ggml_nelements(dst);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-            } break;
-        case GGML_OP_FLASH_ATTN_EXT:
-            {
-                GGML_ASSERT(ne00 % 4  == 0);
-                GGML_ASSERT(ne11 % 32 == 0);
-
-                GGML_ASSERT(src0->type == GGML_TYPE_F32);
-                GGML_ASSERT(src1->type == src2->type);
-
-                GGML_ASSERT(ggml_are_same_shape (src1, src2));
-
-                struct ggml_tensor * src3 = node->src[3];
-
-                size_t offs_src3 = 0;
-
-                id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
-
-                GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
-                GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
-                        "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
-
-                const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
-                //const int64_t  ne31 = src3 ? src3->ne[1] : 0;
-                const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
-                const int64_t  ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
-
-                const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
-                const uint64_t nb31 = src3 ? src3->nb[1] : 0;
-                const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
-                const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
-
-                const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
-
-                float scale;
-                float max_bias;
-                float logit_softcap;
-                memcpy(&scale,         ((const int32_t *) dst->op_params) + 0, sizeof(scale));
-                memcpy(&max_bias,      ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
-                memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
-
-                if (logit_softcap != 0.0f) {
-                    scale /= logit_softcap;
-                }
-
-                const uint32_t n_head      = src0->ne[2];
-                const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-                const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-                const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-                id<MTLComputePipelineState> pipeline = nil;
-
-                bool use_vec_kernel = false;
-
-                // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
-                //       for now avoiding mainly to keep the number of templates/kernels a bit lower
-                if (ne01 >= 4 || (ne00%128 != 0)) {
-                    switch (src1->type) {
-                        case GGML_TYPE_F16:
-                            {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
-                                }
-                            } break;
-                        case GGML_TYPE_BF16:
-                            {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
-                                }
-                            } break;
-                        case GGML_TYPE_Q4_0:
-                            {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
-                                }
-                            } break;
-                        case GGML_TYPE_Q4_1:
-                            {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
-                                }
-                            } break;
-                        case GGML_TYPE_Q5_0:
-                            {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
-                                }
-                            } break;
-                        case GGML_TYPE_Q5_1:
-                            {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
-                                }
-                            } break;
-                        case GGML_TYPE_Q8_0:
-                            {
-                                switch (ne00) {
-                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
-                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
-                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
-                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
-                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
-                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
-                                    default:
-                                              {
-                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                                  GGML_LOG_ERROR("add template specialization for this size\n");
-                                                  GGML_ABORT("add template specialization for this size");
-                                              }
-                                }
-                            } break;
-                        default:
-                            {
-                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                GGML_LOG_ERROR("add template specialization for this type\n");
-                                GGML_ABORT("add template specialization for this type");
-                            }
-                    }
-                } else {
-                    use_vec_kernel = true;
-
-                    switch (ne00) {
-                        case 128:
-                            {
-                                switch (src1->type) {
-                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
-                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
-                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
-                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
-                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
-                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
-                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
-                                    default:
-                                        {
-                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                            GGML_LOG_ERROR("add template specialization for this type\n");
-                                            GGML_ABORT("add template specialization for this type");
-                                        }
-                                }
-                            } break;
-                        case 256:
-                            {
-                                switch (src1->type) {
-                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
-                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
-                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
-                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
-                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
-                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
-                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
-                                    default:
-                                        {
-                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                            GGML_LOG_ERROR("add template specialization for this type\n");
-                                            GGML_ABORT("add template specialization for this type");
-                                        }
-                                }
-                            } break;
-                        default:
-                                  {
-                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                      GGML_ABORT("add template specialization for this size");
-                                  }
-                    }
-                }
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0           atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1           atIndex:1];
-                [encoder setBuffer:id_src2     offset:offs_src2           atIndex:2];
-                if (id_src3) {
-                    [encoder setBuffer:id_src3     offset:offs_src3           atIndex:3];
-                } else {
-                    [encoder setBuffer:id_src0     offset:offs_src0           atIndex:3];
-                }
-                [encoder setBuffer:id_dst        offset:offs_dst              atIndex:4];
-                [encoder setBytes:&ne01          length:sizeof( int64_t)      atIndex:5];
-                [encoder setBytes:&ne02          length:sizeof( int64_t)      atIndex:6];
-                [encoder setBytes:&ne03          length:sizeof( int64_t)      atIndex:7];
-                [encoder setBytes:&nb01          length:sizeof(uint64_t)      atIndex:8];
-                [encoder setBytes:&nb02          length:sizeof(uint64_t)      atIndex:9];
-                [encoder setBytes:&nb03          length:sizeof(uint64_t)      atIndex:10];
-                [encoder setBytes:&ne11          length:sizeof( int64_t)      atIndex:11];
-                [encoder setBytes:&ne12          length:sizeof( int64_t)      atIndex:12];
-                [encoder setBytes:&ne13          length:sizeof( int64_t)      atIndex:13];
-                [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14];
-                [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15];
-                [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16];
-                [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:17];
-                [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:18];
-                [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:19];
-                [encoder setBytes:&scale         length:sizeof(   float)      atIndex:20];
-                [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:21];
-                [encoder setBytes:&m0            length:sizeof(m0)            atIndex:22];
-                [encoder setBytes:&m1            length:sizeof(m1)            atIndex:23];
-                [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:24];
-                [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
-
-                if (!use_vec_kernel) {
-                    // half8x8 kernel
-                    const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
-                    const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
-
-                    GGML_ASSERT(nqptg <= 32);
-                    GGML_ASSERT(nqptg  % 8  == 0);
-                    GGML_ASSERT(ncpsg  % 32 == 0);
-
-                    // 2*(2*ncpsg + nqptg)*(nsg)
-                    // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
-                    //
-                    // 16*32*(nsg)
-                    // the shared memory needed for the simdgroups to load the KV cache
-                    // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
-                    //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
-
-                    int64_t nsgmax = 2;
-
-                    while (true) {
-                        const size_t smem = FATTN_SMEM(nsgmax);
-                        if (smem > device.maxThreadgroupMemoryLength) {
-                            break;
-                        }
-                        nsgmax *= 2;
-                    }
-                    nsgmax /= 2;
-
-                    // simdgroups per threadgroup (a.k.a. warps)
-                    const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
-
-                    const size_t smem = FATTN_SMEM(nsg);
-
-                    //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
-                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
-                    [encoder setThreadgroupMemoryLength:smem atIndex:0];
-#undef FATTN_SMEM
-                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
-                } else {
-                    // half4x4 kernel
-                    const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
-                    const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
-
-                    GGML_ASSERT(nqptg <= 32);
-                    GGML_ASSERT(nqptg  % 1  == 0);
-                    GGML_ASSERT(ncpsg  % 32 == 0);
-
-                    // ne00 + 2*ncpsg*(nsg)
-                    // for each query, we load it as f16 in shared memory (ne00)
-                    // and store the soft_max values and the mask
-                    //
-                    // ne00*(nsg)
-                    // each simdgroup has a full f16 head vector in shared mem to accumulate results
-                    //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
-
-                    int64_t nsgmax = 2;
-
-                    while (true) {
-                        const size_t smem = FATTN_SMEM(nsgmax);
-                        if (smem > device.maxThreadgroupMemoryLength) {
-                            break;
-                        }
-                        nsgmax *= 2;
-                    }
-                    nsgmax /= 2;
-
-                    // simdgroups per threadgroup (a.k.a. warps)
-                    const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
-
-                    int64_t nsg = 1;
-                    while (nsg <= nsgt) {
-                        nsg *= 2;
-                    }
-                    nsg /= 2;
-
-                    const size_t smem = FATTN_SMEM(nsg);
-
-                    //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
-                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
-                    [encoder setThreadgroupMemoryLength:smem atIndex:0];
-#undef FATTN_SMEM
-                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
-                }
-            } break;
-        case GGML_OP_DUP:
-        case GGML_OP_CPY:
-        case GGML_OP_CONT:
-            {
-                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
-                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
-                id<MTLComputePipelineState> pipeline = nil;
-
-                switch (src0t) {
-                    case GGML_TYPE_F32:
-                        {
-                            GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
-
-                            switch (dstt) {
-                                case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
-                                case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
-                                case GGML_TYPE_BF16:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
-                                case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
-                                case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
-                                case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
-                                case GGML_TYPE_Q5_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
-                                case GGML_TYPE_Q5_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
-                                case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
-                                default: GGML_ABORT("not implemented");
-                            };
-                        } break;
-                    case GGML_TYPE_F16:
-                        {
-                            switch (dstt) {
-                                case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
-                                case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
-                                default: GGML_ABORT("not implemented");
-                            };
-                        } break;
-                    case GGML_TYPE_BF16:
-                        {
-                            switch (dstt) {
-                                case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
-                                case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
-                                default: GGML_ASSERT(false && "not implemented");
-                            };
-                        } break;
-                    default: GGML_ABORT("not implemented");
-                }
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-            } break;
-        case GGML_OP_POOL_2D:
-            {
-                GGML_ASSERT(ggml_is_contiguous(src0));
-                GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
-
-                const int32_t * opts = dst->op_params;
-                enum ggml_op_pool op = opts[0];
-
-                id<MTLComputePipelineState> pipeline = nil;
-                switch (src0t) {
-                    case GGML_TYPE_F32: {
-                        switch(op) {
-                            case GGML_OP_POOL_AVG:
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
-                            case GGML_OP_POOL_MAX:
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
-                            default: GGML_ASSERT(false && "not implemented");
-                        }
-                    } break;
-                    default: GGML_ASSERT(false && "not implemented");
-                }
-
-                const int32_t k0 = opts[1];
-                const int32_t k1 = opts[2];
-                const int32_t s0 = opts[3];
-                const int32_t s1 = opts[4];
-                const int32_t p0 = opts[5];
-                const int32_t p1 = opts[6];
-
-                const int64_t IH = src0->ne[1];
-                const int64_t IW = src0->ne[0];
-
-                const int64_t N  = dst->ne[3];
-                const int64_t OC = dst->ne[2];
-                const int64_t OH = dst->ne[1];
-                const int64_t OW = dst->ne[0];
-
-                const int64_t parallel_elements = N * OC * OH * OW;
-                const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
-                const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
-
-                [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
-                [encoder setBytes:&k0      length:sizeof(int32_t) atIndex:2];
-                [encoder setBytes:&k1      length:sizeof(int32_t) atIndex:3];
-                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:4];
-                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:5];
-                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:6];
-                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:7];
-                [encoder setBytes:&IH      length:sizeof(int64_t) atIndex:8];
-                [encoder setBytes:&IW      length:sizeof(int64_t) atIndex:9];
-                [encoder setBytes:&OH      length:sizeof(int64_t) atIndex:10];
-                [encoder setBytes:&OW      length:sizeof(int64_t) atIndex:11];
-                [encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
-            } break;
-       default:
-            {
-                GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-static enum ggml_status ggml_metal_graph_compute(
-            ggml_backend_t   backend,
-        struct ggml_cgraph * gf) {
-    struct ggml_backend_metal_context        * ctx     = backend->context;
-    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
-
-    // number of nodes encoded by the main thread (empirically determined)
-    const int n_main = 128;
-
-    // number of threads in addition to the main thread
-    const int n_cb = ctx->n_cb;
-
-    // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
-    // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
-    // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
-    // each thread creates it's own command buffer and enqueues the ops in parallel
-    //
-    // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
-
-    @autoreleasepool {
-        ctx->gf = gf;
-
-        ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
-        ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
-
-        ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
-
-        const bool should_capture = ctx->capture_next_compute;
-        if (should_capture) {
-            ctx->capture_next_compute = false;
-
-            if (!ctx->capture_started) {
-                // create capture scope
-                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
-
-                MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
-                descriptor.captureObject = ctx->capture_scope;
-                descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
-                descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
-
-                NSError * error = nil;
-                if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
-                    GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
-                } else {
-                    [ctx->capture_scope beginScope];
-                    ctx->capture_started = true;
-                }
-            }
-        }
-
-        // the main thread commits the first few commands immediately
-        // command_buffer[n_cb]
-        {
-            id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
-            ctx->command_buffers[n_cb] = command_buffer;
-
-            [command_buffer enqueue];
-            ctx->encode_async(n_cb);
-        }
-
-        // prepare the rest of the command buffers asynchronously
-        // command_buffer[0.. n_cb)
-        for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
-            id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
-            ctx->command_buffers[cb_idx] = command_buffer;
-
-            // always enqueue the first two command buffers
-            // enqueue all of the command buffers if we don't need to abort
-            if (cb_idx < 2 || ctx->abort_callback == NULL) {
-                [command_buffer enqueue];
-            }
-        }
-
-        dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
-
-        // wait for completion and check status of each command buffer
-        // needed to detect if the device ran out-of-memory for example (#1881)
-        {
-            id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
-            [command_buffer waitUntilCompleted];
-
-            MTLCommandBufferStatus status = [command_buffer status];
-            if (status != MTLCommandBufferStatusCompleted) {
-                GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
-                if (status == MTLCommandBufferStatusError) {
-                    GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
-                }
-
-                return GGML_STATUS_FAILED;
-            }
-        }
-
-        for (int i = 0; i < n_cb; ++i) {
-            id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
-            [command_buffer waitUntilCompleted];
-
-            MTLCommandBufferStatus status = [command_buffer status];
-            if (status != MTLCommandBufferStatusCompleted) {
-                GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
-                if (status == MTLCommandBufferStatusError) {
-                    GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
-                }
-
-                return GGML_STATUS_FAILED;
-            }
-
-            id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
-            if (!next_buffer) {
-                continue;
-            }
-
-            const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
-            if (next_queued) {
-                continue;
-            }
-
-            if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
-                GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
-                return GGML_STATUS_ABORTED;
-            }
-
-            [next_buffer commit];
-        }
-
-        if (!should_capture && ctx->capture_started) {
-            [ctx->capture_scope endScope];
-            [[MTLCaptureManager sharedCaptureManager] stopCapture];
-        }
-    }
-
-    return GGML_STATUS_SUCCESS;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend interface
-
-static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
-
-    for (int i = 0; i < ctx->n_buffers; i++) {
-        [ctx->buffers[i].metal release];
-    }
-    ggml_backend_metal_device_rel(buffer->buft->device->context);
-
-    if (ctx->owned) {
-#if TARGET_OS_OSX
-        vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
-#else
-        free(ctx->all_data);
-#endif
-    }
-
-    free(ctx);
-}
-
-static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
-    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
-
-    return ctx->all_data;
-}
-
-static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    memcpy((char *)tensor->data + offset, data, size);
-
-    UNUSED(buffer);
-}
-
-static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    memcpy(data, (const char *)tensor->data + offset, size);
-
-    UNUSED(buffer);
-}
-
-static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
-    if (ggml_backend_buffer_is_host(src->buffer)) {
-        memcpy(dst->data, src->data, ggml_nbytes(src));
-        return true;
-    }
-    return false;
-
-    UNUSED(buffer);
-}
-
-static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
-
-    memset(ctx->all_data, value, ctx->all_size);
-}
-
-static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
-    /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_metal_buffer_get_base,
-    /* .init_tensor     = */ NULL,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
-    /* .cpy_tensor      = */ ggml_backend_metal_buffer_cpy_tensor,
-    /* .clear           = */ ggml_backend_metal_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// default buffer type
-
-static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    return "Metal";
-
-    UNUSED(buft);
-}
-
-static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
-#ifndef GGML_METAL_NDEBUG
-#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
-    if (@available(macOS 10.12, iOS 16.0, *)) {
-        GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n",
-                __func__,
-                size_aligned / 1024.0 / 1024.0,
-                device.currentAllocatedSize / 1024.0 / 1024.0,
-                device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-
-        if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
-            GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
-        }
-    } else {
-        GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
-                __func__,
-                size_aligned / 1024.0 / 1024.0,
-                device.currentAllocatedSize / 1024.0 / 1024.0);
-    }
-#endif
-#endif
-    UNUSED(device);
-    UNUSED(size_aligned);
-}
-
-static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
-
-    const size_t size_page = sysconf(_SC_PAGESIZE);
-
-    size_t size_aligned = size;
-    if ((size_aligned % size_page) != 0) {
-        size_aligned += (size_page - (size_aligned % size_page));
-    }
-
-    id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
-
-    ctx->all_data = ggml_metal_host_malloc(size_aligned);
-    ctx->all_size = size_aligned;
-    ctx->owned = true;
-    ctx->n_buffers = 1;
-
-    if (ctx->all_data != NULL) {
-        ctx->buffers[0].data  = ctx->all_data;
-        ctx->buffers[0].size  = size;
-        ctx->buffers[0].metal = nil;
-
-        if (size_aligned > 0) {
-            ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
-                                            length:size_aligned
-                                            options:MTLResourceStorageModeShared
-                                            deallocator:nil];
-        }
-    }
-
-    if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
-        GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
-        free(ctx);
-        ggml_backend_metal_device_rel(buft->device->context);
-        return NULL;
-    }
-
-    //ggml_backend_metal_log_allocated_size(device, size_aligned);
-
-    return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
-}
-
-static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return 32;
-    UNUSED(buft);
-}
-
-static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
-    const size_t max_size = device.maxBufferLength;
-    ggml_backend_metal_device_rel(buft->device->context);
-
-    return max_size;
-
-    UNUSED(buft);
-}
-
-static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
-    return true;
-
-    UNUSED(buft);
-}
-
-ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
-    static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_metal_buffer_type_get_name,
-            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_metal_buffer_type_get_alignment,
-            /* .get_max_size     = */ ggml_backend_metal_buffer_type_get_max_size,
-            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
-            /* .is_host          = */ ggml_backend_metal_buffer_type_is_host,
-        },
-        /* .device  = */ &g_ggml_backend_metal_device,
-        /* .context = */ NULL,
-    };
-
-    return &ggml_backend_buffer_type_metal;
-}
-
-static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
-    return "Metal_Mapped";
-
-    UNUSED(buft);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
-    static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_metal_buffer_from_ptr_type_get_name,
-            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_metal_buffer_type_get_alignment,
-            /* .get_max_size     = */ ggml_backend_metal_buffer_type_get_max_size,
-            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
-            /* .is_host          = */ ggml_backend_metal_buffer_type_is_host,
-        },
-        /* .device  = */ &g_ggml_backend_metal_device,
-        /* .context = */ NULL,
-    };
-
-    return &ggml_backend_buffer_from_ptr_type_metal;
-}
-
-// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
-ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
-    struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
-
-    ctx->all_data = data;
-    ctx->all_size = size;
-    ctx->owned = false;
-    ctx->n_buffers = 0;
-
-    const size_t size_page = sysconf(_SC_PAGESIZE);
-
-    // page-align the data ptr
-    {
-        const uintptr_t offs = (uintptr_t) data % size_page;
-        data  = (void *) ((char *) data - offs);
-        size += offs;
-    }
-
-    size_t size_aligned = size;
-    if ((size_aligned % size_page) != 0) {
-        size_aligned += (size_page - (size_aligned % size_page));
-    }
-
-    id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
-
-    // the buffer fits into the max buffer size allowed by the device
-    if (size_aligned <= device.maxBufferLength) {
-        ctx->buffers[ctx->n_buffers].data  = data;
-        ctx->buffers[ctx->n_buffers].size  = size;
-        ctx->buffers[ctx->n_buffers].metal = nil;
-
-        if (size_aligned > 0) {
-            ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
-            if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
-                return false;
-            }
-        }
-
-        ggml_backend_metal_log_allocated_size(device, size_aligned);
-
-        ++ctx->n_buffers;
-    } else {
-        // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
-        // one of the views
-        const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
-        const size_t size_step = device.maxBufferLength - size_ovlp;
-        const size_t size_view = device.maxBufferLength;
-
-        for (size_t i = 0; i < size; i += size_step) {
-            const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
-
-            ctx->buffers[ctx->n_buffers].data  = (void *) ((uint8_t *) data + i);
-            ctx->buffers[ctx->n_buffers].size  = size_step_aligned;
-            ctx->buffers[ctx->n_buffers].metal = nil;
-
-            if (size_step_aligned > 0) {
-                ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
-                if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                    GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
-                    return false;
-                }
-            }
-
-            ggml_backend_metal_log_allocated_size(device, size_step_aligned);
-
-            if (i + size_step < size) {
-                GGML_LOG_INFO("\n");
-            }
-
-            ++ctx->n_buffers;
-        }
-    }
-
-    return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
-}
-
-// backend
-
-static const char * ggml_backend_metal_name(ggml_backend_t backend) {
-    return "Metal";
-
-    UNUSED(backend);
-}
-
-static void ggml_backend_metal_free(ggml_backend_t backend) {
-    struct ggml_backend_metal_context        * ctx     = backend->context;
-    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
-
-    ggml_backend_metal_device_rel(ctx_dev);
-    ggml_metal_free(ctx);
-
-    free(backend);
-}
-
-static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    return ggml_metal_graph_compute(backend, cgraph);
-}
-
-static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
-    GGML_ASSERT(ggml_backend_is_metal(backend));
-
-    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
-
-    if (ctx->n_cb != n_cb) {
-        ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
-
-        if (ctx->n_cb > 2) {
-            GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
-        }
-    }
-
-    if (ctx->encode_async) {
-        Block_release(ctx->encode_async);
-    }
-
-    ctx->encode_async = Block_copy(^(size_t iter) {
-        const int cb_idx = iter;
-        const int n_cb_l = ctx->n_cb;
-
-        const int n_nodes_0 = ctx->n_nodes_0;
-        const int n_nodes_1 = ctx->n_nodes_1;
-
-        const int n_nodes_per_cb = ctx->n_nodes_per_cb;
-
-        id<MTLCommandBuffer> command_buffer  = ctx->command_buffers[cb_idx];
-        id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
-
-        int node_start = 0;
-        int node_end   = n_nodes_0;
-
-        if (cb_idx < n_cb_l) {
-            node_start = n_nodes_0 + (                                         (cb_idx + 0) * n_nodes_per_cb);
-            node_end   = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
-        }
-
-        const bool should_capture = ctx->capture_next_compute;
-
-        for (int idx = node_start; idx < node_end; ++idx) {
-            if (should_capture) {
-                [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
-            }
-
-            ggml_metal_encode_node(backend, idx, encoder);
-
-            if (should_capture) {
-                [encoder popDebugGroup];
-            }
-        }
-
-        [encoder endEncoding];
-
-        if (cb_idx < 2 || ctx->abort_callback == NULL) {
-            [command_buffer commit];
-        }
-    });
-}
-
-static struct ggml_backend_i ggml_backend_metal_i = {
-    /* .get_name                = */ ggml_backend_metal_name,
-    /* .free                    = */ ggml_backend_metal_free,
-    /* .set_tensor_async        = */ NULL,
-    /* .get_tensor_async        = */ NULL,
-    /* .cpy_tensor_async        = */ NULL,
-    /* .synchronize             = */ NULL,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_metal_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_metal_guid(void) {
-    static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
-    return &guid;
-}
-
-// TODO: remove in the future
-ggml_backend_t ggml_backend_metal_init(void) {
-    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
-
-    struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
-    if (ctx == NULL) {
-        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
-        return NULL;
-    }
-
-    ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
-
-    *backend = (struct ggml_backend) {
-        /* .guid      = */ ggml_backend_metal_guid(),
-        /* .interface = */ ggml_backend_metal_i,
-        /* .device    = */ dev,
-        /* .context   = */ ctx,
-    };
-
-    ggml_backend_metal_set_n_cb(backend, 1);
-
-    return backend;
-}
-
-bool ggml_backend_is_metal(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
-}
-
-void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
-    GGML_ASSERT(ggml_backend_is_metal(backend));
-
-    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
-
-    ctx->abort_callback = abort_callback;
-    ctx->abort_callback_data = user_data;
-}
-
-bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
-    GGML_ASSERT(ggml_backend_is_metal(backend));
-
-    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
-
-    return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
-}
-
-void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
-    GGML_ASSERT(ggml_backend_is_metal(backend));
-
-    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
-    ctx->capture_next_compute = true;
-}
-
-// backend device
-
-static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
-    return "Metal";
-
-    GGML_UNUSED(dev);
-}
-
-static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
-    // acq/rel just to populate ctx->name in case it hasn't been done yet
-    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
-    ggml_backend_metal_device_acq(ctx_dev);
-    ggml_backend_metal_device_rel(ctx_dev);
-
-    return ctx_dev->name;
-}
-
-static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    if (@available(macOS 10.12, iOS 16.0, *)) {
-        struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
-        id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
-
-        *total = device.recommendedMaxWorkingSetSize;
-        *free  = *total - device.currentAllocatedSize;
-
-        ggml_backend_metal_device_rel(ctx_dev);
-    } else {
-        *free = 1;
-        *total = 1;
-    }
-}
-
-static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
-    return GGML_BACKEND_DEVICE_TYPE_GPU;
-
-    GGML_UNUSED(dev);
-}
-
-static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_metal_device_get_name(dev);
-    props->description = ggml_backend_metal_device_get_description(dev);
-    props->type        = ggml_backend_metal_device_get_type(dev);
-    ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
-    props->caps = (struct ggml_backend_dev_caps) {
-        /* .async                 = */ false,
-        /* .host_buffer           = */ false,
-        /* .buffer_from_host_ptr  = */ true,
-        /* .events                = */ false,
-    };
-}
-
-static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
-    struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
-    if (ctx == NULL) {
-        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
-        return NULL;
-    }
-
-    ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
-
-    *backend = (struct ggml_backend) {
-        /* .guid      = */ ggml_backend_metal_guid(),
-        /* .interface = */ ggml_backend_metal_i,
-        /* .device    = */ dev,
-        /* .context   = */ ctx,
-    };
-
-    ggml_backend_metal_set_n_cb(backend, 1);
-
-    return backend;
-
-    GGML_UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
-    return ggml_backend_metal_buffer_type();
-
-    GGML_UNUSED(dev);
-}
-
-static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
-    struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
-
-    ctx->all_data = ptr;
-    ctx->all_size = size;
-    ctx->owned = false;
-    ctx->n_buffers = 0;
-
-    const size_t size_page = sysconf(_SC_PAGESIZE);
-
-    // page-align the data ptr
-    {
-        const uintptr_t offs = (uintptr_t) ptr % size_page;
-        ptr  = (void *) ((char *) ptr - offs);
-        size += offs;
-    }
-
-    size_t size_aligned = size;
-    if ((size_aligned % size_page) != 0) {
-        size_aligned += (size_page - (size_aligned % size_page));
-    }
-
-    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
-    id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
-
-    // the buffer fits into the max buffer size allowed by the device
-    if (size_aligned <= device.maxBufferLength) {
-        ctx->buffers[ctx->n_buffers].data  = ptr;
-        ctx->buffers[ctx->n_buffers].size  = size;
-        ctx->buffers[ctx->n_buffers].metal = nil;
-
-        if (size_aligned > 0) {
-            ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
-            if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
-                return false;
-            }
-        }
-
-        ggml_backend_metal_log_allocated_size(device, size_aligned);
-
-        ++ctx->n_buffers;
-    } else {
-        // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
-        // one of the views
-        const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
-        const size_t size_step = device.maxBufferLength - size_ovlp;
-        const size_t size_view = device.maxBufferLength;
-
-        for (size_t i = 0; i < size; i += size_step) {
-            const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
-
-            ctx->buffers[ctx->n_buffers].data  = (void *) ((uint8_t *) ptr + i);
-            ctx->buffers[ctx->n_buffers].size  = size_step_aligned;
-            ctx->buffers[ctx->n_buffers].metal = nil;
-
-            if (size_step_aligned > 0) {
-                ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
-
-                if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                    GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
-                    return false;
-                }
-            }
-
-            ggml_backend_metal_log_allocated_size(device, size_step_aligned);
-
-            if (i + size_step < size) {
-                GGML_LOG_INFO("\n");
-            }
-
-            ++ctx->n_buffers;
-        }
-    }
-
-    return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
-}
-
-static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    struct ggml_backend_metal_device_context * ctx_dev = dev->context;
-
-    return ggml_metal_supports_op(ctx_dev, op);
-}
-
-static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
-            buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
-
-    UNUSED(dev);
-}
-
-static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    return false;
-
-    GGML_UNUSED(dev);
-    GGML_UNUSED(op);
-}
-
-static struct ggml_backend_device_i ggml_backend_metal_device_i = {
-    /* .get_name             = */ ggml_backend_metal_device_get_name,
-    /* .get_description      = */ ggml_backend_metal_device_get_description,
-    /* .get_memory           = */ ggml_backend_metal_device_get_memory,
-    /* .get_type             = */ ggml_backend_metal_device_get_type,
-    /* .get_props            = */ ggml_backend_metal_device_get_props,
-    /* .init_backend         = */ ggml_backend_metal_device_init,
-    /* .get_buffer_type      = */ ggml_backend_metal_device_get_buffer_type,
-    /* .get_host_buffer_type = */ NULL,
-    /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
-    /* .supports_op          = */ ggml_backend_metal_device_supports_op,
-    /* .supports_buft        = */ ggml_backend_metal_device_supports_buft,
-    /* .offload_op           = */ ggml_backend_metal_device_offload_op,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
-};
-
-// backend registry
-
-static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
-    return "Metal";
-
-    GGML_UNUSED(reg);
-}
-
-static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
-    return 1;
-
-    GGML_UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
-    GGML_ASSERT(index == 0);
-
-    return &g_ggml_backend_metal_device;
-
-    GGML_UNUSED(reg);
-    GGML_UNUSED(index);
-}
-
-static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
-    /* .get_name         = */ ggml_backend_metal_reg_get_name,
-    /* .device_count     = */ ggml_backend_metal_reg_device_count,
-    /* .device_get       = */ ggml_backend_metal_reg_device_get,
-    /* .get_proc_address = */ NULL,
-};
-
-ggml_backend_reg_t ggml_backend_metal_reg(void) {
-    // TODO: make this thread-safe somehow?
-    {
-        g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
-            /* .iface   = */ ggml_backend_metal_reg_i,
-            /* .context = */ NULL,
-        };
-
-        g_ggml_backend_metal_device = (struct ggml_backend_device) {
-            /* .iface   = */ ggml_backend_metal_device_i,
-            /* .reg     = */ &g_ggml_backend_metal_reg,
-            /* .context = */ &g_ggml_ctx_dev_main,
-        };
-    }
-
-    return &g_ggml_backend_metal_reg;
-}
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
deleted file mode 100644 (file)
index e8b71a9..0000000
+++ /dev/null
@@ -1,7064 +0,0 @@
-#define GGML_COMMON_DECL_METAL
-#define GGML_COMMON_IMPL_METAL
-#include "ggml-common.h"
-
-#include <metal_stdlib>
-
-using namespace metal;
-
-#define MAX(x, y) ((x) > (y) ? (x) : (y))
-#define MIN(x, y) ((x) < (y) ? (x) : (y))
-#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
-
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-
-// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
-//
-// cmd:
-//   .../usr/bin/metal -dM -E -c                             ggml/src/ggml-metal.metal
-//   .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal.metal
-//
-#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16)
-#undef GGML_METAL_USE_BF16
-#endif
-
-#if defined(GGML_METAL_USE_BF16)
-typedef matrix<bfloat, 4, 4> bfloat4x4;
-#endif
-
-constexpr constant static float kvalues_iq4nl_f[16] = {
-    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
-};
-
-// NOTE: this is not dequantizing - we are simply fitting the template
-template <typename type4x4>
-void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
-    reg = (type4x4)(*src);
-}
-
-template <typename type4x4>
-void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
-    reg = (type4x4)(*src);
-}
-
-#if defined(GGML_METAL_USE_BF16)
-template <typename type4x4>
-void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
-    reg = (type4x4)(*src);
-}
-#endif
-
-template <typename type4x4>
-void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
-    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
-    const float d1 = il ? (xb->d / 16.h) : xb->d;
-    const float d2 = d1 / 256.f;
-    const float md = -8.h * xb->d;
-    const ushort mask0 = il ? 0x00F0 : 0x000F;
-    const ushort mask1 = mask0 << 8;
-
-    float4x4 reg_f;
-
-    for (int i = 0; i < 8; i++) {
-        reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
-        reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
-    }
-
-    reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
-    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
-    const float d1 = il ? (xb->d / 16.h) : xb->d;
-    const float d2 = d1 / 256.f;
-    const float  m = xb->m;
-    const ushort mask0 = il ? 0x00F0 : 0x000F;
-    const ushort mask1 = mask0 << 8;
-
-    float4x4 reg_f;
-
-    for (int i = 0; i < 8; i++) {
-        reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
-        reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
-    }
-
-    reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
-    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
-    const float d = xb->d;
-    const float md = -16.h * xb->d;
-    const ushort mask = il ? 0x00F0 : 0x000F;
-
-    const uint32_t qh = *((device const uint32_t *)xb->qh);
-
-    const int x_mv = il ? 4 : 0;
-
-    const int gh_mv = il ? 12 : 0;
-    const int gh_bk = il ?  0 : 4;
-
-    float4x4 reg_f;
-
-    for (int i = 0; i < 8; i++) {
-        // extract the 5-th bits for x0 and x1
-        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
-        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
-
-        // combine the 4-bits from qs with the 5th bit
-        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
-        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
-
-        reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
-        reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
-    }
-
-    reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
-    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
-    const float d = xb->d;
-    const float m = xb->m;
-    const ushort mask = il ? 0x00F0 : 0x000F;
-
-    const uint32_t qh = *((device const uint32_t *)xb->qh);
-
-    const int x_mv = il ? 4 : 0;
-
-    const int gh_mv = il ? 12 : 0;
-    const int gh_bk = il ?  0 : 4;
-
-    float4x4 reg_f;
-
-    for (int i = 0; i < 8; i++) {
-        // extract the 5-th bits for x0 and x1
-        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
-        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
-
-        // combine the 4-bits from qs with the 5th bit
-        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
-        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
-
-        reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
-        reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
-    }
-
-    reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
-    device const int8_t * qs = ((device const int8_t *)xb->qs);
-    const half d = xb->d;
-
-    float4x4 reg_f;
-
-    for (int i = 0; i < 16; i++) {
-        reg_f[i/4][i%4] = (qs[i + 16*il] * d);
-    }
-
-    reg = (type4x4) reg_f;
-}
-
-template <typename type4x4>
-void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
-    const float d = xb->d;
-    const float min = xb->dmin;
-    device const uint8_t * q = (device const uint8_t *)xb->qs;
-    float dl, ml;
-    uint8_t sc = xb->scales[il];
-
-    q = q + 32*(il/8) + 16*(il&1);
-    il = (il/2)%4;
-
-    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
-    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
-    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
-    for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
-    const half d_all = xb->d;
-    device const uint8_t * q = (device const uint8_t *)xb->qs;
-    device const uint8_t * h = (device const uint8_t *)xb->hmask;
-    device const int8_t * scales = (device const int8_t *)xb->scales;
-
-    q = q + 32 * (il/8) + 16 * (il&1);
-    h = h + 16 * (il&1);
-    uint8_t m = 1 << (il/2);
-    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
-                                 ((il/4)>0 ? 12  : 3);
-    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
-    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
-    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
-                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
-    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
-    const float ml = 4.f * dl;
-
-    il = (il/2) & 3;
-    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
-    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
-    dl *= coef;
-
-    for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
-    }
-}
-
-static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
-    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
-                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
-}
-
-template <typename type4x4>
-void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
-    device const uchar * q = xb->qs;
-
-    short is = (il/4) * 2;
-    q = q + (il/4) * 32 + 16 * (il&1);
-    il = il & 3;
-    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const float d   = il < 2 ? xb->d : xb->d / 16.h;
-    const float min = xb->dmin;
-    const float dl = d * sc[0];
-    const float ml = min * sc[1];
-
-    const ushort mask = il<2 ? 0x0F : 0xF0;
-    for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
-    device const uint8_t * q  = xb->qs;
-    device const uint8_t * qh = xb->qh;
-
-    short is = (il/4) * 2;
-    q  = q + 32 * (il/4) + 16 * (il&1);
-    qh = qh + 16 * (il&1);
-    uint8_t ul = 1 << (il/2);
-    il = il & 3;
-    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const float d = il < 2 ? xb->d : xb->d / 16.f;
-    const float min = xb->dmin;
-    const float dl = d * sc[0];
-    const float ml = min * sc[1];
-
-    const ushort mask  = il<2 ? 0x0F : 0xF0;
-    const float qh_val = il<2 ? 16.f : 256.f;
-    for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
-    const half d_all = xb->d;
-    device const uint8_t * ql = (device const uint8_t *)xb->ql;
-    device const uint8_t * qh = (device const uint8_t *)xb->qh;
-    device const int8_t * scales = (device const int8_t *)xb->scales;
-
-    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    qh = qh + 32*(il/8) + 16*(il&1);
-    float sc = scales[(il%2) + 2 * ((il/2))];
-    il = (il/2) & 3;
-
-    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-    const float       coef = il>1 ? 1.f/16.f          : 1.f;
-    const float ml = d_all * sc * 32.f;
-    const float dl = d_all * sc * coef;
-    for (int i = 0; i < 16; ++i) {
-        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
-                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
-    device const uint16_t * q2 = xb->qs + 4*ib32;
-    const uint32_t aux32_g = q2[0] | (q2[1] << 16);
-    const uint32_t aux32_s = q2[2] | (q2[3] << 16);
-    thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
-    const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
-    constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
-    uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
-    for (int i = 0; i < 8; ++i) {
-        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
-    }
-    grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
-    signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
-    for (int i = 0; i < 8; ++i) {
-        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint16_t * q2 = xb->qs + 4*ib32;
-    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
-    constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
-    uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
-    for (int i = 0; i < 8; ++i) {
-        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
-    }
-    grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
-    signs = ksigns_iq2xs[q2[2*il+1] >> 9];
-    for (int i = 0; i < 8; ++i) {
-        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint8_t * q3 = xb->qs + 8*ib32;
-    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
-    const uint32_t aux32 = gas[0] | (gas[1] << 16);
-    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
-    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
-    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
-    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
-    for (int i = 0; i < 4; ++i) {
-        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
-        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
-    }
-    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
-    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
-    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
-    for (int i = 0; i < 4; ++i) {
-        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
-        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint8_t * qs = xb->qs + 8*ib32;
-    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
-    const uint8_t qh = xb->qh[ib32] >> 4*il;
-    const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
-    constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
-    constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
-    for (int i = 0; i < 4; ++i) {
-        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
-        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
-    }
-    grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
-    grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
-    for (int i = 0; i < 4; ++i) {
-        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
-        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const float d = xb->d;
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
-    device const uint8_t * signs = qs + QK_K/8;
-    const uint8_t qh = xb->qh[ib32] >> 4*il;
-    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
-    constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
-    constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
-    for (int i = 0; i < 8; ++i) {
-        reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
-        reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const int ib32 = il/2;
-    il = il%2;
-    const float d = xb->d;
-    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;
-    device const uint16_t * qh = xb->qh;
-    const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
-    const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
-    const uint16_t h = qh[ib32] >> 6*il;
-    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
-    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
-    for (int i = 0; i < 4; ++i) {
-        reg[0][i] = dl * (grid1[i] & 0xf) + ml;
-        reg[1][i] = dl * (grid1[i] >>  4) + ml;
-        reg[2][i] = dl * (grid2[i] & 0xf) + ml;
-        reg[3][i] = dl * (grid2[i] >>  4) + ml;
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const int ib32 = il/2;
-    il = il%2;
-    device const uint16_t * sc = (device const uint16_t *)xb->scales;
-
-    iq1m_scale_t scale;
-    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-    const float d = scale.f16;
-
-    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
-    device const uint8_t * qh = xb->qh + 2*ib32 + il;
-
-    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
-    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
-    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
-    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
-    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
-    for (int i = 0; i < 4; ++i) {
-        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
-        reg[1][i] = dl * (grid1[i] >>  4) + ml1;
-        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
-        reg[3][i] = dl * (grid2[i] >>  4) + ml2;
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
-    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
-    const float d = xb->d;
-    uint32_t aux32;
-    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
-    for (int i = 0; i < 4; ++i) {
-        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
-        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
-        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
-        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
-        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
-    }
-}
-
-template <typename type4x4>
-void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
-    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
-    const int ib32 = il/2;
-    il = il%2;
-    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
-    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
-    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
-    const float d = (float)xb->d * (ls - 32);
-    uint32_t aux32;
-    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
-    for (int i = 0; i < 4; ++i) {
-        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
-        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
-        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
-        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
-        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
-    }
-}
-
-enum ggml_sort_order {
-    GGML_SORT_ORDER_ASC,
-    GGML_SORT_ORDER_DESC,
-};
-
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-kernel void kernel_add(
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        constant  int64_t & offs,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
-
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
-
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
-    }
-}
-
-kernel void kernel_sub(
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        constant  int64_t & offs,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
-
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
-
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
-    }
-}
-
-kernel void kernel_mul(
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
-
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
-
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
-    }
-}
-
-kernel void kernel_div(
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
-
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
-
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
-    }
-}
-
-template<typename T>
-kernel void kernel_repeat(
-        device const char * src0,
-        device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i3 = tgpig.z;
-    const int64_t i2 = tgpig.y;
-    const int64_t i1 = tgpig.x;
-
-    const int64_t i03 = i3 % ne03;
-    const int64_t i02 = i2 % ne02;
-    const int64_t i01 = i1 % ne01;
-
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1 ;
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i00 = i0 % ne00;
-        *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
-    }
-}
-
-typedef decltype(kernel_repeat<float>) kernel_repeat_t;
-
-template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
-template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
-template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
-template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
-
-// assumption: src1 is a row
-// broadcast src1 into src0
-kernel void kernel_add_row(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        constant   uint64_t & nb [[buffer(28)]],
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] + src1[tpig % nb];
-}
-
-kernel void kernel_sub_row(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        constant   uint64_t & nb [[buffer(28)]],
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] - src1[tpig % nb];
-}
-
-kernel void kernel_mul_row(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        constant   uint64_t & nb  [[buffer(28)]],
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src1[tpig % nb];
-}
-
-kernel void kernel_div_row(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        constant   uint64_t & nb  [[buffer(28)]],
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] / src1[tpig % nb];
-}
-
-kernel void kernel_scale(
-        device const float * src0,
-        device       float * dst,
-        constant     float & scale,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * scale;
-}
-
-kernel void kernel_scale_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        constant     float  & scale,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * scale;
-}
-
-kernel void kernel_clamp(
-        device const float * src0,
-        device       float * dst,
-        constant     float & min,
-        constant     float & max,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
-}
-
-kernel void kernel_relu(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_sigmoid(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_tanh(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = precise::tanh(x);
-}
-
-constant float GELU_COEF_A     = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
-
-kernel void kernel_gelu(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_quick(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-kernel void kernel_gelu_quick_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-kernel void kernel_silu(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_silu_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_sqr(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqrt(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sin(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_cos(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_sum_rows(
-        device const float * src0,
-        device       float * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tpig[[thread_position_in_grid]]) {
-    int64_t i3 = tpig.z;
-    int64_t i2 = tpig.y;
-    int64_t i1 = tpig.x;
-
-    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
-        return;
-    }
-
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
-
-    float row_sum = 0;
-
-    for (int64_t i0 = 0; i0 < ne00; i0++) {
-        row_sum += src_row[i0];
-    }
-
-    dst_row[0] = row_sum;
-}
-
-template<typename T>
-kernel void kernel_soft_max(
-        device const  char * src0,
-        device const  char * src1,
-        device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
-        threadgroup  float * buf [[threadgroup(0)]],
-        uint  tgpig[[threadgroup_position_in_grid]],
-        uint  tpitg[[thread_position_in_threadgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
-
-    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
-    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-
-    float slope = 1.0f;
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const int64_t h = i02;
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slope = pow(base, exp);
-    }
-
-    // parallel max
-    float lmax = -INFINITY;
-
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
-    }
-
-    // find the max value in the block
-    float max_val = simd_max(lmax);
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = -INFINITY;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        if (tiisg == 0) {
-            buf[sgitg] = max_val;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        max_val = buf[tiisg];
-        max_val = simd_max(max_val);
-    }
-
-    // parallel sum
-    float lsum = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
-        lsum += exp_psrc0;
-        pdst[i00] = exp_psrc0;
-    }
-
-    // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
-    threadgroup_barrier(mem_flags::mem_none);
-
-    float sum = simd_sum(lsum);
-
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = 0.0f;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        if (tiisg == 0) {
-            buf[sgitg] = sum;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        sum = buf[tiisg];
-        sum = simd_sum(sum);
-    }
-
-    const float inv_sum = 1.0f/sum;
-
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        pdst[i00] *= inv_sum;
-    }
-}
-
-template<typename T>
-kernel void kernel_soft_max_4(
-        device const  char * src0,
-        device const  char * src1,
-        device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
-        threadgroup  float * buf [[threadgroup(0)]],
-        uint  tgpig[[threadgroup_position_in_grid]],
-        uint  tpitg[[thread_position_in_threadgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
-
-    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
-    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
-    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
-
-    float slope = 1.0f;
-
-    if (max_bias > 0.0f) {
-        const int64_t h = i02;
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slope = pow(base, exp);
-    }
-
-    // parallel max
-    float4 lmax4 = -INFINITY;
-
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
-    }
-
-    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
-
-    float max_val = simd_max(lmax);
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = -INFINITY;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        if (tiisg == 0) {
-            buf[sgitg] = max_val;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        max_val = buf[tiisg];
-        max_val = simd_max(max_val);
-    }
-
-    // parallel sum
-    float4 lsum4 = 0.0f;
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
-        lsum4 += exp_psrc4;
-        pdst4[i00] = exp_psrc4;
-    }
-
-    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
-
-    // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
-    threadgroup_barrier(mem_flags::mem_none);
-
-    float sum = simd_sum(lsum);
-
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = 0.0f;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        if (tiisg == 0) {
-            buf[sgitg] = sum;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        sum = buf[tiisg];
-        sum = simd_sum(sum);
-    }
-
-    const float inv_sum = 1.0f/sum;
-
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        pdst4[i00] *= inv_sum;
-    }
-}
-
-typedef decltype(kernel_soft_max<float>)    kernel_soft_max_t;
-typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
-
-template [[host_name("kernel_soft_max_f16")]]   kernel kernel_soft_max_t   kernel_soft_max<half>;
-template [[host_name("kernel_soft_max_f32")]]   kernel kernel_soft_max_t   kernel_soft_max<float>;
-template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
-template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
-
-kernel void kernel_diag_mask_inf(
-        device const float * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant       int & n_past,
-        uint3 tpig[[thread_position_in_grid]]) {
-    const int64_t i02 = tpig[2];
-    const int64_t i01 = tpig[1];
-    const int64_t i00 = tpig[0];
-
-    if (i00 > n_past + i01) {
-        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
-    } else {
-        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
-    }
-}
-
-kernel void kernel_diag_mask_inf_8(
-        device const float4 * src0,
-        device       float4 * dst,
-        constant    int64_t & ne00,
-        constant    int64_t & ne01,
-        constant        int & n_past,
-        uint3 tpig[[thread_position_in_grid]]) {
-
-    const int64_t i = 2*tpig[0];
-
-    dst[i+0] = src0[i+0];
-    dst[i+1] = src0[i+1];
-    int64_t i4 = 4*i;
-    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
-    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
-    const int64_t i00 = i4;
-    for (int k = 3; k >= 0; --k) {
-        if (i00 + 4 + k <= n_past + i01) {
-            break;
-        }
-        dst[i+1][k] = -INFINITY;
-        if (i00 + k > n_past + i01) {
-            dst[i][k] = -INFINITY;
-        }
-    }
-}
-
-// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
-// TODO: optimize
-kernel void kernel_ssm_conv_f32(
-        device const  void * src0,
-        device const  void * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t ir = tgpig.x;
-    const int64_t i2 = tgpig.y;
-    const int64_t i3 = tgpig.z;
-
-    const int64_t nc  = ne10;
-  //const int64_t ncs = ne00;
-  //const int64_t nr  = ne01;
-  //const int64_t n_t = ne1;
-  //const int64_t n_s = ne2;
-
-    device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
-    device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
-    device       float * x = (device       float *) ((device       char *) dst  + ir*nb0  + i2*nb1  + i3*nb2);
-
-    float sumf = 0.0f;
-
-    for (int64_t i0 = 0; i0 < nc; ++i0) {
-        sumf += s[i0] * c[i0];
-    }
-
-    x[0] = sumf;
-}
-
-// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
-// TODO: optimize
-kernel void kernel_ssm_scan_f32(
-        device const void * src0,
-        device const void * src1,
-        device const void * src2,
-        device const void * src3,
-        device const void * src4,
-        device const void * src5,
-        device      float * dst,
-        constant  int64_t & d_state,
-        constant  int64_t & d_inner,
-        constant  int64_t & n_seq_tokens,
-        constant  int64_t & n_seqs,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant uint64_t & nb20,
-        constant uint64_t & nb21,
-        constant uint64_t & nb22,
-        constant uint64_t & nb30,
-        constant uint64_t & nb31,
-        constant uint64_t & nb40,
-        constant uint64_t & nb41,
-        constant uint64_t & nb42,
-        constant uint64_t & nb50,
-        constant uint64_t & nb51,
-        constant uint64_t & nb52,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t ir = tgpig.x;
-    const int64_t i3 = tgpig.y;
-
-    const int64_t nc  = d_state;
-  //const int64_t nr  = d_inner;
-    const int64_t n_t = n_seq_tokens;
-  //const int64_t n_s = n_seqs;
-
-    for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
-        device const float * x  = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*nb31);
-        device const float * B  = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
-        device const float * C  = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
-        device       float * y  = (device       float *) ((device       char *) dst  + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
-        device       float * s  = (device       float *) ((device       char *) dst  + ir*nb01 + i3*nb02 +    nb13);
-
-        if (i2 > 0) {
-            s0 = s;
-        }
-
-        // i1 == 0
-        float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
-        float x_dt = x[0] * dt_soft_plus;
-        float sumf = 0.0f;
-
-        for (int64_t i0 = 0; i0 < nc; ++i0) {
-            int64_t i = i0;
-            float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
-            sumf += state * C[i0];
-            s[i] = state;
-        }
-
-        y[0] = sumf;
-    }
-}
-
-kernel void kernel_norm(
-        device const  void * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant     float & eps,
-        threadgroup float  * sum [[threadgroup(0)]],
-        uint tgpig[[threadgroup_position_in_grid]],
-        uint tpitg[[thread_position_in_threadgroup]],
-        uint   ntg[[threads_per_threadgroup]]) {
-    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
-    // MEAN
-    // parallel sum
-    sum[tpitg] = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        sum[tpitg] += x[i00];
-    }
-    // reduce
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg/2; i > 0; i /= 2) {
-        if (tpitg < i) {
-            sum[tpitg] += sum[tpitg + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
-    const float mean  = sum[0] / ne00;
-
-    // recenter and VARIANCE
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    device float * y = dst + tgpig*ne00;
-    sum[tpitg] = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        y[i00] = x[i00] - mean;
-        sum[tpitg] += y[i00] * y[i00];
-    }
-
-    // reduce
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg/2; i > 0; i /= 2) {
-        if (tpitg < i) {
-            sum[tpitg] += sum[tpitg + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
-    const float variance = sum[0] / ne00;
-
-    const float scale = 1.0f/sqrt(variance + eps);
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        y[i00] = y[i00] * scale;
-    }
-}
-
-kernel void kernel_rms_norm(
-        device const  void * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant     float & eps,
-        threadgroup float  * buf [[threadgroup(0)]],
-        uint tgpig[[threadgroup_position_in_grid]],
-        uint tpitg[[thread_position_in_threadgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint   ntg[[threads_per_threadgroup]]) {
-    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
-
-    float4 sumf = 0;
-    float all_sum = 0;
-
-    // parallel sum
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        sumf += x[i00] * x[i00];
-    }
-    all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
-    all_sum = simd_sum(all_sum);
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = 0.0f;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        if (tiisg == 0) {
-            buf[sgitg] = all_sum;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        all_sum = buf[tiisg];
-        all_sum = simd_sum(all_sum);
-    }
-
-    const float mean  = all_sum/ne00;
-    const float scale = 1.0f/sqrt(mean + eps);
-
-    device float4 * y = (device float4 *) (dst + tgpig*ne00);
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        y[i00] = x[i00] * scale;
-    }
-}
-
-kernel void kernel_group_norm(
-        device const float * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int32_t & n_groups,
-        constant     float & eps,
-        threadgroup float  * buf [[threadgroup(0)]],
-        uint tgpig[[threadgroup_position_in_grid]],
-        uint tpitg[[thread_position_in_threadgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint   ntg[[threads_per_threadgroup]]) {
-    const int64_t ne = ne00*ne01*ne02;
-    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
-
-    int start = tgpig * gs;
-    int end   = start + gs;
-
-    start += tpitg;
-
-    if (end >= ne) {
-        end = ne;
-    }
-
-    float tmp = 0.0f; // partial sum for thread in warp
-
-    for (int j = start; j < end; j += ntg) {
-        tmp += src0[j];
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    tmp = simd_sum(tmp);
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = 0.0f;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        if (tiisg == 0) {
-            buf[sgitg] = tmp;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        tmp = buf[tiisg];
-        tmp = simd_sum(tmp);
-    }
-
-    const float mean = tmp / gs;
-    tmp = 0.0f;
-
-    for (int j = start; j < end; j += ntg) {
-        float xi = src0[j] - mean;
-        dst[j] = xi;
-        tmp += xi * xi;
-    }
-
-    tmp = simd_sum(tmp);
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = 0.0f;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        if (tiisg == 0) {
-            buf[sgitg] = tmp;
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        tmp = buf[tiisg];
-        tmp = simd_sum(tmp);
-    }
-
-    const float variance = tmp / gs;
-    const float scale = 1.0f/sqrt(variance + eps);
-    for (int j = start; j < end; j += ntg) {
-        dst[j] *= scale;
-    }
-}
-
-// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
-// il indicates where the q4 quants begin (0 or QK4_0/4)
-// we assume that the yl's have been multiplied with the appropriate scale factor
-// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
-inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
-    float d = qb_curr->d;
-
-    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
-
-    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
-
-    for (int i = 0; i < 8; i += 2) {
-        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
-        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
-        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
-        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
-    }
-
-    return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
-}
-
-// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
-// il indicates where the q4 quants begin (0 or QK4_0/4)
-// we assume that the yl's have been multiplied with the appropriate scale factor
-// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
-inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
-    float d = qb_curr->d;
-    float m = qb_curr->m;
-
-    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
-
-    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
-
-    for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
-        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
-        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
-        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
-    }
-
-    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
-}
-
-// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
-// il indicates where the q5 quants begin (0 or QK5_0/4)
-// we assume that the yl's have been multiplied with the appropriate scale factor
-// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
-inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
-    float d = qb_curr->d;
-
-    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
-
-    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);
-           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
-
-    for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
-        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
-        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
-        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
-    }
-
-    return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
-}
-
-// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
-// il indicates where the q5 quants begin (0 or QK5_1/4)
-// we assume that the yl's have been multiplied with the appropriate scale factor
-// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
-inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
-    float d = qb_curr->d;
-    float m = qb_curr->m;
-
-    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
-
-    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);
-           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
-
-    for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
-        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
-        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
-        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
-    }
-
-    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
-}
-
-// putting them in the kernel cause a significant performance penalty
-#define N_DST 4        // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
-//Note: This is a template, but strictly speaking it only applies to
-//      quantizations where the block size is 32. It also does not
-//      guard against the number of rows not being divisible by
-//      N_DST, so this is another explicit assumption of the implementation.
-template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32_impl(
-        device const void  * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                     uint3   tgpig,
-                     uint    tiisg,
-                     uint    sgitg) {
-    const int nb = ne00/QK4_0;
-
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * nsg + sgitg) * nr;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-  //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-  //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
-    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
-
-    // pointers to src0 rows
-    device const block_q_type * ax[nr];
-    for (int row = 0; row < nr; ++row) {
-        const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-
-        ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
-    }
-
-    float yl[16]; // src1 vector cache
-    float sumf[nr] = {0.f};
-
-    const int ix = (tiisg/2);
-    const int il = (tiisg%2)*8;
-
-    device const float * yb = y + ix * QK4_0 + il;
-
-    // each thread in a SIMD group deals with half a block.
-    for (int ib = ix; ib < nb; ib += nw/2) {
-        float sumy[2] = { 0.f, 0.f };
-
-#pragma unroll
-        for (int i = 0; i < 8; i += 2) {
-            sumy[0]  += yb[i +  0] + yb[i +  1];
-            yl[i + 0] = yb[i +  0];
-            yl[i + 1] = yb[i +  1]/256.f;
-
-            sumy[1]  += yb[i + 16] + yb[i + 17];
-            yl[i + 8] = yb[i + 16]/16.f;
-            yl[i + 9] = yb[i + 17]/4096.f;
-        }
-
-#pragma unroll
-        for (int row = 0; row < nr; row++) {
-            sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
-        }
-
-        yb += QK4_0 * 16;
-    }
-
-    for (int row = 0; row < nr; ++row) {
-        const float tot = simd_sum(sumf[row]);
-        if (tiisg == 0 && first_row + row < ne01) {
-            dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
-        }
-    }
-}
-
-kernel void kernel_mul_mv_q4_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-kernel void kernel_mul_mv_q4_1_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-kernel void kernel_mul_mv_q5_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-kernel void kernel_mul_mv_q5_1_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-
-#define NB_Q8_0 8
-
-void kernel_mul_mv_q8_0_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-    const int nr  = N_DST;
-    const int nsg = N_SIMDGROUP;
-    const int nw  = N_SIMDWIDTH;
-
-    const int nb = ne00/QK8_0;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * nsg + sgitg) * nr;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-  //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-  //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
-    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
-
-    // pointers to src0 rows
-    device const block_q8_0 * ax[nr];
-    for (int row = 0; row < nr; ++row) {
-        const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-
-        ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
-    }
-
-    float yl[NB_Q8_0];
-    float sumf[nr]={0.f};
-
-    const int ix = tiisg/4;
-    const int il = tiisg%4;
-
-    device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
-
-    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
-    for (int ib = ix; ib < nb; ib += nw/4) {
-        for (int i = 0; i < NB_Q8_0; ++i) {
-            yl[i] = yb[i];
-        }
-
-        for (int row = 0; row < nr; row++) {
-            device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
-            float sumq = 0.f;
-            for (int iq = 0; iq < NB_Q8_0; ++iq) {
-                sumq += qs[iq] * yl[iq];
-            }
-            sumf[row] += sumq*ax[row][ib].d;
-        }
-
-        yb += NB_Q8_0 * nw;
-    }
-
-    for (int row = 0; row < nr; ++row) {
-        const float tot = simd_sum(sumf[row]);
-        if (tiisg == 0 && first_row + row < ne01) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_q8_0_f32")]]
-kernel void kernel_mul_mv_q8_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
-}
-
-#define N_MV_T_T 4
-
-template<typename T0, typename T04, typename T1, typename T14>
-void kernel_mul_mv_impl(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb00,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne11,
-                   int64_t   ne12,
-                  uint64_t   nb10,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-                   uint3     tgpig,
-                   uint      tiisg) {
-    const int64_t r0 = tgpig.x;
-    const int64_t rb = tgpig.y*N_MV_T_T;
-    const int64_t im = tgpig.z;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-
-    device const T0 * x = (device const T0 *) (src0 + offset0);
-
-    if (ne00 < 128) {
-        for (int row = 0; row < N_MV_T_T; ++row) {
-            int r1 = rb + row;
-            if (r1 >= ne11) {
-                break;
-            }
-
-            const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-            device const T1 * y = (device const T1 *) (src1 + offset1);
-
-            float sumf = 0;
-            for (int i = tiisg; i < ne00; i += 32) {
-                sumf += (T0) x[i] * (T1) y[i];
-            }
-
-            float all_sum = simd_sum(sumf);
-            if (tiisg == 0) {
-                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-            }
-        }
-    } else {
-        device const T04 * x4 = (device const T04 *) x;
-        for (int row = 0; row < N_MV_T_T; ++row) {
-            int r1 = rb + row;
-            if (r1 >= ne11) {
-                break;
-            }
-
-            const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-            device const T1  * y  = (device const T1  *) (src1 + offset1);
-            device const T14 * y4 = (device const T14 *) y;
-
-            float sumf = 0;
-            for (int i = tiisg; i < ne00/4; i += 32) {
-                for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
-            }
-
-            float all_sum = simd_sum(sumf);
-            if (tiisg == 0) {
-                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
-                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-            }
-        }
-    }
-}
-
-template<typename T0, typename T04, typename T1, typename T14>
-kernel void kernel_mul_mv(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]]) {
-    kernel_mul_mv_impl<T0, T04, T1, T14>(
-        src0,
-        src1,
-        dst,
-        ne00,
-        ne01,
-        ne02,
-        nb00,
-        nb01,
-        nb02,
-        nb03,
-        ne10,
-        ne11,
-        ne12,
-        nb10,
-        nb11,
-        nb12,
-        nb13,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg);
-}
-
-typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
-
-template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv<float,  float4,  float,  float4>;
-template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   float,  float4>;
-template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   half,   half4>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float,  float4>;
-template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
-#endif
-
-template<typename T, typename T4>
-kernel void kernel_mul_mv_1row(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]]) {
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-    const int64_t im = tgpig.z;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const T     * x = (device const T     *) (src0 + offset0);
-    device const float * y = (device const float *) (src1 + offset1);
-
-    float sumf = 0;
-    if (ne00 < 128) {
-        for (int i = tiisg; i < ne00; i += 32) {
-            sumf += (float) x[i] * (float) y[i];
-        }
-        float all_sum = simd_sum(sumf);
-        if (tiisg == 0) {
-            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-        }
-    } else {
-        device const T4     * x4 = (device const T4     *) x;
-        device const float4 * y4 = (device const float4 *) y;
-
-        for (int i = tiisg; i < ne00/4; i += 32) {
-            for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
-        }
-
-        float all_sum = simd_sum(sumf);
-
-        if (tiisg == 0) {
-            for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
-            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-        }
-    }
-}
-
-typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
-
-template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row<half,   half4>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
-#endif
-
-// Assumes row size (ne00) is a multiple of 4
-template<typename T, typename T4>
-kernel void kernel_mul_mv_l4(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
-
-    const int nrows = ne11;
-    const int64_t r0 = tgpig.x;
-    const int64_t im = tgpig.z;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-
-    device const T4 * x4 = (device const T4 *) (src0 + offset0);
-
-    for (int r1 = 0; r1 < nrows; ++r1) {
-        const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-        device const float4 * y4 = (device const float4 *) (src1 + offset1);
-
-        float sumf = 0;
-        for (int i = tiisg; i < ne00/4; i += 32) {
-            for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
-        }
-
-        float all_sum = simd_sum(sumf);
-        if (tiisg == 0) {
-            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-        }
-    }
-}
-
-typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
-
-template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
-#endif
-
-static float rope_yarn_ramp(const float low, const float high, const int i0) {
-    const float y = (i0 / 2 - low) / max(0.001f, high - low);
-    return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
-    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
-    thread float * cos_theta, thread float * sin_theta) {
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
-    }
-    *cos_theta = cos(theta) * mscale;
-    *sin_theta = sin(theta) * mscale;
-}
-
-// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
-// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
-    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
-}
-
-static void rope_yarn_corr_dims(
-    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
-) {
-    // start and end correction dims
-    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
-    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
-}
-
-template<typename T>
-kernel void kernel_rope_norm(
-        device const    void * src0,
-        device const int32_t * src1,
-        device const   float * src2,
-        device         float * dst,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant     int64_t & ne03,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant    uint64_t & nb03,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant     int64_t & ne2,
-        constant     int64_t & ne3,
-        constant    uint64_t & nb0,
-        constant    uint64_t & nb1,
-        constant    uint64_t & nb2,
-        constant    uint64_t & nb3,
-        constant         int & n_past,
-        constant         int & n_dims,
-        constant         int & n_ctx_orig,
-        constant       float & freq_base,
-        constant       float & freq_scale,
-        constant       float & ext_factor,
-        constant       float & attn_factor,
-        constant       float & beta_fast,
-        constant       float & beta_slow,
-        uint  tiitg[[thread_index_in_threadgroup]],
-        uint3 tptg[[threads_per_threadgroup]],
-        uint3 tgpig[[threadgroup_position_in_grid]]) {
-    const int64_t i3 = tgpig[2];
-    const int64_t i2 = tgpig[1];
-    const int64_t i1 = tgpig[0];
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
-    device const int32_t * pos = src1;
-
-    const float theta_base = (float) pos[i2];
-    const float inv_ndims = -1.f/n_dims;
-
-    float cos_theta;
-    float sin_theta;
-
-    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
-        if (i0 < n_dims) {
-            const int64_t ic = i0/2;
-
-            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
-
-            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
-
-            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-            const float x0 = src[0];
-            const float x1 = src[1];
-
-            dst_data[0] = x0*cos_theta - x1*sin_theta;
-            dst_data[1] = x0*sin_theta + x1*cos_theta;
-        } else {
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-            dst_data[0] = src[0];
-            dst_data[1] = src[1];
-        }
-    }
-}
-
-template<typename T>
-kernel void kernel_rope_neox(
-        device const    void * src0,
-        device const int32_t * src1,
-        device const   float * src2,
-        device         float * dst,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant     int64_t & ne03,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant    uint64_t & nb03,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant     int64_t & ne2,
-        constant     int64_t & ne3,
-        constant    uint64_t & nb0,
-        constant    uint64_t & nb1,
-        constant    uint64_t & nb2,
-        constant    uint64_t & nb3,
-        constant         int & n_past,
-        constant         int & n_dims,
-        constant         int & n_ctx_orig,
-        constant       float & freq_base,
-        constant       float & freq_scale,
-        constant       float & ext_factor,
-        constant       float & attn_factor,
-        constant       float & beta_fast,
-        constant       float & beta_slow,
-        uint  tiitg[[thread_index_in_threadgroup]],
-        uint3 tptg[[threads_per_threadgroup]],
-        uint3 tgpig[[threadgroup_position_in_grid]]) {
-    const int64_t i3 = tgpig[2];
-    const int64_t i2 = tgpig[1];
-    const int64_t i1 = tgpig[0];
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
-    device const int32_t * pos = src1;
-
-    const float theta_base = (float) pos[i2];
-    const float inv_ndims = -1.f/n_dims;
-
-    float cos_theta;
-    float sin_theta;
-
-    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
-        if (i0 < n_dims) {
-            const int64_t ic = i0/2;
-
-            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
-
-            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
-
-            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-            const float x0 = src[0];
-            const float x1 = src[n_dims/2];
-
-            dst_data[0]        = x0*cos_theta - x1*sin_theta;
-            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
-        } else {
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-            dst_data[0] = src[0];
-            dst_data[1] = src[1];
-        }
-    }
-}
-
-typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
-typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
-
-template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
-template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
-
-template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
-template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
-
-typedef void (im2col_t)(
-        device const float * x,
-        device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3  tgpg[[threadgroups_per_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]);
-
-template <typename T>
-kernel void kernel_im2col(
-        device const float * x,
-        device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3  tgpg[[threadgroups_per_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
-    const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
-
-    const int32_t offset_dst =
-        (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
-        (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
-
-    device T * pdst = (device T *) (dst);
-
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-        pdst[offset_dst] = 0.0f;
-    } else {
-        const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
-        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
-    }
-}
-
-template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
-template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
-
-typedef void (im2col_ext_t)(
-        device const float * x,
-        device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        constant   int32_t & N,
-        constant   int32_t & KH,
-        constant   int32_t & KW,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3  tgpg[[threadgroups_per_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]);
-
-template <typename T>
-kernel void kernel_im2col_ext(
-        device const float * x,
-        device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        constant   int32_t & N,
-        constant   int32_t & KH,
-        constant   int32_t & KW,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
-    const int32_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
-
-    const int32_t d = tgpig[0] / CHW;
-    const int32_t chw = tgpig[0] % CHW;
-    const int32_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
-    const int32_t HW = tgpig[0] % KHW;
-
-    const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
-    if (tpitg_0 >= N) {
-        return;
-    }
-
-    const int32_t tpitg_1 = HW / KW;
-    const int32_t tpitg_2 = HW % KW;
-
-    const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
-    const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
-
-    const int32_t offset_dst =
-        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
-        (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
-
-    device T * pdst = (device T *) (dst);
-
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-        pdst[offset_dst] = 0.0f;
-    } else {
-        const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
-        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
-    }
-}
-
-template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
-template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
-
-kernel void kernel_upscale_f32(
-    device  const char * src0,
-    device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant     float & sf0,
-    constant     float & sf1,
-    constant     float & sf2,
-    constant     float & sf3,
-    uint3 tgpig[[threadgroup_position_in_grid]],
-    uint3 tpitg[[thread_position_in_threadgroup]],
-    uint3   ntg[[threads_per_threadgroup]]) {
-
-    const int64_t i3 = tgpig.z;
-    const int64_t i2 = tgpig.y;
-    const int64_t i1 = tgpig.x;
-
-    const int64_t i03 = i3/sf3;
-    const int64_t i02 = i2/sf2;
-    const int64_t i01 = i1/sf1;
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int64_t i00 = i0/sf0;
-
-        device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1  +  i0*nb0);
-
-        dst_ptr[0] = src0_ptr[0];
-    }
-}
-
-kernel void kernel_pad_f32(
-    device  const char * src0,
-    device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    uint3 tgpig[[threadgroup_position_in_grid]],
-    uint3 tpitg[[thread_position_in_threadgroup]],
-    uint3   ntg[[threads_per_threadgroup]]) {
-
-    const int64_t i3 = tgpig.z;
-    const int64_t i2 = tgpig.y;
-    const int64_t i1 = tgpig.x;
-
-    const int64_t i03 = i3;
-    const int64_t i02 = i2;
-    const int64_t i01 = i1;
-
-    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
-    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
-
-    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
-        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-            if (i0 < ne00) {
-                dst_ptr[i0] = src0_ptr[i0];
-            } else {
-                dst_ptr[i0] = 0.0f;
-            }
-        }
-
-        return;
-    }
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        dst_ptr[i0] = 0.0f;
-    }
-}
-
-kernel void kernel_arange_f32(
-    device        char * dst,
-    constant   int64_t & ne0,
-    constant   float   & start,
-    constant   float   & step,
-    uint3 tgpig[[threadgroup_position_in_grid]],
-    uint3 tpitg[[thread_position_in_threadgroup]],
-    uint3   ntg[[threads_per_threadgroup]]) {
-
-    device float * dst_ptr = (device float *) dst;
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        dst_ptr[i0] = start + step * i0;
-    }
-}
-
-kernel void kernel_timestep_embedding_f32(
-    device  const char * src0,
-    device        char * dst,
-    constant  uint64_t & nb1,
-    constant  int      & dim,
-    constant  int      & max_period,
-    uint3 tgpig[[threadgroup_position_in_grid]],
-    uint3 tpitg[[thread_position_in_threadgroup]],
-    uint3   ntg[[threads_per_threadgroup]]) {
-
-    int i = tgpig.x;
-    device float * embed_data = (device float *)(dst +  i*nb1);
-
-    int half_ = dim / 2;
-    for (int j = tpitg.x; j < half_; j += ntg.x) {
-        float timestep = ((device float *)src0)[i];
-        float freq = (float)exp(-log((float)max_period) * j / half_);
-        float arg = timestep * freq;
-        embed_data[j        ] = cos(arg);
-        embed_data[j + half_] = sin(arg);
-    }
-
-    if (dim % 2 != 0 && tpitg.x == 0) {
-        embed_data[dim] = 0.f;
-    }
-}
-
-// bitonic sort implementation following the CUDA kernels as reference
-typedef void (argsort_t)(
-        device const float  * x,
-        device     int32_t  * dst,
-        constant   int64_t  & ncols,
-        constant   int64_t  & ncols_pad,
-        threadgroup int32_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]]);
-
-template<ggml_sort_order order>
-kernel void kernel_argsort_f32_i32(
-        device const float   * x,
-        device       int32_t * dst,
-        constant     int64_t & ncols,
-        constant     int64_t & ncols_pad,
-        threadgroup int32_t  * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]]) {
-    // bitonic sort
-    int col = tpitg[0];
-    int row = tgpig[1];
-
-    if (col >= ncols_pad) return;
-
-    device const float   * x_row   = x + row * ncols;
-    threadgroup int32_t  * dst_row = shared_values;
-
-    // initialize indices
-    dst_row[col] = col;
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    for (int k = 2; k <= ncols_pad; k *= 2) {
-        for (int j = k / 2; j > 0; j /= 2) {
-            int ixj = col ^ j;
-            if (ixj > col) {
-                if ((col & k) == 0) {
-                    if (dst_row[col] >= ncols ||
-                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
-                    ) {
-                        SWAP(dst_row[col], dst_row[ixj]);
-                    }
-                } else {
-                    if (dst_row[ixj] >= ncols ||
-                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
-                    ) {
-                        SWAP(dst_row[col], dst_row[ixj]);
-                    }
-                }
-            }
-            threadgroup_barrier(mem_flags::mem_threadgroup);
-        }
-    }
-
-    // copy the result to dst without the padding
-    if (col < ncols) {
-        dst[row * ncols + col] = dst_row[col];
-    }
-}
-
-template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
-template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
-
-kernel void kernel_leaky_relu_f32(
-        device const float * src0,
-        device       float * dst,
-        constant     float & slope,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
-}
-
-// ref: https://arxiv.org/pdf/2307.08691.pdf
-template<
-    typename q_t,     // query types in shared memory
-    typename q4_t,
-    typename q8x8_t,
-    typename k_t,     // key types in shared memory
-    typename k4x4_t,
-    typename k8x8_t,
-    typename v_t,     // value types in shared memory
-    typename v4x4_t,
-    typename v8x8_t,
-    typename qk_t,    // Q*K types
-    typename qk8x8_t,
-    typename s_t,     // soft-max types
-    typename s8x8_t,
-    typename o_t,     // attention accumulation types
-    typename o4_t,
-    typename o8x8_t,
-    typename kd4x4_t, // key type in device memory
-    short nl_k,
-    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
-    typename vd4x4_t, // key type in device memory
-    short nl_v,
-    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short D,         // head size
-    short Q  = 8,    // queries per threadgroup
-    short KV = 8,    // key/value processed per each simdgroup
-    short C  = 32>   // cache items per threadgroup
-kernel void kernel_flash_attn_ext(
-        device const  char * q,
-        device const  char * k,
-        device const  char * v,
-        device const  char * mask,
-        device       float * dst,
-        constant   int32_t & ne01,
-        constant   int32_t & ne02,
-        constant   int32_t & ne03,
-        constant  uint32_t & nb01,
-        constant  uint32_t & nb02,
-        constant  uint32_t & nb03,
-        constant   int32_t & ne11,
-        constant   int32_t & ne_12_2, // assume K and V are same shape
-        constant   int32_t & ne_12_3,
-        constant  uint32_t & nb_12_1,
-        constant  uint32_t & nb_12_2,
-        constant  uint32_t & nb_12_3,
-        constant  uint32_t & nb31,
-        constant   int32_t & ne1,
-        constant   int32_t & ne2,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint16_t & n_head_log2,
-        constant     float & logit_softcap,
-        threadgroup   half * shared [[threadgroup(0)]],
-        ushort3  tgpig[[threadgroup_position_in_grid]],
-        ushort3    ntg[[threads_per_threadgroup]],
-        ushort   tiisg[[thread_index_in_simdgroup]],
-        ushort   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const short nsg = ntg.y; // number of simdgroups
-
-    const int iq3 = tgpig[2];
-    const int iq2 = tgpig[1];
-    const int iq1 = tgpig[0]*Q;
-
-    const short D4  = D/4;
-    const short D8  = D/8;
-    const short D16 = D/16;
-    const short NW  = N_SIMDWIDTH;
-    const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
-
-    const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
-    const short T  = D + 2*TS; // shared memory size per query in (half)
-
-    threadgroup q_t  * sq  = (threadgroup q_t  *) (shared +              0*D); // holds the query data
-    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared +              0*D); // same as above but in q4_t
-    threadgroup o_t  * so  = (threadgroup o_t  *) (shared +              0*D); // reuse query data for accumulation
-    threadgroup o4_t * so4 = (threadgroup o4_t *) (shared +              0*D); // same as above but in o4_t
-    threadgroup s_t  * ss  = (threadgroup s_t  *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
-
-    threadgroup k_t    * sk    = (threadgroup k_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
-    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
-
-    threadgroup v_t    * sv    = (threadgroup v_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
-    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
-
-    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o8x8_t lo[D8];
-
-    // load heads from Q to shared memory
-    for (short j = sgitg; j < Q; j += nsg) {
-        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
-
-        for (short i = tiisg; i < D4; i += NW) {
-            if (iq1 + j < ne01) {
-                sq4[j*D4 + i] = (q4_t) q4[i];
-            } else {
-                sq4[j*D4 + i] = (q4_t) 0.0f;
-            }
-        }
-    }
-
-    // zero out lo
-    for (short i = 0; i < D8; ++i) {
-        lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
-    }
-
-    // zero out shared memory SH
-    for (short j = 0; j < Q; ++j) {
-        for (short i = tiisg; i < SH; i += NW) {
-            ss[j*TS + i] = 0.0f;
-        }
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    {
-        half S[Q] = { [0 ... Q-1] = 0.0f };
-        half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
-
-        // thread indices inside the simdgroup
-        // TODO: see if we can utilize quad-group functions for better performance
-        //       https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
-        const short tx = tiisg%4;
-        const short ty = tiisg/4;
-
-        // broadcast kv
-        //const short rk2 = ne02/ne12;
-        //const short rk3 = ne03/ne13;
-
-        const short ikv2 = iq2/(ne02/ne_12_2);
-        const short ikv3 = iq3/(ne03/ne_12_3);
-
-        // load the queries from shared memory into local memory
-        q8x8_t mq[D8];
-
-        for (short i = 0; i < D8; ++i) {
-            simdgroup_load(mq[i], sq + i*8, D);
-        }
-
-        const bool has_mask = mask != q;
-
-        half slope = 1.0f;
-
-        // ALiBi
-        if (max_bias > 0.0f) {
-            const short h = iq2;
-
-            const half  base = h < n_head_log2 ? m0 : m1;
-            const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-            slope = pow(base, exph);
-        }
-
-        // loop over the KV cache
-        // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
-            const int ic = ic0 + C*sgitg;
-            if (ic >= ne11) {
-                break;
-            }
-
-            if (has_mask) {
-                // used to detect blocks full of -INF
-                half smax = -INFINITY;
-
-                // load the mask in shared memory
-                #pragma unroll(Q)
-                for (short j = 0; j < Q; ++j) {
-                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
-
-                    const half m = pm[ic + tiisg];
-
-                    ss[j*TS + C + tiisg] = m;
-                    smax = max(smax, m);
-                }
-
-                smax = simd_max(smax);
-
-                if (smax == -INFINITY) {
-                    continue;
-                }
-            }
-
-            // Q*K^T
-            {
-                for (short cc = 0; cc < C/8; ++cc) {
-                    qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
-
-                    // this is compile-time check, so it does not have runtime overhead
-                    if (is_same<kd4x4_t, k4x4_t>::value) {
-                        // we can read directly from global memory
-                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
-                        #pragma unroll(D8)
-                        for (short i = 0; i < D8; ++i) {
-                            k8x8_t mk;
-                            simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
-
-                            simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
-                        }
-                    } else {
-                        for (short ii = 0; ii < D16; ii += 4) {
-                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
-                            if (D16%4 == 0) {
-                                // the head is evenly divisible by 4*16 = 64, so no need for bound checks
-                                {
-                                    k4x4_t tmp;
-                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
-                                    sk4x4[4*ty + tx] = tmp;
-                                }
-
-                                simdgroup_barrier(mem_flags::mem_threadgroup);
-
-                                #pragma unroll(4)
-                                for (short k = 0; k < 4; ++k) {
-                                    k8x8_t mk;
-
-                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
-
-                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
-                                }
-                            } else {
-                                if (ii + tx < D16) {
-                                    k4x4_t tmp;
-                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
-                                    sk4x4[4*ty + tx] = tmp;
-                                }
-
-                                simdgroup_barrier(mem_flags::mem_threadgroup);
-
-                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
-                                    k8x8_t mk;
-
-                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
-
-                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
-                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
-                                }
-                            }
-                        }
-                    }
-
-                    // cast qk_t -> s_t
-                    //s8x8_t mqks(1.0f);
-                    //simdgroup_multiply(mqks, mqk, mqks);
-                    //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
-
-                    simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
-                }
-            }
-
-            // online softmax
-            {
-                for (ushort j = 0; j < Q; ++j) {
-                    const half m = M[j];
-
-                    // scale and apply the logitcap / mask
-                    half s = ss[j*TS + tiisg]*scale;
-
-                    if (logit_softcap != 0.0f) {
-                        s = logit_softcap*precise::tanh(s);
-                    }
-
-                    // mqk = mqk + mask*slope
-                    s += slope*ss[j*TS + C + tiisg];
-
-                    M[j] = simd_max(max(M[j], s));
-
-                    const half ms = exp(m - M[j]);
-                    const half vs = exp(s - M[j]);
-
-                    S[j] = S[j]*ms + simd_sum(vs);
-
-                    // the P matrix from the paper (Q rows, C columns)
-                    ss[j*TS + tiisg] = vs;
-
-                    // create a QxQ diagonal matrix for rescaling the output
-                    if (tiisg == j) {
-                        ss[j*TS + 2*C + j] = ms;
-                    }
-                }
-            }
-
-            // O = diag(ms)*O
-            {
-                s8x8_t mm;
-                simdgroup_load(mm, ss + 2*C, TS, 0, false);
-
-                #pragma unroll(D8)
-                for (short i = 0; i < D8; ++i) {
-                    simdgroup_multiply(lo[i], mm, lo[i]);
-                }
-            }
-
-            // O = O + (Q*K^T)*V
-            {
-                for (short cc = 0; cc < C/8; ++cc) {
-                    s8x8_t ms;
-                    simdgroup_load(ms, ss + 8*cc, TS, 0, false);
-
-                    if (is_same<vd4x4_t, v4x4_t>::value) {
-                        // we can read directly from global memory
-                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
-                        #pragma unroll(D8)
-                        for (short i = 0; i < D8; ++i) {
-                            v8x8_t mv;
-                            simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
-
-                            simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
-                        }
-                    } else {
-                        for (short ii = 0; ii < D16; ii += 4) {
-                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
-                            if (D16%4 == 0) {
-                                // no need for bound checks
-                                {
-                                    v4x4_t tmp;
-                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
-                                    sv4x4[4*ty + tx] = tmp;
-                                }
-
-                                simdgroup_barrier(mem_flags::mem_threadgroup);
-
-                                #pragma unroll(4)
-                                for (short k = 0; k < 4; ++k) {
-                                    v8x8_t mv;
-
-                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
-
-                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
-                                }
-                            } else {
-                                if (ii + tx < D16) {
-                                    v4x4_t tmp;
-                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
-                                    sv4x4[4*ty + tx] = tmp;
-                                }
-
-                                simdgroup_barrier(mem_flags::mem_threadgroup);
-
-                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
-                                    v8x8_t mv;
-
-                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
-
-                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        }
-
-        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
-        for (short j = 0; j < Q; ++j) {
-            if (tiisg == 0) {
-                ss[j*TS + 0] = S[j];
-                ss[j*TS + 1] = M[j];
-            }
-        }
-    }
-
-    // reduce the warps sequentially
-    for (ushort sg = 1; sg < nsg; ++sg) {
-        half S = { 0.0f };
-        half M = { -__FLT16_MAX__/2 };
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        // each simdgroup stores its output to shared memory, reusing sq
-        if (sgitg == sg) {
-            for (short i = 0; i < D8; ++i) {
-                simdgroup_store(lo[i], so + i*8, D, 0, false);
-            }
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        // the first simdgroup accumulates the results from the other simdgroups
-        if (sgitg == 0) {
-            for (short j = 0; j < Q; ++j) {
-                const half S0 = ss[j*TS +         0];
-                const half S1 = ss[j*TS + sg*SH + 0];
-
-                const half M0 = ss[j*TS +         1];
-                const half M1 = ss[j*TS + sg*SH + 1];
-
-                M = max(M0, M1);
-
-                const half ms0 = exp(M0 - M);
-                const half ms1 = exp(M1 - M);
-
-                S = S0*ms0 + S1*ms1;
-
-                if (tiisg == 0) {
-                    ss[j*TS + 0] = S;
-                    ss[j*TS + 1] = M;
-
-                    ss[j*TS + 2*C + j        ] = ms0;
-                    ss[j*TS + 2*C + j + sg*SH] = ms1;
-                }
-            }
-
-            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
-            {
-                s8x8_t ms0;
-                s8x8_t ms1;
-
-                simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
-                simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
-
-                #pragma unroll(D8)
-                for (short i = 0; i < D8; ++i) {
-                    o8x8_t t;
-
-                    simdgroup_load    (t, so + i*8, D, 0, false);
-                    simdgroup_multiply(t, ms1, t);
-
-                    simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
-                }
-            }
-        }
-    }
-
-    // store result to shared memory (reuse sq)
-    if (sgitg == 0) {
-        for (short i = 0; i < D8; ++i) {
-            simdgroup_store(lo[i], so + i*8, D, 0, false);
-        }
-    }
-
-    device float4 * dst4 = (device float4 *) dst;
-
-    // final rescale with 1/S and store to global memory
-    if (sgitg == 0) {
-        for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
-            const float S = ss[j*TS + 0];
-
-            for (short i = tiisg; i < D4; i += NW) {
-                dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
-            }
-        }
-    }
-}
-
-// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
-//       template to be able to explore different combinations
-//
-#define FA_TYPES \
-    half,  half4,   simdgroup_half8x8,  \
-    half,  half4x4, simdgroup_half8x8,  \
-    half,  half4x4, simdgroup_half8x8,  \
-    float,          simdgroup_float8x8, \
-    float,          simdgroup_float8x8, \
-    half,  half4,   simdgroup_half8x8
-
-typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
-
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128>;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256>;
-
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64>;
-template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80>;
-template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96>;
-template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112>;
-template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256>;
-#endif
-
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
-
-#undef FA_TYPES
-
-template<
-    typename q4_t,    // query types in shared memory
-    typename q4x4_t,
-    typename k4x4_t,  // key types in shared memory
-    typename v4x4_t,  // value types in shared memory
-    typename qk_t,    // Q*K types
-    typename s_t,     // soft-max types
-    typename s4_t,
-    typename s4x4_t,
-    typename o4x4_t,  // attention accumulation types
-    typename kd4x4_t, // key type in device memory
-    short nl_k,
-    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
-    typename vd4x4_t, // key type in device memory
-    short nl_v,
-    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
-    short D,         // head size
-    short Q  = 1,    // queries per threadgroup
-    short C  = 32>   // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
-        device const  char * q,
-        device const  char * k,
-        device const  char * v,
-        device const  char * mask,
-        device       float * dst,
-        constant   int32_t & ne01,
-        constant   int32_t & ne02,
-        constant   int32_t & ne03,
-        constant  uint32_t & nb01,
-        constant  uint32_t & nb02,
-        constant  uint32_t & nb03,
-        constant   int32_t & ne11,
-        constant   int32_t & ne_12_2, // assume K and V are same shape
-        constant   int32_t & ne_12_3,
-        constant  uint32_t & nb_12_1,
-        constant  uint32_t & nb_12_2,
-        constant  uint32_t & nb_12_3,
-        constant  uint32_t & nb31,
-        constant   int32_t & ne1,
-        constant   int32_t & ne2,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint16_t & n_head_log2,
-        constant     float & logit_softcap,
-        threadgroup   half * shared [[threadgroup(0)]],
-        ushort3  tgpig[[threadgroup_position_in_grid]],
-        ushort3  tpitg[[thread_position_in_threadgroup]],
-        ushort3    ntg[[threads_per_threadgroup]],
-        ushort   tiisg[[thread_index_in_simdgroup]],
-        ushort   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const short nsg = ntg.y; // number of simdgroups
-
-    const int iq3 = tgpig[2];
-    const int iq2 = tgpig[1];
-    const int iq1 = tgpig[0];
-
-    const short D4  = D/4;
-    const short D16 = D/16;
-    const short NW  = N_SIMDWIDTH;
-    const short NL  = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
-    const short SH  = 2*C;  // shared memory per simdgroup
-
-    const short T = D + nsg*SH; // shared memory size per query in (half)
-
-  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shared +                0*D); // holds the query data
-    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shared +                0*D); // same as above but in q4_t
-    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared +                0*D); // same as above but in q4x4_t
-    threadgroup s_t    * ss    = (threadgroup s_t    *) (shared + sgitg*SH     + Q*D); // scratch buffer for attention
-    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shared + sgitg*SH     + Q*D); // same as above but in s4_t
-    threadgroup half   * sm    = (threadgroup half   *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
-    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D      + Q*T); // scratch buffer for the results
-
-    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o4x4_t lo[D16/NL];
-
-    // load heads from Q to shared memory
-    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
-
-    for (short i = tiisg; i < D4; i += NW) {
-        if (iq1 < ne01) {
-            sq4[i] = (q4_t) q4[i];
-        } else {
-            sq4[i] = (q4_t) 0.0f;
-        }
-    }
-
-    // zero out lo
-    for (short i = 0; i < D16/NL; ++i) {
-        lo[i] = (o4x4_t) 0.0f;
-    }
-
-    // zero out shared memory SH
-    for (short i = tiisg; i < SH/4; i += NW) {
-        ss4[i] = (s4_t) 0.0f;
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    {
-        half S = 0.0f;
-        half M = -__FLT16_MAX__/2;
-
-        // thread indices inside the simdgroup
-        const short tx = tiisg%NL;
-        const short ty = tiisg/NL;
-
-        // broadcast kv
-        //const short rk2 = ne02/ne12;
-        //const short rk3 = ne03/ne13;
-
-        const short ikv2 = iq2/(ne02/ne_12_2);
-        const short ikv3 = iq3/(ne03/ne_12_3);
-
-        // load the queries from shared memory into local memory
-        q4x4_t mq[D16/NL];
-
-        #pragma unroll(D16/NL)
-        for (short ii = 0; ii < D16; ii += NL) {
-            mq[ii/NL] = sq4x4[ii + tx];
-        }
-
-        const bool has_mask = mask != q;
-
-        // pointer to the mask
-        device const half * pm = (device const half *) (mask + iq1*nb31);
-
-        half slope = 1.0f;
-
-        // ALiBi
-        if (max_bias > 0.0f) {
-            const short h = iq2;
-
-            const half  base = h < n_head_log2 ? m0 : m1;
-            const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-            slope = pow(base, exph);
-        }
-
-        // loop over the KV cache
-        // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
-            const int ic = ic0 + C*sgitg;
-            if (ic >= ne11) {
-                break;
-            }
-
-            if (has_mask) {
-                sm[tiisg] = pm[ic + tiisg];
-            }
-
-            // Q*K^T
-            {
-                // each simdgroup processes 1 query and 4 (NW/NL) keys
-                for (short cc = 0; cc < C/4; ++cc) {
-                    qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
-
-                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
-                    #pragma unroll(D16/NL)
-                    for (short ii = 0; ii < D16; ii += NL) {
-                        const short i = ii + tx;
-
-                        k4x4_t mk;
-                        deq_k(pk + i/nl_k, i%nl_k, mk);
-
-                        // note: this is less precise than the version below
-                        //mqka[0] += dot(mq[ii/NL][0], mk[0]);
-                        //mqka[1] += dot(mq[ii/NL][1], mk[1]);
-                        //mqka[2] += dot(mq[ii/NL][2], mk[2]);
-                        //mqka[3] += dot(mq[ii/NL][3], mk[3]);
-
-                        mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
-                        mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
-                        mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
-                        mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
-                    }
-
-                    qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
-
-                    // simdgroup reduce
-                    // [ 0 ..  7] -> [ 0]
-                    // [ 8 .. 15] -> [ 8]
-                    // [16 .. 23] -> [16]
-                    // [24 .. 31] -> [24]
-                  //mqk += simd_shuffle_down(mqk, 16);
-                  //mqk += simd_shuffle_down(mqk,  8);
-                    mqk += simd_shuffle_down(mqk,  4);
-                    mqk += simd_shuffle_down(mqk,  2);
-                    mqk += simd_shuffle_down(mqk,  1);
-
-                    // mqk = mqk*scale + mask*slope
-                    if (tx == 0) {
-                        mqk *= scale;
-
-                        if (logit_softcap != 0.0f) {
-                            mqk = logit_softcap*precise::tanh(mqk);
-                        }
-
-                        mqk += sm[4*cc + ty]*slope;
-
-                        ss[4*cc + ty] = mqk;
-                    }
-                }
-            }
-
-            simdgroup_barrier(mem_flags::mem_threadgroup);
-
-            // online softmax
-            {
-                const half m = M;
-                const half s = ss[tiisg];
-
-                M = simd_max(max(M, s));
-
-                const half ms = exp(m - M);
-                const half vs = exp(s - M);
-
-                S = S*ms + simd_sum(vs);
-
-                // the P matrix from the paper (Q rows, C columns)
-                ss[tiisg] = vs;
-
-                // O = diag(ms)*O
-                #pragma unroll(D16/NL)
-                for (short ii = 0; ii < D16; ii += NL) {
-                    lo[ii/NL] *= ms;
-                }
-            }
-
-            simdgroup_barrier(mem_flags::mem_threadgroup);
-
-            // O = O + (Q*K^T)*V
-            {
-                for (short cc = 0; cc < C/4; ++cc) {
-                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
-
-                    const s4x4_t ms(ss[4*cc + ty]);
-
-                    #pragma unroll(D16/NL)
-                    for (short ii = 0; ii < D16; ii += NL) {
-                        const short i = ii + tx;
-
-                        v4x4_t mv;
-                        deq_v(pv4 + i/nl_v, i%nl_v, mv);
-
-                        lo[ii/NL] += mv*ms;
-                    }
-                }
-            }
-        }
-
-        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
-        if (tiisg == 0) {
-            ss[0] = (s_t) S;
-            ss[1] = (s_t) M;
-        }
-    }
-
-    // simdgroup reduce
-    // [ 0,  8, 16, 24] -> [ 0]
-    // [ 1,  9, 17, 25] -> [ 1]
-    // [ 2, 10, 18, 26] -> [ 2]
-    // [ 3, 11, 19, 27] -> [ 3]
-    // [ 4, 12, 20, 28] -> [ 4]
-    // [ 5, 13, 21, 29] -> [ 5]
-    // [ 6, 14, 22, 30] -> [ 6]
-    // [ 7, 15, 23, 31] -> [ 7]
-    for (short ii = 0; ii < D16; ii += NL) {
-        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
-        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
-      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
-
-        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
-        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
-      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
-
-        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
-        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
-      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
-
-        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
-        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
-      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    // store results to shared memory
-    for (short i = tiisg; i < D16; i += NL) {
-        sr4x4[i] = lo[i/NL];
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    // parallel reduce
-    for (short r = nsg/2; r > 0; r >>= 1) {
-        if (sgitg < r) {
-            const half S0 = ss[       0];
-            const half S1 = ss[r*SH + 0];
-
-            const half M0 = ss[       1];
-            const half M1 = ss[r*SH + 1];
-
-            const half M = max(M0, M1);
-
-            const half ms0 = exp(M0 - M);
-            const half ms1 = exp(M1 - M);
-
-            const half S = S0*ms0 + S1*ms1;
-
-            if (tiisg == 0) {
-                ss[0] = S;
-                ss[1] = M;
-            }
-
-            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
-            for (short i = tiisg; i < D16; i += NW) {
-                sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
-            }
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
-
-    device float4x4 * dst44 = (device float4x4 *) dst;
-
-    // final rescale with 1/S and store to global memory
-    if (sgitg == 0) {
-        const float S = ss[0];
-
-        for (short i = tiisg; i < D16; i += NW) {
-            dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
-        }
-    }
-}
-
-// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
-//       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
-//
-#define FA_TYPES \
-           half4,  half4x4, \
-                   half4x4, \
-                   half4x4, \
-    float,                  \
-    half,  half4,  half4x4, \
-                   half4x4
-
-typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
-
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  128>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,   1, dequantize_bf16, 128>;
-#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0,  2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1,  2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0,  2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1,  2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0,  2, dequantize_q8_0, 128>;
-
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  256>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,   1, dequantize_bf16, 256>;
-#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0,  2, dequantize_q4_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1,  2, dequantize_q4_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0,  2, dequantize_q5_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1,  2, dequantize_q5_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0,  2, dequantize_q8_0, 256>;
-
-#undef FA_TYPES
-
-template<typename T0, typename T1>
-kernel void kernel_cpy(
-        device  const void * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
-    device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        dst_data[i00] = (T1) src[0];
-    }
-}
-
-typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
-
-template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy<float,  float>;
-template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy<float,  half>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy<float,  bfloat>;
-#endif
-template [[host_name("kernel_cpy_f16_f32")]]   kernel kernel_cpy_t kernel_cpy<half,   float>;
-template [[host_name("kernel_cpy_f16_f16")]]   kernel kernel_cpy_t kernel_cpy<half,   half>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy<bfloat, float>;
-template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
-#endif
-
-kernel void kernel_cpy_f32_q8_0(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
-
-    device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
-        float amax = 0.0f; // absolute max
-
-        for (int j = 0; j < QK8_0; j++) {
-            const float v = src[j];
-            amax = MAX(amax, fabs(v));
-        }
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK8_0].d = d;
-
-        for (int j = 0; j < QK8_0; ++j) {
-            const float x0 = src[j]*id;
-
-            dst_data[i00/QK8_0].qs[j] = round(x0);
-        }
-    }
-}
-
-kernel void kernel_cpy_f32_q4_0(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
-
-    device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK4_0; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / -8;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK4_0].d = d;
-
-        for (int j = 0; j < QK4_0/2; ++j) {
-            const float x0 = src[0       + j]*id;
-            const float x1 = src[QK4_0/2 + j]*id;
-
-            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
-            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
-
-            dst_data[i00/QK4_0].qs[j]  = xi0;
-            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
-        }
-    }
-}
-
-kernel void kernel_cpy_f32_q4_1(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
-
-    device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
-        float min = FLT_MAX;
-        float max = -FLT_MAX;
-
-        for (int j = 0; j < QK4_1; j++) {
-            const float v = src[j];
-            if (min > v) min = v;
-            if (max < v) max = v;
-        }
-
-        const float d = (max - min) / ((1 << 4) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK4_1].d = d;
-        dst_data[i00/QK4_1].m = min;
-
-        for (int j = 0; j < QK4_1/2; ++j) {
-            const float x0 = (src[0       + j] - min)*id;
-            const float x1 = (src[QK4_1/2 + j] - min)*id;
-
-            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
-            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
-
-            dst_data[i00/QK4_1].qs[j]  = xi0;
-            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
-        }
-    }
-}
-
-kernel void kernel_cpy_f32_q5_0(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
-
-    device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK5_0; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / -16;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK5_0].d = d;
-
-        uint32_t qh = 0;
-        for (int j = 0; j < QK5_0/2; ++j) {
-            const float x0 = src[0       + j]*id;
-            const float x1 = src[QK5_0/2 + j]*id;
-
-            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
-            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
-
-            dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
-        }
-        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
-        for (int j = 0; j < 4; ++j) {
-            dst_data[i00/QK5_0].qh[j] = qh8[j];
-        }
-    }
-}
-
-kernel void kernel_cpy_f32_q5_1(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
-
-    device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
-        float max = src[0];
-        float min = src[0];
-
-        for (int j = 1; j < QK5_1; j++) {
-            const float v = src[j];
-            min = v < min ? v : min;
-            max = v > max ? v : max;
-        }
-
-        const float d = (max - min) / 31;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK5_1].d = d;
-        dst_data[i00/QK5_1].m = min;
-
-        uint32_t qh = 0;
-        for (int j = 0; j < QK5_1/2; ++j) {
-            const float x0 = (src[0       + j] - min)*id;
-            const float x1 = (src[QK5_1/2 + j] - min)*id;
-
-            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
-            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
-
-            dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
-        }
-        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
-        for (int j = 0; j < 4; ++j) {
-            dst_data[i00/QK5_1].qh[j] = qh8[j];
-        }
-    }
-}
-
-static inline int best_index_int8(int n, constant float * val, float x) {
-    if (x <= val[0]) return 0;
-    if (x >= val[n-1]) return n-1;
-    int ml = 0, mu = n-1;
-    while (mu-ml > 1) {
-        int mav = (ml+mu)/2;
-        if (x < val[mav]) mu = mav; else ml = mav;
-    }
-    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
-}
-
-kernel void kernel_cpy_f32_iq4_nl(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
-
-    device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-    for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK4_0; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / kvalues_iq4nl_f[0];
-        const float id = d ? 1.0f/d : 0.0f;
-
-        float sumqx = 0, sumq2 = 0;
-        for (int j = 0; j < QK4_NL/2; ++j) {
-            const float x0 = src[0        + j]*id;
-            const float x1 = src[QK4_NL/2 + j]*id;
-
-            const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
-            const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
-
-            dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
-
-            const float v0 = kvalues_iq4nl_f[xi0];
-            const float v1 = kvalues_iq4nl_f[xi1];
-            const float w0 = src[0        + j]*src[0        + j];
-            const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
-            sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
-            sumq2 += w0*v0*v0 + w1*v1*v1;
-
-        }
-
-        dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
-
-    }
-}
-
-kernel void kernel_concat(
-    device  const char * src0,
-    device  const char * src1,
-    device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne10,
-    constant   int64_t & ne11,
-    constant   int64_t & ne12,
-    constant   int64_t & ne13,
-    constant  uint64_t & nb10,
-    constant  uint64_t & nb11,
-    constant  uint64_t & nb12,
-    constant  uint64_t & nb13,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant   int32_t & dim,
-    uint3 tgpig[[threadgroup_position_in_grid]],
-    uint3 tpitg[[thread_position_in_threadgroup]],
-    uint3   ntg[[threads_per_threadgroup]]) {
-
-    const int64_t i3 = tgpig.z;
-    const int64_t i2 = tgpig.y;
-    const int64_t i1 = tgpig.x;
-
-    int64_t o[4] = {0, 0, 0, 0};
-    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
-
-    device const float * x;
-
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
-            x = (device const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
-        } else {
-            x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
-        }
-
-        device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-
-        *y = *x;
-    }
-}
-
-void kernel_mul_mv_q2_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
-    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
-
-    float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    const int ix = tiisg/8;  // 0...3
-    const int it = tiisg%8;  // 0...7
-    const int iq = it/4;     // 0 or 1
-    const int ir = it%4;     // 0...3
-    const int is = (8*ir)/16;// 0 or 1
-
-    device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
-
-    for (int ib = ix; ib < nb; ib += 4) {
-
-        float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
-            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
-            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
-            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
-            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
-        }
-
-        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*iq + is;
-        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
-        device const half     * dh = &x[ib].d;
-
-        for (int row = 0; row < N_DST; row++) {
-
-            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
-            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
-            for (int i = 0; i < 8; i += 2) {
-                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
-                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
-                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
-                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
-                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
-                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
-                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
-                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
-            }
-            float dall = dh[0];
-            float dmin = dh[1] * 1.f/16.f;
-            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
-                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
-                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
-                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
-                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
-
-            qs += nb01/2;
-            sc += nb01;
-            dh += nb01/2;
-        }
-
-        y4 += 4 * QK_K;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_q2_K_f32")]]
-kernel void kernel_mul_mv_q2_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_q3_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-    const int64_t im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
-    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
-
-    float yl[32];
-
-    //const uint16_t kmask1 = 0x3030;
-    //const uint16_t kmask2 = 0x0f0f;
-
-    const int tid = tiisg/4;
-    const int ix  = tiisg%4;
-    const int ip  = tid/4;          // 0 or 1
-    const int il  = 2*((tid%4)/2);  // 0 or 2
-    const int ir  = tid%2;
-    const int n   = 8;
-    const int l0  = n*ir;
-
-    // One would think that the Metal compiler would figure out that ip and il can only have
-    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
-    // with these two tales.
-    //
-    // Possible masks for the high bit
-    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
-                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
-                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
-                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
-
-    // Possible masks for the low 2 bits
-    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
-
-    const ushort4 hm = mm[2*ip + il/2];
-
-    const int shift = 2*il;
-    const float    v1 = il == 0 ? 4.f : 64.f;
-    const float    v2 = 4.f * v1;
-
-    const uint16_t s_shift1 = 4*ip;
-    const uint16_t s_shift2 = s_shift1 + il;
-
-    const int q_offset = 32*ip + l0;
-    const int y_offset = 128*ip + 32*il + l0;
-
-    device const float * y1 = yy + ix*QK_K + y_offset;
-
-    uint32_t scales32, aux32;
-    thread uint16_t * scales16 = (thread uint16_t *)&scales32;
-    thread const int8_t * scales = (thread const int8_t *)&scales32;
-
-    float sumf1[2] = {0.f};
-    float sumf2[2] = {0.f};
-    for (int i = ix; i < nb; i += 4) {
-        for (int l = 0; l < 8; ++l) {
-            yl[l+ 0] = y1[l+ 0];
-            yl[l+ 8] = y1[l+16];
-            yl[l+16] = y1[l+32];
-            yl[l+24] = y1[l+48];
-        }
-
-        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
-        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
-        device const uint16_t * a = (device const uint16_t *)(x[i].scales);
-        device const half * dh = &x[i].d;
-
-        for (int row = 0; row < 2; ++row) {
-            const float d_all = (float)dh[0];
-
-            scales16[0] = a[4];
-            scales16[1] = a[5];
-            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
-            scales16[0] = a[il+0];
-            scales16[1] = a[il+1];
-            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
-
-            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
-            for (int l = 0; l < n; l += 2) {
-                const int32_t qs = q[l/2];
-                s1 += yl[l+0] * (qs & qm[il/2][0]);
-                s2 += yl[l+1] * (qs & qm[il/2][1]);
-                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
-                s4 += yl[l+16] * (qs & qm[il/2][2]);
-                s5 += yl[l+17] * (qs & qm[il/2][3]);
-                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
-            }
-            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
-            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
-            sumf1[row] += d1 * (scales[0] - 32);
-            sumf2[row] += d2 * (scales[2] - 32);
-
-            s1 = s2 = s3 = s4 = s5 = s6 = 0;
-            for (int l = 0; l < n; l += 2) {
-                const int32_t qs = q[l/2+8];
-                s1 += yl[l+8] * (qs & qm[il/2][0]);
-                s2 += yl[l+9] * (qs & qm[il/2][1]);
-                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
-                s4 += yl[l+24] * (qs & qm[il/2][2]);
-                s5 += yl[l+25] * (qs & qm[il/2][3]);
-                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
-            }
-            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
-            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
-            sumf1[row] += d1 * (scales[1] - 32);
-            sumf2[row] += d2 * (scales[3] - 32);
-
-            q  += nb01/2;
-            h  += nb01/2;
-            a  += nb01/2;
-            dh += nb01/2;
-        }
-
-        y1 += 4 * QK_K;
-    }
-
-    for (int row = 0; row < 2; ++row) {
-        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
-        sumf1[row] = simd_sum(sumf);
-    }
-    if (tiisg == 0) {
-        for (int row = 0; row < 2; ++row) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_q3_K_f32")]]
-kernel void kernel_mul_mv_q3_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_q4_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    const int ix = tiisg/8;  // 0...3
-    const int it = tiisg%8;  // 0...7
-    const int iq = it/4;     // 0 or 1
-    const int ir = it%4;     // 0...3
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int first_row = r0 * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
-    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
-
-    float yl[16];
-    float yh[16];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
-
-    uint16_t sc16[4];
-    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
-
-    for (int ib = ix; ib < nb; ib += 4) {
-        float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
-            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
-            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
-            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
-            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
-        }
-
-        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
-        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
-        device const half     * dh = &x[ib].d;
-
-        for (int row = 0; row < N_DST; row++) {
-            sc16[0] = sc[0] & kmask1;
-            sc16[1] = sc[2] & kmask1;
-            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
-            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
-
-            device const uint16_t * q2 = q1 + 32;
-
-            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
-            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
-            for (int i = 0; i < 8; i += 2) {
-                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
-                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
-                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
-                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
-                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
-                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
-                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
-                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
-            }
-
-            float dall = dh[0];
-            float dmin = dh[1];
-            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
-                                 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
-                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
-                                 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
-                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
-
-            q1 += nb01/2;
-            sc += nb01/2;
-            dh += nb01/2;
-        }
-
-        y4 += 4 * QK_K;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_q4_K_f32")]]
-kernel void kernel_mul_mv_q4_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_q5_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
-    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
-
-    float sumf[2]={0.f};
-
-    float yl[16], yh[16];
-
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    const int tid = tiisg/4;
-    const int ix  = tiisg%4;
-    const int iq  = tid/4;
-    const int ir  = tid%4;
-    const int n   = 8;
-
-    const int l0 = n*ir;
-    const int q_offset = 32*iq + l0;
-    const int y_offset = 64*iq + l0;
-
-    const uint8_t hm1 = 1u << (2*iq);
-    const uint8_t hm2 = hm1 << 1;
-    const uint8_t hm3 = hm1 << 4;
-    const uint8_t hm4 = hm2 << 4;
-
-    uint16_t sc16[4];
-    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
-
-    device const float * y1 = yy + ix*QK_K + y_offset;
-
-    for (int i = ix; i < nb; i += 4) {
-        device const uint8_t * q1 = x[i].qs + q_offset;
-        device const uint8_t * qh = x[i].qh + l0;
-        device const half * dh = &x[i].d;
-        device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
-
-        device const float * y2 = y1 + 128;
-        float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int l = 0; l < 8; ++l) {
-            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
-            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
-            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
-            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
-        }
-
-        for (int row = 0; row < 2; ++row) {
-            device const uint8_t * q2 = q1 + 64;
-
-            sc16[0] = a[0] & kmask1;
-            sc16[1] = a[2] & kmask1;
-            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
-            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
-
-            float4 acc1 = {0.f};
-            float4 acc2 = {0.f};
-            for (int l = 0; l < n; ++l) {
-                uint8_t h = qh[l];
-                acc1[0] += yl[l+0] * (q1[l] & 0x0F);
-                acc1[1] += yl[l+8] * (q1[l] & 0xF0);
-                acc1[2] += yh[l+0] * (q2[l] & 0x0F);
-                acc1[3] += yh[l+8] * (q2[l] & 0xF0);
-                acc2[0] += h & hm1 ? yl[l+0] : 0.f;
-                acc2[1] += h & hm2 ? yl[l+8] : 0.f;
-                acc2[2] += h & hm3 ? yh[l+0] : 0.f;
-                acc2[3] += h & hm4 ? yh[l+8] : 0.f;
-            }
-            const float dall = dh[0];
-            const float dmin = dh[1];
-            sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
-                                 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
-                                 sc8[4] * (acc1[2] +  16.f*acc2[2]) +
-                                 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
-                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
-
-            q1 += nb01;
-            qh += nb01;
-            dh += nb01/2;
-            a  += nb01/2;
-        }
-
-        y1 += 4 * QK_K;
-    }
-
-    for (int row = 0; row < 2; ++row) {
-        const float tot = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_q5_K_f32")]]
-kernel void kernel_mul_mv_q5_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_q6_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const uint8_t kmask1 = 0x03;
-    const uint8_t kmask2 = 0x0C;
-    const uint8_t kmask3 = 0x30;
-    const uint8_t kmask4 = 0xC0;
-
-    const int nb = ne00/QK_K;
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-    const int     im = tgpig.z;
-
-    const int row = 2 * r0 + sgitg;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =  r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
-    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
-
-    float sumf = 0;
-
-    const int tid  = tiisg/2;
-    const int ix   = tiisg%2;
-    const int ip   = tid/8;         // 0 or 1
-    const int il   = tid%8;
-    const int n    = 4;
-    const int l0   = n*il;
-    const int is   = 8*ip + l0/16;
-
-    const int y_offset = 128*ip + l0;
-    const int q_offset_l = 64*ip + l0;
-    const int q_offset_h = 32*ip + l0;
-
-    for (int i = ix; i < nb; i += 2) {
-
-        device const uint8_t * q1 = x[i].ql + q_offset_l;
-        device const uint8_t * q2 = q1 + 32;
-        device const uint8_t * qh = x[i].qh + q_offset_h;
-        device const int8_t  * sc = x[i].scales + is;
-
-        device const float * y = yy + i * QK_K + y_offset;
-
-        const float dall = x[i].d;
-
-        float4 sums = {0.f, 0.f, 0.f, 0.f};
-        for (int l = 0; l < n; ++l) {
-            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
-            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
-            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
-            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
-        }
-
-        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
-
-    }
-
-    const float tot = simd_sum(sumf);
-    if (tiisg == 0) {
-        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
-    }
-}
-
-[[host_name("kernel_mul_mv_q6_K_f32")]]
-kernel void kernel_mul_mv_q6_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-// ======================= "True" 2-bit
-
-void kernel_mul_mv_iq2_xxs_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
-    device const float         * y = (device const float         *) ((device char *) src1 + offset1);
-
-    float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    const int nb32 = nb * (QK_K / 32);
-
-    threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
-    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 256);
-    {
-        int nval = 4;
-        int pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
-        nval = 2;
-        pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
-
-    const int ix = tiisg;
-
-    device const float * y4 = y + 32 * ix;
-
-    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
-            yl[i] = y4[i];
-        }
-
-        const int ibl = ib32 / (QK_K / 32);
-        const int ib  = ib32 % (QK_K / 32);
-
-        device const block_iq2_xxs * xr = x + ibl;
-        device const uint16_t * q2 = xr->qs + 4 * ib;
-        device const half * dh = &xr->d;
-
-        for (int row = 0; row < N_DST; row++) {
-
-            const float db = dh[0];
-            device const uint8_t * aux8 = (device const uint8_t *)q2;
-            const uint32_t aux32 = q2[2] | (q2[3] << 16);
-            const float d = db * (0.5f + (aux32 >> 28));
-
-            float sum = 0;
-            for (int l = 0; l < 4; ++l) {
-                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
-                const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 8; ++j) {
-                    sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
-                }
-            }
-            sumf[row] += d * sum;
-
-            dh += nb01/2;
-            q2 += nb01/2;
-        }
-
-        y4 += 32 * 32;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
-kernel void kernel_mul_mv_iq2_xxs_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq2_xs_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
-    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
-
-    float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    const int nb32 = nb * (QK_K / 32);
-
-    threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
-    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 512);
-    {
-        int nval = 8;
-        int pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
-        nval = 2;
-        pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
-
-    const int ix = tiisg;
-
-    device const float * y4 = y + 32 * ix;
-
-    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
-            yl[i] = y4[i];
-        }
-
-        const int ibl = ib32 / (QK_K / 32);
-        const int ib  = ib32 % (QK_K / 32);
-
-        device const block_iq2_xs * xr = x + ibl;
-        device const uint16_t * q2 = xr->qs + 4 * ib;
-        device const uint8_t  * sc = xr->scales + ib;
-        device const half * dh = &xr->d;
-
-        for (int row = 0; row < N_DST; row++) {
-
-            const float db = dh[0];
-            const uint8_t ls1 = sc[0] & 0xf;
-            const uint8_t ls2 = sc[0] >>  4;
-            const float d1 = db * (0.5f + ls1);
-            const float d2 = db * (0.5f + ls2);
-
-            float sum1 = 0, sum2 = 0;
-            for (int l = 0; l < 2; ++l) {
-                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
-                const uint8_t signs = shared_signs[(q2[l] >> 9)];
-                for (int j = 0; j < 8; ++j) {
-                    sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
-                }
-            }
-            for (int l = 2; l < 4; ++l) {
-                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
-                const uint8_t signs = shared_signs[(q2[l] >> 9)];
-                for (int j = 0; j < 8; ++j) {
-                    sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
-                }
-            }
-            sumf[row] += d1 * sum1 + d2 * sum2;
-
-            dh += nb01/2;
-            q2 += nb01/2;
-            sc += nb01;
-        }
-
-        y4 += 32 * 32;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_iq2_xs_f32")]]
-kernel void kernel_mul_mv_iq2_xs_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq3_xxs_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
-    device const float         * y = (device const float         *) ((device char *) src1 + offset1);
-
-    float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    const int nb32 = nb * (QK_K / 32);
-
-    threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
-    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 256);
-    {
-        int nval = 4;
-        int pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
-        nval = 2;
-        pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
-
-    const int ix = tiisg;
-
-    device const float * y4 = y + 32 * ix;
-
-    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
-            yl[i] = y4[i];
-        }
-
-        const int ibl = ib32 / (QK_K / 32);
-        const int ib  = ib32 % (QK_K / 32);
-
-        device const block_iq3_xxs * xr = x + ibl;
-        device const uint8_t  * q3 = xr->qs + 8 * ib;
-        device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
-        device const half * dh = &xr->d;
-
-        for (int row = 0; row < N_DST; row++) {
-
-            const float db = dh[0];
-            const uint32_t aux32 = gas[0] | (gas[1] << 16);
-            const float d = db * (0.5f + (aux32 >> 28));
-
-            float2 sum = {0};
-            for (int l = 0; l < 4; ++l) {
-                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
-                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
-                const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 4; ++j) {
-                    sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
-                    sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
-                }
-            }
-            sumf[row] += d * (sum[0] + sum[1]);
-
-            dh  += nb01/2;
-            q3  += nb01;
-            gas += nb01/2;
-        }
-
-        y4 += 32 * 32;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
-kernel void kernel_mul_mv_iq3_xxs_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq3_s_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
-    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
-
-    float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    const int nb32 = nb * (QK_K / 32);
-
-    threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
-    {
-        int nval = 8;
-        int pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
-
-    const int ix = tiisg;
-
-    device const float * y4 = y + 32 * ix;
-
-    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
-            yl[i] = y4[i];
-        }
-
-        const int ibl = ib32 / (QK_K / 32);
-        const int ib  = ib32 % (QK_K / 32);
-
-        device const block_iq3_s * xr = x + ibl;
-        device const uint8_t * qs = xr->qs + 8 * ib;
-        device const uint8_t * qh = xr->qh + ib;
-        device const uint8_t * sc = xr->scales + (ib/2);
-        device const uint8_t * signs = xr->signs + 4 * ib;
-        device const half * dh = &xr->d;
-
-        for (int row = 0; row < N_DST; row++) {
-
-            const float db = dh[0];
-            const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
-
-            float2 sum = {0};
-            for (int l = 0; l < 4; ++l) {
-                const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
-                const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
-                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
-                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
-                for (int j = 0; j < 4; ++j) {
-                    sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
-                    sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
-                }
-            }
-            sumf[row] += d * (sum[0] + sum[1]);
-
-            dh    += nb01/2;
-            qs    += nb01;
-            qh    += nb01;
-            sc    += nb01;
-            signs += nb01;
-        }
-
-        y4 += 32 * 32;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_iq3_s_f32")]]
-kernel void kernel_mul_mv_iq3_s_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq2_s_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
-    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
-
-    float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    const int nb32 = nb * (QK_K / 32);
-
-    //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
-    //{
-    //    int nval = 32;
-    //    int pos  = (32*sgitg + tiisg)*nval;
-    //    for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
-    //    threadgroup_barrier(mem_flags::mem_threadgroup);
-    //}
-
-    const int ix = tiisg;
-
-    device const float * y4 = y + 32 * ix;
-
-    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
-            yl[i] = y4[i];
-        }
-
-        const int ibl = ib32 / (QK_K / 32);
-        const int ib  = ib32 % (QK_K / 32);
-
-        device const block_iq2_s * xr = x + ibl;
-        device const uint8_t * qs = xr->qs + 4 * ib;
-        device const uint8_t * qh = xr->qh + ib;
-        device const uint8_t * sc = xr->scales + ib;
-        device const uint8_t * signs = qs + QK_K/8;
-        device const half * dh = &xr->d;
-
-        for (int row = 0; row < N_DST; row++) {
-
-            const float db = dh[0];
-            const float d1 = db * (0.5f + (sc[0] & 0xf));
-            const float d2 = db * (0.5f + (sc[0] >>  4));
-
-            float2 sum = {0};
-            for (int l = 0; l < 2; ++l) {
-                //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
-                //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
-                constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
-                constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
-                    sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
-                }
-            }
-            sumf[row] += d1 * sum[0] + d2 * sum[1];
-
-            dh    += nb01/2;
-            qs    += nb01;
-            qh    += nb01;
-            sc    += nb01;
-            signs += nb01;
-        }
-
-        y4 += 32 * 32;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_iq2_s_f32")]]
-kernel void kernel_mul_mv_iq2_s_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-void kernel_mul_mv_iq1_s_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_value,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
-    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
-
-    float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    const int nb32 = nb * (QK_K / 32);
-
-    const int ix = tiisg;
-
-    device const float * y4 = y + 32 * ix;
-
-    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        float sumy = 0;
-        for (int i = 0; i < 32; ++i) {
-            yl[i] = y4[i];
-            sumy += yl[i];
-        }
-
-        const int ibl = ib32 / (QK_K / 32);
-        const int ib  = ib32 % (QK_K / 32);
-
-        device const block_iq1_s * xr = x + ibl;
-        device const uint8_t  * qs = xr->qs + 4 * ib;
-        device const uint16_t * qh = xr->qh + ib;
-        device const half     * dh = &xr->d;
-
-        for (int row = 0; row < N_DST; row++) {
-
-            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
-            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
-            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
-            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
-
-            float sum = 0;
-            for (int j = 0; j < 4; ++j) {
-                sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
-                     + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
-                     + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
-                     + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
-            }
-            sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
-
-            dh += nb01/2;
-            qs += nb01;
-            qh += nb01/2;
-        }
-
-        y4 += 32 * 32;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
-        }
-    }
-}
-
-void kernel_mul_mv_iq1_m_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_value,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
-    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
-
-    float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
-
-    const int nb32 = nb * (QK_K / 32);
-
-    const int ix = tiisg;
-
-    device const float * y4 = y + 32 * ix;
-
-    iq1m_scale_t scale;
-
-    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        float4 sumy = {0.f};
-        for (int i = 0; i < 8; ++i) {
-            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
-            yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
-            yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
-            yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
-        }
-
-        const int ibl = ib32 / (QK_K / 32);
-        const int ib  = ib32 % (QK_K / 32);
-
-        device const block_iq1_m * xr = x + ibl;
-        device const uint8_t  * qs = xr->qs + 4 * ib;
-        device const uint8_t  * qh = xr->qh + 2 * ib;
-        device const uint16_t * sc = (device const uint16_t *)xr->scales;
-
-        for (int row = 0; row < N_DST; row++) {
-            scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-
-            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
-            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
-            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
-            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
-
-            float2 sum = {0.f};
-            for (int j = 0; j < 4; ++j) {
-                sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
-                        + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
-                sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
-                        + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
-            }
-            const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
-            const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
-
-            sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
-                                             (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
-
-            sc += nb01/2;
-            qs += nb01;
-            qh += nb01;
-        }
-
-        y4 += 32 * 32;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
-        }
-    }
-}
-
-void kernel_mul_mv_iq4_nl_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values_i8,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
-    const int nb = ne00/QK4_NL;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-    const int first_row = (r0 * 2 + sgitg) * 2;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
-    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
-
-    const int ix = tiisg/2;  // 0...15
-    const int it = tiisg%2;  // 0 or 1
-
-    shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    float4 yl[4];
-    float sumf[2]={0.f}, all_sum;
-
-    device const float * yb = y + ix * QK4_NL + it * 8;
-
-    uint32_t aux32[2];
-    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
-
-    float4 qf1, qf2;
-
-    for (int ib = ix; ib < nb; ib += 16) {
-
-        device const float4 * y4 = (device const float4 *)yb;
-        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
-
-        for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
-
-            device const block_iq4_nl & xb = x[row*nb + ib];
-            device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
-
-            float4 acc1 = {0.f}, acc2 = {0.f};
-
-            aux32[0] = q4[0] | (q4[1] << 16);
-            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
-            aux32[0] &= 0x0f0f0f0f;
-            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
-            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
-            acc1 += yl[0] * qf1;
-            acc2 += yl[1] * qf2;
-
-            aux32[0] = q4[2] | (q4[3] << 16);
-            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
-            aux32[0] &= 0x0f0f0f0f;
-            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
-            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
-            acc1 += yl[2] * qf1;
-            acc2 += yl[3] * qf2;
-
-            acc1 += acc2;
-
-            sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
-        }
-
-        yb += 16 * QK4_NL;
-    }
-
-    for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
-        }
-    }
-}
-
-void kernel_mul_mv_iq4_xs_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values_i8,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
-    const int nb = ne00/QK_K;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    const int im = tgpig.z;
-    const int first_row = (r0 * 2 + sgitg) * 2;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
-
-    device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
-    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
-
-    const int ix = tiisg/16;  // 0 or 1
-    const int it = tiisg%16;  // 0...15
-    const int ib = it/2;
-    const int il = it%2;
-
-    shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    float4 yl[4];
-    float sumf[2]={0.f}, all_sum;
-
-    device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
-
-    uint32_t aux32[2];
-    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
-
-    float4 qf1, qf2;
-
-    for (int ibl = ix; ibl < nb; ibl += 2) {
-
-        device const float4 * y4 = (device const float4 *)yb;
-        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
-
-        for (int row = 0; row < 2; ++row) {
-
-            device const block_iq4_xs & xb = x[row*nb + ibl];
-            device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
-
-            float4 acc1 = {0.f}, acc2 = {0.f};
-
-            aux32[0] = q4[0] & 0x0f0f0f0f;
-            aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
-            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
-            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
-            acc1 += yl[0] * qf1;
-            acc2 += yl[1] * qf2;
-
-            aux32[0] = q4[1] & 0x0f0f0f0f;
-            aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
-            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
-            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
-            acc1 += yl[2] * qf1;
-            acc2 += yl[3] * qf2;
-
-            acc1 += acc2;
-
-            const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
-            sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
-        }
-
-        yb += 2 * QK_K;
-    }
-
-    for (int row = 0; row < 2; ++row) {
-        all_sum = simd_sum(sumf[row]);
-        if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
-        }
-    }
-}
-
-[[host_name("kernel_mul_mv_iq1_s_f32")]]
-kernel void kernel_mul_mv_iq1_s_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq1_m_f32")]]
-kernel void kernel_mul_mv_iq1_m_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq4_nl_f32")]]
-kernel void kernel_mul_mv_iq4_nl_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq4_xs_f32")]]
-kernel void kernel_mul_mv_iq4_xs_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
-}
-
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
-kernel void kernel_get_rows_q(
-        device const  void * src0,
-        device const  void * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        uint3                tgpig[[threadgroup_position_in_grid]],
-        uint                 tiitg[[thread_index_in_threadgroup]],
-        uint3                tptg [[threads_per_threadgroup]]) {
-    const int64_t i10 = tgpig.x;
-    const int64_t i11 = tgpig.y;
-
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
-
-    const int64_t i02 = i11;
-
-    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
-        float4x4 temp;
-        dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
-        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
-    }
-}
-
-template<typename T>
-kernel void kernel_get_rows_f(
-        device const  void * src0,
-        device const  void * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        uint3                tgpig[[threadgroup_position_in_grid]],
-        uint                 tiitg[[thread_index_in_threadgroup]],
-        uint3                tptg [[threads_per_threadgroup]]) {
-    const int64_t i10 = tgpig.x;
-    const int64_t i11 = tgpig.y;
-
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
-
-    const int64_t i02 = i11;
-
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((      device float *) ((      device char *)  dst + i11*nb2  + i10*nb1))[ind] =
-        ((const device T     *) ((const device char *) src0 + i02*nb02 +  r*nb01))[ind];
-    }
-}
-
-kernel void kernel_get_rows_i32(
-        device const  void * src0,
-        device const  void * src1,
-        device     int32_t * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        uint3                tgpig[[threadgroup_position_in_grid]],
-        uint                 tiitg[[thread_index_in_threadgroup]],
-        uint3                tptg [[threads_per_threadgroup]]) {
-    const int64_t i10 = tgpig.x;
-    const int64_t i11 = tgpig.y;
-
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
-
-    const int64_t i02 = i11;
-
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((      device int32_t *) ((      device char *) dst  + i11*nb2 + i10*nb1))[ind] =
-        ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
-    }
-}
-
-
-#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
-#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
-#define BLOCK_SIZE_K 32
-#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
-#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
-#define THREAD_PER_BLOCK 128
-#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
-#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
-#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
-#define SG_MAT_ROW 8
-
-// each block_q contains 16*nl weights
-template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
-kernel void kernel_mul_mm(device const  uchar * src0,
-                          device const  uchar * src1,
-                          device        float * dst,
-                          constant    int64_t & ne00,
-                          constant    int64_t & ne02,
-                          constant   uint64_t & nb01,
-                          constant   uint64_t & nb02,
-                          constant   uint64_t & nb03,
-                          constant    int64_t & ne12,
-                          constant   uint64_t & nb10,
-                          constant   uint64_t & nb11,
-                          constant   uint64_t & nb12,
-                          constant   uint64_t & nb13,
-                          constant    int64_t & ne0,
-                          constant    int64_t & ne1,
-                          constant       uint & r2,
-                          constant       uint & r3,
-                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
-                          uint3                 tgpig[[threadgroup_position_in_grid]],
-                          uint                  tiitg[[thread_index_in_threadgroup]],
-                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    threadgroup T     * sa = (threadgroup T     *)(shared_memory);
-    threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
-
-    const uint r0 = tgpig.y;
-    const uint r1 = tgpig.x;
-    const uint im = tgpig.z;
-
-    // if this block is of 64x32 shape or smaller
-    short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
-    short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
-
-    // a thread shouldn't load data outside of the matrix
-    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
-    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
-
-    simdgroup_T8x8     ma[4];
-    simdgroup_float8x8 mb[2];
-    simdgroup_float8x8 mc[8];
-
-    for (short i = 0; i < 8; i++){
-        mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
-    }
-
-    short il = (tiitg % THREAD_PER_ROW);
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
-    ushort offset1 = il/nl;
-
-    device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1;
-    device const float   * y = (device const float   *)(src1
-        + nb13 * i13
-        + nb12 * i12
-        + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
-        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
-
-    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
-        // load data and store to threadgroup memory
-        T4x4 temp_a;
-        dequantize_func(x, il, temp_a);
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        #pragma unroll(16)
-        for (short i = 0; i < 16; i++) {
-            *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
-            +                     (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
-            +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4];
-        }
-
-        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
-
-        il = (il + 2 < nl) ? il + 2 : il % 2;
-        x  = (il < 2) ? x + (2+nl-1)/nl : x;
-        y += BLOCK_SIZE_K;
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        // load matrices from threadgroup memory and conduct outer products
-        threadgroup T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
-        threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
-
-        #pragma unroll(4)
-        for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
-            #pragma unroll(4)
-            for (short i = 0; i < 4; i++) {
-                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
-            }
-            simdgroup_barrier(mem_flags::mem_none);
-            #pragma unroll(2)
-            for (short i = 0; i < 2; i++) {
-                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
-            }
-
-            lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE;
-            lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE;
-
-            #pragma unroll(8)
-            for (short i = 0; i < 8; i++){
-                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
-            }
-        }
-    }
-
-    if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
-        device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \
-                               + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
-        for (short i = 0; i < 8; i++) {
-            simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
-        }
-    } else {
-        // block is smaller than 64x32, we should avoid writing data outside of the matrix
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-        threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
-                                      + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
-        for (short i = 0; i < 8; i++) {
-            simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        if (sgitg == 0) {
-            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
-                device float  * D  = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
-                device float4 * D4 = (device float4 *) D;
-
-                threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
-                threadgroup float4 * C4 = (threadgroup float4 *) C;
-
-                int i = 0;
-                for (; i < n_rows/4; i++) {
-                    *(D4 + i) = *(C4 + i);
-                }
-
-                i *= 4;
-                for (; i < n_rows; i++) {
-                    *(D + i) = *(C + i);
-                }
-            }
-        }
-    }
-}
-
-// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-void kernel_mul_mm_id_impl(
-        device const  uchar * src0,
-        device const  uchar * src1,
-        threadgroup ushort2 * rowids,
-        device        float * dst,
-        constant    int64_t & ne00,
-        constant    int64_t & ne02,
-        constant   uint64_t & nb01,
-        constant   uint64_t & nb02,
-        constant    int64_t & ne11,
-        constant    int64_t & ne12,
-        constant   uint64_t & nb10,
-        constant   uint64_t & nb11,
-        constant   uint64_t & nb12,
-        constant    int64_t & ne0,
-                    int64_t   ne1,
-                    int64_t   ne0ne1,
-        threadgroup   uchar * shared_memory,
-        uint3                 tgpig[[threadgroup_position_in_grid]],
-        uint                  tiitg[[thread_index_in_threadgroup]],
-        uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    threadgroup half  * sa = (threadgroup half  *)(shared_memory);
-    threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
-
-    const uint r0 = tgpig.y;
-    const uint r1 = tgpig.x;
-
-    if (r1 * BLOCK_SIZE_N >= ne1) return;
-
-    // if this block is of 64x32 shape or smaller
-    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
-    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
-
-    // a thread shouldn't load data outside of the matrix
-    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
-    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
-
-    simdgroup_half8x8  ma[4];
-    simdgroup_float8x8 mb[2];
-    simdgroup_float8x8 c_res[8];
-    for (int i = 0; i < 8; i++){
-        c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
-    }
-    short il = (tiitg % THREAD_PER_ROW);
-
-    ushort offset1 = il/nl;
-
-    threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
-
-    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
-    device const float   * y = (device const float   *)(src1
-        + nb12 * id[1]
-        + nb11 * (id[0] % ne11)
-        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
-
-    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
-        // load data and store to threadgroup memory
-        half4x4 temp_a;
-        dequantize_func(x, il, temp_a);
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        for (int i = 0; i < 16; i++) {
-            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
-            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
-            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
-        }
-
-        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
-
-        il = (il + 2 < nl) ? il + 2 : il % 2;
-        x  = (il < 2) ? x + (2+nl-1)/nl : x;
-        y += BLOCK_SIZE_K;
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        // load matrices from threadgroup memory and conduct outer products
-        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
-        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
-
-        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
-            for (int i = 0; i < 4; i++) {
-                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
-            }
-            simdgroup_barrier(mem_flags::mem_none);
-            for (int i = 0; i < 2; i++) {
-                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
-            }
-
-            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
-            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
-
-            for (int i = 0; i < 8; i++){
-                simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
-            }
-        }
-    }
-
-    {
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-        threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
-                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
-        for (int i = 0; i < 8; i++) {
-            simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
-        }
-
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-
-        device float * C = dst + (BLOCK_SIZE_M * r0);
-        if (sgitg == 0) {
-            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
-                threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
-                int joff =  jid[0] * ne0 + jid[1] * ne0ne1;
-                for (int i = 0; i < n_rows; i++) {
-                    *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
-                }
-            }
-        }
-    }
-}
-
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-kernel void kernel_mul_mm_id(
-        device const   uchar * src0s,
-        device const   uchar * src1,
-        device         float * dst,
-        device const   uchar * ids,
-        constant     int64_t & nei0,
-        constant     int64_t & nei1,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        threadgroup    uchar * shared_memory [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    const int32_t i02 = tgpig.z;
-    tgpig.z = 0;
-
-    device const uchar * src0 = src0s + i02*nb02;
-
-    // row indices
-    threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
-
-    // TODO: parallelize this loop
-    int64_t _ne1 = 0;
-    for (ushort ii1 = 0; ii1 < nei1; ii1++) {
-        for (ushort ii0 = 0; ii0 < nei0; ii0++) {
-            int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
-            if (id == i02) {
-                //if (tiitg == 0) {
-                    rowids[_ne1] = ushort2(ii0, ii1);
-                //}
-                _ne1++;
-            }
-        }
-    }
-
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
-        src0,
-        src1,
-        rowids,
-        dst,
-        ne00,
-        ne02,
-        nb01,
-        nb02,
-        ne11,
-        ne12,
-        nb10,
-        nb11,
-        nb12,
-        ne0,
-        _ne1,
-        ne0*ne1,
-        shared_memory,
-        tgpig,
-        tiitg,
-        sgitg);
-}
-
-#define QK_NL 16
-
-//
-// get rows
-//
-
-typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
-
-template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f<float>;
-template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f<half>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
-#endif
-
-typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
-
-template [[host_name("kernel_get_rows_q4_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_0,    2, dequantize_q4_0>;
-template [[host_name("kernel_get_rows_q4_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_1,    2, dequantize_q4_1>;
-template [[host_name("kernel_get_rows_q5_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_0,    2, dequantize_q5_0>;
-template [[host_name("kernel_get_rows_q5_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_1,    2, dequantize_q5_1>;
-template [[host_name("kernel_get_rows_q8_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q8_0,    2, dequantize_q8_0>;
-template [[host_name("kernel_get_rows_q2_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q2_K,    QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_get_rows_q3_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q3_K,    QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_get_rows_q4_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_K,    QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_get_rows_q5_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_K,    QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_get_rows_q6_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q6_K,    QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq3_s,   QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq2_s,   QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_m,   QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl,  2,     dequantize_iq4_nl>;
-template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
-
-//
-// matrix-matrix multiplication
-//
-
-typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
-
-template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32>;
-template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mm_bf16_f32")]]    kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16>;
-#endif
-template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
-
-//
-// indirect matrix-matrix multiplication
-//
-
-typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
-
-template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<float4x4,      1,     dequantize_f32>;
-template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<half4x4,       1,     dequantize_f16>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mm_id_bf16_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4,     1,     dequantize_bf16>;
-#endif
-template [[host_name("kernel_mul_mm_id_q4_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0,    2,     dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_id_q4_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1,    2,     dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_id_q5_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0,    2,     dequantize_q5_0>;
-template [[host_name("kernel_mul_mm_id_q5_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1,    2,     dequantize_q5_1>;
-template [[host_name("kernel_mul_mm_id_q8_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0,    2,     dequantize_q8_0>;
-template [[host_name("kernel_mul_mm_id_q2_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K,    QK_NL, dequantize_q2_K>;
-template [[host_name("kernel_mul_mm_id_q3_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K,    QK_NL, dequantize_q3_K>;
-template [[host_name("kernel_mul_mm_id_q4_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K,    QK_NL, dequantize_q4_K>;
-template [[host_name("kernel_mul_mm_id_q5_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K,    QK_NL, dequantize_q5_K>;
-template [[host_name("kernel_mul_mm_id_q6_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K,    QK_NL, dequantize_q6_K>;
-template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
-template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs,  QK_NL, dequantize_iq2_xs>;
-template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
-template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s,   QK_NL, dequantize_iq3_s>;
-template [[host_name("kernel_mul_mm_id_iq2_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s,   QK_NL, dequantize_iq2_s>;
-template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s,   QK_NL, dequantize_iq1_s>;
-template [[host_name("kernel_mul_mm_id_iq1_m_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m,   QK_NL, dequantize_iq1_m>;
-template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl,  2,     dequantize_iq4_nl>;
-template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
-
-//
-// matrix-vector multiplication
-//
-
-typedef void (kernel_mul_mv_impl_t)(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb00,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne11,
-                   int64_t   ne12,
-                  uint64_t   nb10,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-                   uint3     tgpig,
-                   uint      tiisg);
-
-typedef void (kernel_mul_mv2_impl_t)(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg);
-
-template<kernel_mul_mv_impl_t impl_fn>
-void mmv_fn(
-        device const    char * src0,
-        device const    char * src1,
-        device         float * dst,
-                     int64_t   ne00,
-                     int64_t   ne01,
-                     int64_t   ne02,
-                    uint64_t   nb00,
-                    uint64_t   nb01,
-                    uint64_t   nb02,
-                    uint64_t   nb03,
-                     int64_t   ne10,
-                     int64_t   ne11,
-                     int64_t   ne12,
-                     int64_t   ne13,
-                    uint64_t   nb10,
-                    uint64_t   nb11,
-                    uint64_t   nb12,
-                    uint64_t   nb13,
-                     int64_t   ne0,
-                     int64_t   ne1,
-                    uint64_t   nb1,
-                        uint   r2,
-                        uint   r3,
-        threadgroup int8_t   * shared_values,
-        uint3                  tgpig,
-        uint                   tiitg,
-        uint                   tiisg,
-        uint                   sgitg) {
-    impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
-}
-
-template<kernel_mul_mv2_impl_t impl_fn>
-void mmv_fn(
-        device const    char * src0,
-        device const    char * src1,
-        device         float * dst,
-                     int64_t   ne00,
-                     int64_t   ne01,
-                     int64_t   ne02,
-                    uint64_t   nb00,
-                    uint64_t   nb01,
-                    uint64_t   nb02,
-                    uint64_t   nb03,
-                     int64_t   ne10,
-                     int64_t   ne11,
-                     int64_t   ne12,
-                     int64_t   ne13,
-                    uint64_t   nb10,
-                    uint64_t   nb11,
-                    uint64_t   nb12,
-                    uint64_t   nb13,
-                     int64_t   ne0,
-                     int64_t   ne1,
-                    uint64_t   nb1,
-                        uint   r2,
-                        uint   r3,
-        threadgroup int8_t   * shared_values,
-        uint3                  tgpig,
-        uint                   tiitg,
-        uint                   tiisg,
-        uint                   sgitg) {
-    impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
-}
-
-typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
-
-template<mul_mv_impl_fn_t impl_fn>
-kernel void kernel_mul_mv_id(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant     int64_t & nei0,
-        constant     int64_t & nei1,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int iid1 = tgpig.z/nei0;
-    const int idx = tgpig.z%nei0;
-
-    tgpig.z = 0;
-
-    const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
-
-    const int64_t i11 = idx % ne11;
-    const int64_t i12 = iid1;
-
-    const int64_t i1 = idx;
-    const int64_t i2 = i12;
-
-    device const char * src0_cur = src0s + i02*nb02;
-    device const char * src1_cur = src1  + i11*nb11 + i12*nb12;
-    device      float *  dst_cur = dst   + i1*ne0   + i2*ne1*ne0;
-
-    impl_fn(
-        /* src0 */ src0_cur,
-        /* src1 */ src1_cur,
-        /* dst  */ dst_cur,
-        /* ne00 */ ne00,
-        /* ne01 */ ne01,
-        /* ne02 */ 1, // ne02,
-        /* nb00 */ nb00,
-        /* nb01 */ nb01,
-        /* nb02 */ nb02,
-        /* nb03 */ nb02, // ne02 == 1
-        /* ne10 */ ne10,
-        /* ne11 */ 1, // ne11,
-        /* ne12 */ 1, // ne12,
-        /* ne13 */ 1, // ne13,
-        /* nb10 */ nb10,
-        /* nb11 */ nb11,
-        /* nb12 */ nb12,
-        /* ne13 */ nb12, // ne12 == 1
-        /* ne0  */ ne0,
-        /* ne1  */ 1, // ne1,
-        /* nb1  */ nb1,
-        /* r2   */ 1,
-        /* r3   */ 1,
-        shared_values,
-        tgpig,
-        tiitg,
-        tiisg,
-        sgitg);
-}
-
-typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
-
-template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
-template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
-#endif
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
-
-kernel void kernel_pool_2d_max_f32(
-        device  const float * src0,
-        device        float * dst,
-        constant    int32_t & k0,
-        constant    int32_t & k1,
-        constant    int32_t & s0,
-        constant    int32_t & s1,
-        constant    int32_t & p0,
-        constant    int32_t & p1,
-        constant    int64_t & IH,
-        constant    int64_t & IW,
-        constant    int64_t & OH,
-        constant    int64_t & OW,
-        constant    int64_t & parallel_elements,
-        uint        gid[[thread_position_in_grid]]) {
-
-    if (gid >= parallel_elements) {
-        return;
-    }
-
-    const int idx = gid;
-    const int I_HW = IH * IW;
-    const int O_HW = OH * OW;
-    const int nc = idx / O_HW;
-    const int cur_oh = idx % O_HW / OW;
-    const int cur_ow = idx % O_HW % OW;
-
-    device const float * i_ptr = src0 + nc * I_HW;
-    device       float * o_ptr = dst  + nc * O_HW;
-
-    const int start_h = cur_oh * s1 - p1;
-    const int bh = MAX(0,  start_h);
-    const int eh = MIN(IH, start_h + k1);
-    const int start_w = cur_ow * s0 - p0;
-    const int bw = MAX(0,  start_w);
-    const int ew = MIN(IW, start_w + k0);
-
-    float res = -INFINITY;
-
-    for (int i = bh; i < eh; i += 1) {
-        for (int j = bw; j < ew; j += 1) {
-            res = MAX(res, i_ptr[i * IW + j]);
-        }
-    }
-
-    o_ptr[cur_oh * OW + cur_ow] = res;
-}
-
-kernel void kernel_pool_2d_avg_f32(
-        device  const float * src0,
-        device        float * dst,
-        constant    int32_t & k0,
-        constant    int32_t & k1,
-        constant    int32_t & s0,
-        constant    int32_t & s1,
-        constant    int32_t & p0,
-        constant    int32_t & p1,
-        constant    int64_t & IH,
-        constant    int64_t & IW,
-        constant    int64_t & OH,
-        constant    int64_t & OW,
-        constant    int64_t & parallel_elements,
-        uint        gid[[thread_position_in_grid]]) {
-
-    if (gid >= parallel_elements) {
-        return;
-    }
-
-    const int idx = gid;
-    const int I_HW = IH * IW;
-    const int O_HW = OH * OW;
-    const int nc = idx / O_HW;
-    const int cur_oh = idx % O_HW / OW;
-    const int cur_ow = idx % O_HW % OW;
-
-    device const float * i_ptr = src0 + nc * I_HW;
-    device       float * o_ptr = dst  + nc * O_HW;
-
-    const int start_h = cur_oh * s1 - p1;
-    const int bh = MAX(0,  start_h);
-    const int eh = MIN(IH, start_h + k1);
-    const int start_w = cur_ow * s0 - p0;
-    const int bw = MAX(0,  start_w);
-    const int ew = MIN(IW, start_w + k0);
-    // const float scale = 1. / ((eh - bh) * (ew - bw));
-    const float scale = 1. / (k0 * k1);
-
-    float res = 0;
-
-    for (int i = bh; i < eh; i += 1) {
-        for (int j = bw; j < ew; j += 1) {
-            float cur = i_ptr[i * IW + j];
-            res += cur * scale;
-        }
-    }
-
-    o_ptr[cur_oh * OW + cur_ow] = res;
-}
index ee47630256f5cbb985f37fa9077ca0430527e335..d1abb3cef0ec4a375755e3ae1cf7d3e0d98f0846 100644 (file)
@@ -3651,7 +3651,7 @@ static enum ggml_status ggml_metal_graph_compute(
         dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
 
         // wait for completion and check status of each command buffer
-        // needed to detect if the device ran out-of-memory for example (ggml/1881)
+        // needed to detect if the device ran out-of-memory for example (#1881)
         {
             id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
             [command_buffer waitUntilCompleted];
diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt
new file mode 100644 (file)
index 0000000..f3c0136
--- /dev/null
@@ -0,0 +1,100 @@
+if (NOT EXISTS $ENV{MUSA_PATH})
+    if (NOT EXISTS /opt/musa)
+        set(MUSA_PATH /usr/local/musa)
+    else()
+        set(MUSA_PATH /opt/musa)
+    endif()
+else()
+    set(MUSA_PATH $ENV{MUSA_PATH})
+endif()
+
+set(CMAKE_C_COMPILER "${MUSA_PATH}/bin/clang")
+set(CMAKE_C_EXTENSIONS OFF)
+set(CMAKE_CXX_COMPILER "${MUSA_PATH}/bin/clang++")
+set(CMAKE_CXX_EXTENSIONS OFF)
+
+list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
+
+find_package(MUSAToolkit)
+
+if (MUSAToolkit_FOUND)
+    message(STATUS "MUSA Toolkit found")
+
+    file(GLOB   GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
+    list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
+
+    file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
+    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
+    list(APPEND GGML_SOURCES_MUSA ${SRCS})
+    file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
+    list(APPEND GGML_SOURCES_MUSA ${SRCS})
+
+    if (GGML_CUDA_FA_ALL_QUANTS)
+        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
+        list(APPEND GGML_SOURCES_MUSA ${SRCS})
+        add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
+    else()
+        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
+        list(APPEND GGML_SOURCES_MUSA ${SRCS})
+        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
+        list(APPEND GGML_SOURCES_MUSA ${SRCS})
+        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
+        list(APPEND GGML_SOURCES_MUSA ${SRCS})
+    endif()
+
+    set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
+    foreach(SOURCE ${GGML_SOURCES_MUSA})
+        set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
+    endforeach()
+
+    add_library(ggml-musa
+                ${GGML_HEADERS_MUSA}
+                ${GGML_SOURCES_MUSA})
+
+    target_link_libraries(ggml-musa PRIVATE ggml-base)
+    target_include_directories(ggml-musa PRIVATE . ..)
+
+    # TODO: do not use CUDA definitions for MUSA
+    target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+
+    add_compile_definitions(GGML_USE_MUSA)
+    add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
+
+    if (GGML_CUDA_GRAPHS)
+        add_compile_definitions(GGML_CUDA_USE_GRAPHS)
+    endif()
+
+    if (GGML_CUDA_FORCE_MMQ)
+        add_compile_definitions(GGML_CUDA_FORCE_MMQ)
+    endif()
+
+    if (GGML_CUDA_FORCE_CUBLAS)
+        add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
+    endif()
+
+    if (GGML_CUDA_NO_VMM)
+        add_compile_definitions(GGML_CUDA_NO_VMM)
+    endif()
+
+    if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
+        add_compile_definitions(GGML_CUDA_F16)
+    endif()
+
+    if (GGML_CUDA_NO_PEER_COPY)
+        add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+    endif()
+
+    if (GGML_STATIC)
+        target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
+    else()
+        target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
+    endif()
+
+    if (GGML_CUDA_NO_VMM)
+        # No VMM requested, no need to link directly with the musa driver lib (libmusa.so)
+    else()
+        target_link_libraries(ggml-musa PRIVATE MUSA::musa_driver)
+    endif()
+else()
+    message(FATAL_ERROR "MUSA Toolkit not found")
+endif()
diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp
deleted file mode 100644 (file)
index 8a772f2..0000000
+++ /dev/null
@@ -1,1403 +0,0 @@
-#include "ggml-rpc.h"
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-
-#include <cinttypes>
-#include <string>
-#include <vector>
-#include <memory>
-#include <mutex>
-#include <unordered_map>
-#include <unordered_set>
-#ifdef _WIN32
-#  define WIN32_LEAN_AND_MEAN
-#  ifndef NOMINMAX
-#     define NOMINMAX
-#  endif
-#  include <windows.h>
-#  include <winsock2.h>
-#else
-#  include <arpa/inet.h>
-#  include <sys/socket.h>
-#  include <sys/types.h>
-#  include <netinet/in.h>
-#  include <netinet/tcp.h>
-#  include <netdb.h>
-#  include <unistd.h>
-#endif
-#include <cstring>
-
-#define UNUSED GGML_UNUSED
-
-#define GGML_DEBUG 0
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#ifdef _WIN32
-typedef SOCKET sockfd_t;
-using ssize_t = __int64;
-#else
-typedef int sockfd_t;
-#endif
-
-// cross-platform socket
-struct socket_t {
-    sockfd_t fd;
-    socket_t(sockfd_t fd) : fd(fd) {}
-    ~socket_t() {
-        GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
-#ifdef _WIN32
-        closesocket(this->fd);
-#else
-        close(this->fd);
-#endif
-    }
-};
-
-// all RPC structures must be packed
-#pragma pack(push, 1)
-// ggml_tensor is serialized into rpc_tensor
-struct rpc_tensor {
-    uint64_t id;
-    uint32_t type;
-    uint64_t buffer;
-    uint32_t ne[GGML_MAX_DIMS];
-    uint32_t nb[GGML_MAX_DIMS];
-    uint32_t op;
-    int32_t  op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
-    int32_t  flags;
-    uint64_t src[GGML_MAX_SRC];
-    uint64_t view_src;
-    uint64_t view_offs;
-    uint64_t data;
-    char name[GGML_MAX_NAME];
-
-    char padding[4];
-};
-
-static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
-
-// RPC commands
-enum rpc_cmd {
-    RPC_CMD_ALLOC_BUFFER = 0,
-    RPC_CMD_GET_ALIGNMENT,
-    RPC_CMD_GET_MAX_SIZE,
-    RPC_CMD_BUFFER_GET_BASE,
-    RPC_CMD_FREE_BUFFER,
-    RPC_CMD_BUFFER_CLEAR,
-    RPC_CMD_SET_TENSOR,
-    RPC_CMD_GET_TENSOR,
-    RPC_CMD_COPY_TENSOR,
-    RPC_CMD_GRAPH_COMPUTE,
-    RPC_CMD_GET_DEVICE_MEMORY,
-    RPC_CMD_COUNT,
-};
-
-struct rpc_msg_alloc_buffer_req {
-    uint64_t size;
-};
-
-struct rpc_msg_alloc_buffer_rsp {
-    uint64_t remote_ptr;
-    uint64_t remote_size;
-};
-
-struct rpc_msg_get_alignment_rsp {
-    uint64_t alignment;
-};
-
-struct rpc_msg_get_max_size_rsp {
-    uint64_t max_size;
-};
-
-struct rpc_msg_buffer_get_base_req {
-    uint64_t remote_ptr;
-};
-
-struct rpc_msg_buffer_get_base_rsp {
-    uint64_t base_ptr;
-};
-
-struct rpc_msg_free_buffer_req {
-    uint64_t remote_ptr;
-};
-
-struct rpc_msg_buffer_clear_req {
-    uint64_t remote_ptr;
-    uint8_t value;
-};
-
-struct rpc_msg_get_tensor_req {
-    rpc_tensor tensor;
-    uint64_t offset;
-    uint64_t size;
-};
-
-struct rpc_msg_copy_tensor_req {
-    rpc_tensor src;
-    rpc_tensor dst;
-};
-
-struct rpc_msg_copy_tensor_rsp {
-    uint8_t result;
-};
-
-struct rpc_msg_graph_compute_rsp {
-    uint8_t result;
-};
-
-struct rpc_msg_get_device_memory_rsp {
-    uint64_t free_mem;
-    uint64_t total_mem;
-};
-#pragma pack(pop)
-
-// RPC data structures
-
-static ggml_guid_t ggml_backend_rpc_guid() {
-    static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
-    return &guid;
-}
-
-struct ggml_backend_rpc_buffer_type_context {
-    std::string endpoint;
-    std::string name;
-    size_t alignment;
-    size_t max_size;
-};
-
-struct ggml_backend_rpc_context {
-    std::string endpoint;
-    std::string name;
-};
-
-struct ggml_backend_rpc_buffer_context {
-    std::shared_ptr<socket_t> sock;
-    std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
-    uint64_t remote_ptr;
-};
-
-// RPC helper functions
-
-static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
-#ifdef _WIN32
-    if (fd == INVALID_SOCKET) {
-        return nullptr;
-    }
-#else
-    if (fd < 0) {
-        return nullptr;
-    }
-#endif
-    return std::make_shared<socket_t>(fd);
-}
-
-static bool set_no_delay(sockfd_t sockfd) {
-    int flag = 1;
-    // set TCP_NODELAY to disable Nagle's algorithm
-    int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
-    return ret == 0;
-}
-
-static bool set_reuse_addr(sockfd_t sockfd) {
-    int flag = 1;
-    int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
-    return ret == 0;
-}
-
-static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
-    struct sockaddr_in addr;
-    auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
-    auto sock_ptr = make_socket(sockfd);
-    if (sock_ptr == nullptr) {
-        return nullptr;
-    }
-    if (!set_no_delay(sockfd)) {
-        fprintf(stderr, "Failed to set TCP_NODELAY\n");
-        return nullptr;
-    }
-    addr.sin_family = AF_INET;
-    addr.sin_port = htons(port);
-    struct hostent * server = gethostbyname(host);
-    if (server == NULL) {
-        fprintf(stderr, "Cannot resolve host '%s'\n", host);
-        return nullptr;
-    }
-    memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
-    if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
-        return nullptr;
-    }
-    return sock_ptr;
-}
-
-static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
-    auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
-    auto client_socket = make_socket(client_socket_fd);
-    if (client_socket == nullptr) {
-        return nullptr;
-    }
-    if (!set_no_delay(client_socket_fd)) {
-        fprintf(stderr, "Failed to set TCP_NODELAY\n");
-        return nullptr;
-    }
-    return client_socket;
-}
-
-static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
-    auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
-    auto sock = make_socket(sockfd);
-    if (sock == nullptr) {
-        return nullptr;
-    }
-    if (!set_reuse_addr(sockfd)) {
-        fprintf(stderr, "Failed to set SO_REUSEADDR\n");
-        return nullptr;
-    }
-    if (inet_addr(host) == INADDR_NONE) {
-        fprintf(stderr, "Invalid host address: %s\n", host);
-        return nullptr;
-    }
-    struct sockaddr_in serv_addr;
-    serv_addr.sin_family = AF_INET;
-    serv_addr.sin_addr.s_addr = inet_addr(host);
-    serv_addr.sin_port = htons(port);
-
-    if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
-        return nullptr;
-    }
-    if (listen(sockfd, 1) < 0) {
-        return nullptr;
-    }
-    return sock;
-}
-
-static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
-    size_t bytes_sent = 0;
-    while (bytes_sent < size) {
-        ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
-        if (n < 0) {
-            return false;
-        }
-        bytes_sent += n;
-    }
-    return true;
-}
-
-static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
-    size_t bytes_recv = 0;
-    while (bytes_recv < size) {
-        ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
-        if (n <= 0) {
-            return false;
-        }
-        bytes_recv += n;
-    }
-    return true;
-}
-
-static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
-    if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
-        return false;
-    }
-    return send_data(sockfd, msg, msg_size);
-}
-
-static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
-    uint64_t size;
-    if (!recv_data(sockfd, &size, sizeof(size))) {
-        return false;
-    }
-    if (size != msg_size) {
-        return false;
-    }
-    return recv_data(sockfd, msg, msg_size);
-}
-
-static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
-    uint64_t size;
-    if (!recv_data(sockfd, &size, sizeof(size))) {
-        return false;
-    }
-    try {
-        input.resize(size);
-    } catch (const std::bad_alloc & e) {
-        fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
-        return false;
-    }
-    return recv_data(sockfd, input.data(), size);
-}
-
-static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
-    size_t pos = endpoint.find(':');
-    if (pos == std::string::npos) {
-        return false;
-    }
-    host = endpoint.substr(0, pos);
-    port = std::stoi(endpoint.substr(pos + 1));
-    return true;
-}
-
-// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
-// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
-static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
-    uint8_t cmd_byte = cmd;
-    if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
-        return false;
-    }
-    if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
-        return false;
-    }
-    if (!send_data(sock->fd, input, input_size)) {
-        return false;
-    }
-    // TODO: currently the output_size is always known, do we need support for commands with variable output size?
-    // even if we do, we can skip sending output_size from the server for commands with known output size
-    uint64_t out_size;
-    if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
-        return false;
-    }
-    if (out_size != output_size) {
-        return false;
-    }
-    if (!recv_data(sock->fd, output, output_size)) {
-        return false;
-    }
-    return true;
-}
-
-// RPC client-side implementation
-
-static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-    static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
-    static bool initialized = false;
-
-    auto it = sockets.find(endpoint);
-    if (it != sockets.end()) {
-        if (auto sock = it->second.lock()) {
-            return sock;
-        }
-    }
-    std::string host;
-    int port;
-    if (!parse_endpoint(endpoint, host, port)) {
-        return nullptr;
-    }
-#ifdef _WIN32
-    if (!initialized) {
-        WSADATA wsaData;
-        int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
-        if (res != 0) {
-            return nullptr;
-        }
-        initialized = true;
-    }
-#else
-    UNUSED(initialized);
-#endif
-    auto sock = socket_connect(host.c_str(), port);
-    if (sock == nullptr) {
-        return nullptr;
-    }
-    GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
-    sockets[endpoint] = sock;
-    return sock;
-}
-
-static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    rpc_msg_free_buffer_req request = {ctx->remote_ptr};
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
-    GGML_ASSERT(status);
-    delete ctx;
-}
-
-static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
-    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
-        return ctx->base_cache[buffer];
-    }
-    rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
-    rpc_msg_buffer_get_base_rsp response;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
-    GGML_ASSERT(status);
-    void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
-    ctx->base_cache[buffer] = base_ptr;
-    return base_ptr;
-}
-
-static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
-    rpc_tensor result;
-    result.id = reinterpret_cast<uint64_t>(tensor);
-    result.type = tensor->type;
-    if (tensor->buffer) {
-        ggml_backend_buffer_t buffer = tensor->buffer;
-        ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-        result.buffer = ctx->remote_ptr;
-    } else {
-        result.buffer = 0;
-    }
-    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
-        result.ne[i] = tensor->ne[i];
-        result.nb[i] = tensor->nb[i];
-    }
-    result.op = tensor->op;
-    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
-        result.op_params[i] = tensor->op_params[i];
-    }
-    result.flags = tensor->flags;
-    for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
-        result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
-    }
-    result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
-    result.view_offs = tensor->view_offs;
-    result.data = reinterpret_cast<uint64_t>(tensor->data);
-    snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
-    return result;
-}
-
-static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
-    UNUSED(buffer);
-    if (ggml_is_quantized(tensor->type)) {
-        // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
-        GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
-    }
-}
-
-static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
-    size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
-    std::vector<uint8_t> input(input_size, 0);
-    rpc_tensor rpc_tensor = serialize_tensor(tensor);
-    memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
-    memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
-    memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
-    GGML_ASSERT(status);
-}
-
-static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    rpc_msg_get_tensor_req request;
-    request.tensor = serialize_tensor(tensor);
-    request.offset = offset;
-    request.size = size;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
-    GGML_ASSERT(status);
-}
-
-static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
-    // check if src and dst are on the same server
-    ggml_backend_buffer_t src_buffer = src->buffer;
-    ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
-    ggml_backend_buffer_t dst_buffer = dst->buffer;
-    ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
-    if (src_ctx->sock != dst_ctx->sock) {
-        return false;
-    }
-    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    rpc_msg_copy_tensor_req request;
-    request.src = serialize_tensor(src);
-    request.dst = serialize_tensor(dst);
-    rpc_msg_copy_tensor_rsp response;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
-    GGML_ASSERT(status);
-    return response.result;
-}
-
-static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
-    GGML_ASSERT(status);
-}
-
-static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
-    /* .free_buffer     = */ ggml_backend_rpc_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_rpc_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_rpc_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_rpc_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_rpc_buffer_get_tensor,
-    /* .cpy_tensor      = */ ggml_backend_rpc_buffer_cpy_tensor,
-    /* .clear           = */ ggml_backend_rpc_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    return buft_ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    rpc_msg_alloc_buffer_req request = {size};
-    rpc_msg_alloc_buffer_rsp response;
-    auto sock = get_socket(buft_ctx->endpoint);
-    bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
-    GGML_ASSERT(status);
-    if (response.remote_ptr != 0) {
-        ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
-            ggml_backend_rpc_buffer_interface,
-            new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
-            response.remote_size);
-        return buffer;
-    } else {
-        return nullptr;
-    }
-}
-
-static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
-    rpc_msg_get_alignment_rsp response;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
-    GGML_ASSERT(status);
-    return response.alignment;
-}
-
-static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    return buft_ctx->alignment;
-}
-
-static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
-    rpc_msg_get_max_size_rsp response;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
-    GGML_ASSERT(status);
-    return response.max_size;
-}
-
-static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
-    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    return buft_ctx->max_size;
-}
-
-static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    UNUSED(buft);
-    return ggml_nbytes(tensor);
-}
-
-static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_rpc_buffer_type_name,
-    /* .alloc_buffer     = */ ggml_backend_rpc_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_rpc_buffer_type_get_alignment,
-    /* .get_max_size     = */ ggml_backend_rpc_get_max_size,
-    /* .get_alloc_size   = */ ggml_backend_rpc_buffer_type_get_alloc_size,
-    /* .is_host          = */ NULL,
-};
-
-static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
-    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-
-    return rpc_ctx->name.c_str();
-}
-
-static void ggml_backend_rpc_free(ggml_backend_t backend) {
-    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-    delete rpc_ctx;
-    delete backend;
-}
-
-static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
-    UNUSED(backend);
-    // this is no-op because we don't have any async operations
-}
-
-static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
-    if (tensor == nullptr) {
-        return;
-    }
-    if (visited.find(tensor) != visited.end()) {
-        return;
-    }
-    visited.insert(tensor);
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        add_tensor(tensor->src[i], tensors, visited);
-    }
-    add_tensor(tensor->view_src, tensors, visited);
-    tensors.push_back(serialize_tensor(tensor));
-}
-
-static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
-    uint32_t n_nodes = cgraph->n_nodes;
-    std::vector<rpc_tensor> tensors;
-    std::unordered_set<ggml_tensor*> visited;
-    for (uint32_t i = 0; i < n_nodes; i++) {
-        add_tensor(cgraph->nodes[i], tensors, visited);
-    }
-    // serialization format:
-    // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
-    uint32_t n_tensors = tensors.size();
-    int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
-    output.resize(output_size, 0);
-    memcpy(output.data(), &n_nodes, sizeof(n_nodes));
-    for (uint32_t i = 0; i < n_nodes; i++) {
-        memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
-    }
-    uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
-    *out_ntensors = n_tensors;
-    rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
-    memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
-}
-
-static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-    std::vector<uint8_t> input;
-    serialize_graph(cgraph, input);
-    rpc_msg_graph_compute_rsp response;
-    auto sock = get_socket(rpc_ctx->endpoint);
-    bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
-    GGML_ASSERT(status);
-    return (enum ggml_status)response.result;
-}
-
-static ggml_backend_i ggml_backend_rpc_interface = {
-    /* .get_name                = */ ggml_backend_rpc_name,
-    /* .free                    = */ ggml_backend_rpc_free,
-    /* .set_tensor_async        = */ NULL,
-    /* .get_tensor_async        = */ NULL,
-    /* .cpy_tensor_async        = */ NULL,
-    /* .synchronize             = */ ggml_backend_rpc_synchronize,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_rpc_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
-
-GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-    // NOTE: buffer types are allocated and never freed; this is by design
-    static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
-    auto it = buft_map.find(endpoint);
-    if (it != buft_map.end()) {
-        return it->second;
-    }
-    auto sock = get_socket(endpoint);
-    if (sock == nullptr) {
-        fprintf(stderr, "Failed to connect to %s\n", endpoint);
-        return nullptr;
-    }
-    size_t alignment = get_alignment(sock);
-    size_t max_size = get_max_size(sock);
-    ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
-        /* .endpoint  = */ endpoint,
-        /* .name      = */ "RPC[" + std::string(endpoint) + "]",
-        /* .alignment = */ alignment,
-        /* .max_size  = */ max_size
-    };
-
-    ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
-        /* .iface   = */ ggml_backend_rpc_buffer_type_interface,
-        /* .device  = */ ggml_backend_rpc_add_device(endpoint),
-        /* .context = */ buft_ctx
-    };
-    buft_map[endpoint] = buft;
-    return buft;
-}
-
-ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
-    ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
-        /* .endpoint  = */ endpoint,
-        /* .name      = */ "RPC[" + std::string(endpoint) + "]",
-    };
-
-    ggml_backend_t backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_rpc_guid(),
-        /* .interface = */ ggml_backend_rpc_interface,
-        /* .device    = */ ggml_backend_rpc_add_device(endpoint),
-        /* .context   = */ ctx
-    };
-    return backend;
-}
-
-GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
-}
-
-static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
-    rpc_msg_get_device_memory_rsp response;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
-    GGML_ASSERT(status);
-    *free = response.free_mem;
-    *total = response.total_mem;
-}
-
-GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
-    auto sock = get_socket(endpoint);
-    if (sock == nullptr) {
-        *free = 0;
-        *total = 0;
-        return;
-    }
-    get_device_memory(sock, free, total);
-}
-
-// RPC server-side implementation
-
-class rpc_server {
-public:
-    rpc_server(ggml_backend_t backend) : backend(backend) {}
-    ~rpc_server();
-
-    void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
-    void get_alignment(rpc_msg_get_alignment_rsp & response);
-    void get_max_size(rpc_msg_get_max_size_rsp & response);
-    bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
-    bool free_buffer(const rpc_msg_free_buffer_req & request);
-    bool buffer_clear(const rpc_msg_buffer_clear_req & request);
-    bool set_tensor(const std::vector<uint8_t> & input);
-    bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
-    bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
-    bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
-
-private:
-    ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
-    ggml_tensor * create_node(uint64_t id,
-                              struct ggml_context * ctx,
-                              const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
-                              std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
-
-
-    ggml_backend_t backend;
-    std::unordered_set<ggml_backend_buffer_t> buffers;
-};
-
-void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
-    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
-    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
-    response.remote_ptr = 0;
-    response.remote_size = 0;
-    if (buffer != nullptr) {
-        response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
-        response.remote_size = buffer->size;
-        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
-        buffers.insert(buffer);
-    } else {
-        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
-    }
-}
-
-void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
-    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
-    size_t alignment = ggml_backend_buft_get_alignment(buft);
-    GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
-    response.alignment = alignment;
-}
-
-void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
-    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
-    size_t max_size = ggml_backend_buft_get_max_size(buft);
-    GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
-    response.max_size = max_size;
-}
-
-bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
-    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
-    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
-    if (buffers.find(buffer) == buffers.end()) {
-        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
-        return false;
-    }
-    void * base = ggml_backend_buffer_get_base(buffer);
-    response.base_ptr = reinterpret_cast<uint64_t>(base);
-    return true;
-}
-
-bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
-    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
-    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
-    if (buffers.find(buffer) == buffers.end()) {
-        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
-        return false;
-    }
-    ggml_backend_buffer_free(buffer);
-    buffers.erase(buffer);
-    return true;
-}
-
-bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
-    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
-    ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
-    if (buffers.find(buffer) == buffers.end()) {
-        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
-        return false;
-    }
-    ggml_backend_buffer_clear(buffer, request.value);
-    return true;
-}
-
-ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
-    ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
-        tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
-    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
-        result->nb[i] = tensor->nb[i];
-    }
-    result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
-    if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
-        result->buffer = nullptr;
-    }
-
-    if (result->buffer) {
-        // require that the tensor data does not go beyond the buffer end
-        uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
-        uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
-        uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
-        GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
-        GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
-    }
-
-    result->op = (ggml_op) tensor->op;
-    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
-        result->op_params[i] = tensor->op_params[i];
-    }
-    result->flags = tensor->flags;
-    result->data = reinterpret_cast<void *>(tensor->data);
-    ggml_set_name(result, tensor->name);
-    return result;
-}
-
-
-bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
-    // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
-    if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
-        return false;
-    }
-    const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
-    uint64_t offset;
-    memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
-    const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
-
-    struct ggml_init_params params {
-        /*.mem_size   =*/ ggml_tensor_overhead(),
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-    struct ggml_context * ctx = ggml_init(params);
-    ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
-    if (tensor == nullptr) {
-        GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
-        ggml_free(ctx);
-        return false;
-    }
-    GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
-
-    // sanitize tensor->data
-    {
-        const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
-        const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
-
-        if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
-            GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
-        }
-    }
-
-    const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
-    ggml_backend_tensor_set(tensor, data, offset, size);
-    ggml_free(ctx);
-    return true;
-}
-
-bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
-    struct ggml_init_params params {
-        /*.mem_size   =*/ ggml_tensor_overhead(),
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-    struct ggml_context * ctx = ggml_init(params);
-    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
-    if (tensor == nullptr) {
-        GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
-        ggml_free(ctx);
-        return false;
-    }
-    GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
-
-    // sanitize tensor->data
-    {
-        const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
-        const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
-
-        if (request.tensor.data + request.offset < p0 ||
-            request.tensor.data + request.offset >= p1 ||
-            request.size > (p1 - request.tensor.data - request.offset)) {
-                GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
-        }
-    }
-
-    response.resize(request.size, 0);
-    ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
-    ggml_free(ctx);
-    return true;
-}
-
-bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
-    struct ggml_init_params params {
-        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-    struct ggml_context * ctx = ggml_init(params);
-    ggml_tensor * src = deserialize_tensor(ctx, &request.src);
-    ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
-    if (src == nullptr || dst == nullptr) {
-        GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
-        ggml_free(ctx);
-        return false;
-    }
-    GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
-    response.result = ggml_backend_buffer_copy_tensor(src, dst);
-    ggml_free(ctx);
-    return true;
-}
-
-ggml_tensor * rpc_server::create_node(uint64_t id,
-                                      struct ggml_context * ctx,
-                                      const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
-                                      std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
-    if (id == 0) {
-        return nullptr;
-    }
-    if (tensor_map.find(id) != tensor_map.end()) {
-        return tensor_map[id];
-    }
-    const rpc_tensor * tensor = tensor_ptrs.at(id);
-    struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
-    if (result == nullptr) {
-        return nullptr;
-    }
-    tensor_map[id] = result;
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
-    }
-    result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
-    result->view_offs = tensor->view_offs;
-    return result;
-}
-
-bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
-    // serialization format:
-    // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
-    if (input.size() < sizeof(uint32_t)) {
-        return false;
-    }
-    uint32_t n_nodes;
-    memcpy(&n_nodes, input.data(), sizeof(n_nodes));
-    if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
-        return false;
-    }
-    const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
-    uint32_t n_tensors;
-    memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
-    if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
-        return false;
-    }
-    const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
-    GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
-
-    size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-    struct ggml_context * ctx = ggml_init(params);
-    struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
-    graph->n_nodes = n_nodes;
-    std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
-    for (uint32_t i = 0; i < n_tensors; i++) {
-        tensor_ptrs[tensors[i].id] = &tensors[i];
-    }
-    std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
-    for (uint32_t i = 0; i < n_nodes; i++) {
-        int64_t id;
-        memcpy(&id, &nodes[i], sizeof(id));
-        graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
-    }
-    ggml_status status = ggml_backend_graph_compute(backend, graph);
-    response.result = status;
-    ggml_free(ctx);
-    return true;
-}
-
-rpc_server::~rpc_server() {
-    for (auto buffer : buffers) {
-        ggml_backend_buffer_free(buffer);
-    }
-}
-
-static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
-    rpc_server server(backend);
-    while (true) {
-        uint8_t cmd;
-        if (!recv_data(sockfd, &cmd, 1)) {
-            break;
-        }
-        if (cmd >= RPC_CMD_COUNT) {
-            // fail fast if the command is invalid
-            fprintf(stderr, "Unknown command: %d\n", cmd);
-            break;
-        }
-        switch (cmd) {
-            case RPC_CMD_ALLOC_BUFFER: {
-                rpc_msg_alloc_buffer_req request;
-                if (!recv_msg(sockfd, &request, sizeof(request))) {
-                    return;
-                }
-                rpc_msg_alloc_buffer_rsp response;
-                server.alloc_buffer(request, response);
-                if (!send_msg(sockfd, &response, sizeof(response))) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_GET_ALIGNMENT: {
-                if (!recv_msg(sockfd, nullptr, 0)) {
-                    return;
-                }
-                rpc_msg_get_alignment_rsp response;
-                server.get_alignment(response);
-                if (!send_msg(sockfd, &response, sizeof(response))) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_GET_MAX_SIZE: {
-                if (!recv_msg(sockfd, nullptr, 0)) {
-                    return;
-                }
-                rpc_msg_get_max_size_rsp response;
-                server.get_max_size(response);
-                if (!send_msg(sockfd, &response, sizeof(response))) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_BUFFER_GET_BASE: {
-                rpc_msg_buffer_get_base_req request;
-                if (!recv_msg(sockfd, &request, sizeof(request))) {
-                    return;
-                }
-                rpc_msg_buffer_get_base_rsp response;
-                if (!server.buffer_get_base(request, response)) {
-                    return;
-                }
-                if (!send_msg(sockfd, &response, sizeof(response))) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_FREE_BUFFER: {
-                rpc_msg_free_buffer_req request;
-                if (!recv_msg(sockfd, &request, sizeof(request))) {
-                    return;
-                }
-                if (!server.free_buffer(request)) {
-                    return;
-                }
-                if (!send_msg(sockfd, nullptr, 0)) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_BUFFER_CLEAR: {
-                rpc_msg_buffer_clear_req request;
-                if (!recv_msg(sockfd, &request, sizeof(request))) {
-                    return;
-                }
-                if (!server.buffer_clear(request)) {
-                    return;
-                }
-                if (!send_msg(sockfd, nullptr, 0)) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_SET_TENSOR: {
-                std::vector<uint8_t> input;
-                if (!recv_msg(sockfd, input)) {
-                    return;
-                }
-                if (!server.set_tensor(input)) {
-                    return;
-                }
-                if (!send_msg(sockfd, nullptr, 0)) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_GET_TENSOR: {
-                rpc_msg_get_tensor_req request;
-                if (!recv_msg(sockfd, &request, sizeof(request))) {
-                    return;
-                }
-                std::vector<uint8_t> response;
-                if (!server.get_tensor(request, response)) {
-                    return;
-                }
-                if (!send_msg(sockfd, response.data(), response.size())) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_COPY_TENSOR: {
-                rpc_msg_copy_tensor_req request;
-                if (!recv_msg(sockfd, &request, sizeof(request))) {
-                    return;
-                }
-                rpc_msg_copy_tensor_rsp response;
-                if (!server.copy_tensor(request, response)) {
-                    return;
-                }
-                if (!send_msg(sockfd, &response, sizeof(response))) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_GRAPH_COMPUTE: {
-                std::vector<uint8_t> input;
-                if (!recv_msg(sockfd, input)) {
-                    return;
-                }
-                rpc_msg_graph_compute_rsp response;
-                if (!server.graph_compute(input, response)) {
-                    return;
-                }
-                if (!send_msg(sockfd, &response, sizeof(response))) {
-                    return;
-                }
-                break;
-            }
-            case RPC_CMD_GET_DEVICE_MEMORY: {
-                if (!recv_msg(sockfd, nullptr, 0)) {
-                    return;
-                }
-                rpc_msg_get_device_memory_rsp response;
-                response.free_mem = free_mem;
-                response.total_mem = total_mem;
-                if (!send_msg(sockfd, &response, sizeof(response))) {
-                    return;
-                }
-                break;
-            }
-            default: {
-                fprintf(stderr, "Unknown command: %d\n", cmd);
-                return;
-            }
-        }
-    }
-}
-
-void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
-    std::string host;
-    int port;
-    if (!parse_endpoint(endpoint, host, port)) {
-        return;
-    }
-#ifdef _WIN32
-    {
-        WSADATA wsaData;
-        int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
-        if (res != 0) {
-            fprintf(stderr, "WSAStartup failed: %d\n", res);
-            return;
-        }
-    }
-#endif
-    auto server_socket = create_server_socket(host.c_str(), port);
-    if (server_socket == nullptr) {
-        fprintf(stderr, "Failed to create server socket\n");
-        return;
-    }
-    while (true) {
-        auto client_socket = socket_accept(server_socket->fd);
-        if (client_socket == nullptr) {
-            fprintf(stderr, "Failed to accept client connection\n");
-            return;
-        }
-        printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
-        fflush(stdout);
-        rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
-        printf("Client connection closed\n");
-        fflush(stdout);
-    }
-#ifdef _WIN32
-    WSACleanup();
-#endif
-}
-
-// device interface
-
-struct ggml_backend_rpc_device_context {
-    std::string endpoint;
-    std::string name;
-};
-
-static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
-    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
-    return ctx->name.c_str();
-}
-
-static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
-    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
-    return ctx->name.c_str();
-}
-
-static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
-    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
-
-    UNUSED(dev);
-}
-
-static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
-    // TODO: obtain value from the server
-    return GGML_BACKEND_DEVICE_TYPE_GPU;
-
-    UNUSED(dev);
-}
-
-static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_rpc_device_get_name(dev);
-    props->description = ggml_backend_rpc_device_get_description(dev);
-    props->type        = ggml_backend_rpc_device_get_type(dev);
-    ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
-    props->caps = {
-        /* .async                 = */ false,
-        /* .host_buffer           = */ false,
-        /* .buffer_from_host_ptr  = */ false,
-        /* .events                = */ false,
-    };
-}
-
-static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
-    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
-    return ggml_backend_rpc_init(ctx->endpoint.c_str());
-
-    UNUSED(params);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
-    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
-
-    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
-
-    UNUSED(dev);
-}
-
-static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    UNUSED(dev);
-    UNUSED(op);
-    //TODO: call the remote backend and cache the results
-    return true;
-}
-
-static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
-        return false;
-    }
-    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
-    return buft_ctx->endpoint == dev_ctx->endpoint;
-}
-
-static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
-    /* .get_name             = */ ggml_backend_rpc_device_get_name,
-    /* .get_description      = */ ggml_backend_rpc_device_get_description,
-    /* .get_memory           = */ ggml_backend_rpc_device_get_memory,
-    /* .get_type             = */ ggml_backend_rpc_device_get_type,
-    /* .get_props            = */ ggml_backend_rpc_device_get_props,
-    /* .init_backend         = */ ggml_backend_rpc_device_init,
-    /* .get_buffer_type      = */ ggml_backend_rpc_device_get_buffer_type,
-    /* .get_host_buffer_type = */ NULL,
-    /* .buffer_from_host_ptr = */ NULL,
-    /* .supports_op          = */ ggml_backend_rpc_device_supports_op,
-    /* .supports_buft        = */ ggml_backend_rpc_device_supports_buft,
-    /* .offload_op           = */ NULL,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
-};
-
-// backend reg interface
-
-static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
-    return "RPC";
-
-    UNUSED(reg);
-}
-
-static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
-    return 0;
-
-    UNUSED(reg);
-}
-
-static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
-    GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
-
-    UNUSED(reg);
-    UNUSED(index);
-}
-
-static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
-    if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
-        return (void *)ggml_backend_rpc_add_device;
-    }
-    return NULL;
-
-    UNUSED(reg);
-}
-
-static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
-    /* .get_name         = */ ggml_backend_rpc_reg_get_name,
-    /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
-    /* .get_device       = */ ggml_backend_rpc_reg_get_device,
-    /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
-};
-
-ggml_backend_reg_t ggml_backend_rpc_reg(void) {
-    static struct ggml_backend_reg ggml_backend_rpc_reg = {
-        /* .iface   = */ ggml_backend_rpc_reg_i,
-        /* .context = */ NULL,
-    };
-
-    return &ggml_backend_rpc_reg;
-}
-
-ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
-    static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
-
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-
-    if (dev_map.find(endpoint) != dev_map.end()) {
-        return dev_map[endpoint];
-    }
-
-    ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
-        /* .endpoint = */ endpoint,
-        /* .name     = */ "RPC[" + std::string(endpoint) + "]",
-    };
-
-    ggml_backend_dev_t dev = new ggml_backend_device {
-        /* .iface   = */ ggml_backend_rpc_device_i,
-        /* .reg     = */ ggml_backend_rpc_reg(),
-        /* .context = */ ctx,
-    };
-
-    dev_map[endpoint] = dev;
-
-    return dev;
-}
diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp
deleted file mode 100644 (file)
index 2dba15d..0000000
+++ /dev/null
@@ -1,4684 +0,0 @@
-//
-// MIT license
-// Copyright (C) 2024 Intel Corporation
-// SPDX-License-Identifier: MIT
-//
-
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-
-#include <algorithm>
-#include <assert.h>
-#include <atomic>
-#include <cinttypes>
-#include <cstddef>
-#include <cstdint>
-#include <cstdlib>
-#include <float.h>
-#include <limits>
-#include <stdint.h>
-#include <stdio.h>
-#include <vector>
-#include <cmath>
-#include <iostream>
-#include <fstream>
-#include <stdio.h>
-#include <stdlib.h>
-#include <regex>
-
-#include <sycl/sycl.hpp>
-#include <sycl/half_type.hpp>
-
-#include "ggml-sycl.h"
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-
-#include "ggml-sycl/backend.hpp"
-#include "ggml-sycl/presets.hpp"
-#include "ggml-sycl/gemm.hpp"
-
-static bool g_sycl_loaded = false;
-
-static ggml_sycl_device_info ggml_sycl_init() {
-    ggml_sycl_device_info info = {};
-
-    info.device_count = dpct::dev_mgr::instance().device_count();
-    if (info.device_count == 0) {
-        fprintf(stderr, "%s: failed to initialize " GGML_SYCL_NAME ": %s\n", __func__);
-        return info;
-    }
-
-    GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
-
-    int64_t total_vram = 0;
-#if defined(GGML_SYCL_FORCE_MMQ)
-    fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ:   yes\n", __func__);
-#else
-    fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ:   no\n", __func__);
-#endif
-#if defined(SYCL_USE_XMX)
-    fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
-#else
-    fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
-#endif
-    fprintf(stderr, "%s: found %d " GGML_SYCL_NAME " devices:\n", __func__, info.device_count);
-
-    for (int i = 0; i < info.device_count; ++i) {
-        info.devices[i].vmm = 0;
-        dpct::device_info prop;
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
-            prop, dpct::dev_mgr::instance().get_device(i))));
-
-        info.default_tensor_split[i] = total_vram;
-        total_vram += prop.get_global_mem_size();
-
-        info.devices[i].cc =
-            100 * prop.get_major_version() + 10 * prop.get_minor_version();
-
-        info.max_work_group_sizes[i] = prop.get_max_work_group_size();
-    }
-
-    for (int id = 0; id < info.device_count; ++id) {
-        info.default_tensor_split[id] /= total_vram;
-    }
-    return info;
-}
-
-const ggml_sycl_device_info & ggml_sycl_info() {
-    static ggml_sycl_device_info info = ggml_sycl_init();
-    return info;
-}
-
-void print_device_detail(int id, sycl::device &device, std::string device_type) {
-
-    dpct::device_info prop;
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        dpct::get_device_info(prop, device)));
-
-    std::string version;
-    version += std::to_string(prop.get_major_version());
-    version += ".";
-    version += std::to_string(prop.get_minor_version());
-
-    device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
-    std::string name = std::string(prop.get_name());
-    name = std::regex_replace(name, std::regex("\\(R\\)"), "");
-    name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
-
-    auto global_mem_size = prop.get_global_mem_size()/1000000;
-
-    fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
-            name.c_str(), version.c_str(), prop.get_max_compute_units(),
-            prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
-            global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
-}
-
-void ggml_backend_sycl_print_sycl_devices() {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
-    int device_count = dpct::dev_mgr::instance().device_count();
-    std::map<std::string, size_t> DeviceNums;
-    fprintf(stderr, "found %d SYCL devices:\n", device_count);
-    fprintf(stderr, "|  |                   |                                       |       |Max    |        |Max  |Global |                     |\n");
-    fprintf(stderr, "|  |                   |                                       |       |compute|Max work|sub  |mem    |                     |\n");
-    fprintf(stderr, "|ID|        Device Type|                                   Name|Version|units  |group   |group|size   |       Driver version|\n");
-    fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
-    for (int id = 0; id < device_count; ++id) {
-        sycl::device device = dpct::dev_mgr::instance().get_device(id);
-        sycl::backend backend = device.get_backend();
-        std::string backend_type = get_device_backend_and_type(device);
-        int type_id=DeviceNums[backend_type]++;
-        std::stringstream device_type;
-        device_type << "[" <<  backend_type << ":" << std::to_string(type_id) << "]";
-        print_device_detail(id, device, device_type.str());
-    }
-}
-
-static inline int get_sycl_env(const char *env_name, int default_val) {
-    char *user_device_string = getenv(env_name);
-    int user_number = default_val;
-
-    unsigned n;
-    if (user_device_string != NULL &&
-        sscanf(user_device_string, " %u", &n) == 1) {
-        user_number = (int)n;
-    } else {
-        user_number = default_val;
-    }
-    return user_number;
-}
-
-static void ggml_check_sycl() try {
-    static bool initialized = false;
-
-    if (!initialized) {
-        fprintf(stderr, "[SYCL] call ggml_check_sycl\n");
-        g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
-
-        fprintf(stderr, "%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
-
-#if defined(GGML_SYCL_F16)
-        fprintf(stderr, "%s: GGML_SYCL_F16: yes\n", __func__);
-#else
-        fprintf(stderr, "%s: GGML_SYCL_F16: no\n", __func__);
-#endif
-
-/* NOT REMOVE, keep it for next optimize for XMX.
-#if defined(SYCL_USE_XMX)
-        fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
-#else
-        fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
-#endif
-*/
-
-        if (CHECK_TRY_ERROR(g_all_sycl_device_count =
-                            dpct::dev_mgr::instance().device_count()) != 0) {
-            initialized = true;
-            g_sycl_loaded = false;
-            return;
-        }
-        GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
-        ggml_backend_sycl_print_sycl_devices();
-        initialized = true;
-        g_sycl_loaded = true;
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-/*
-device_index: device index from 0 to n (continue numbers).
-    It is used for device select/set in SYCL backend internal data structure.
-*/
-inline void check_allow_gpu_index(const int device_index) {
-  if (device_index >= ggml_sycl_info().device_count) {
-    char error_buf[256];
-    snprintf(
-        error_buf,
-        sizeof(error_buf),
-        "%s error: device_index:%d is out of range: [0-%d]",
-        __func__,
-        device_index,
-        ggml_sycl_info().device_count - 1);
-    fprintf(stderr, "%s\n", error_buf);
-    assert(false);
-  }
-}
-
-GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n");
-    for(int i=0;i<max_len;i++) id_list[i] = -1;
-
-    for (int i=0;i< ggml_sycl_info().device_count;i++){
-        if (i>=max_len) break;
-        id_list[i] = i;
-    }
-    return;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-// sycl buffer
-
-struct ggml_backend_sycl_buffer_context {
-    int device;
-    void * dev_ptr = nullptr;
-    queue_ptr stream;
-    std::string name;
-
-     ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
-        device(device), dev_ptr(dev_ptr), stream(stream) {
-            check_allow_gpu_index(device);
-            name = (GGML_SYCL_NAME + std::to_string(device));
-        }
-
-
-    ~ggml_backend_sycl_buffer_context() {
-        if (dev_ptr != nullptr) {
-            ggml_sycl_set_device(device);
-            SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
-        }
-    }
-};
-
-static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft);
-
-static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {
-    return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name;
-}
-
-static void
-ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
-    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-    ggml_sycl_set_device(ctx->device);
-
-    delete ctx;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
-    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-    return ctx->dev_ptr;
-}
-
-static void
-ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
-                                     ggml_tensor *tensor) try {
-    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
-
-    if (tensor->view_src != NULL && tensor->view_offs == 0) {
-        assert(tensor->view_src->buffer->buft == buffer->buft);
-        tensor->backend = tensor->view_src->backend;
-        tensor->extra = tensor->view_src->extra;
-        return;
-    }
-
-
-    if (ggml_is_quantized(tensor->type)) {
-        // initialize padding to 0 to avoid possible NaN values
-        size_t original_size = ggml_nbytes(tensor);
-        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
-
-        if (padded_size > original_size && tensor->view_src == nullptr) {
-            SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(
-                (char *)tensor->data + original_size, 0,
-                padded_size - original_size).wait()));
-        }
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
-                                                ggml_tensor *tensor,
-                                                const void *data, size_t offset,
-                                                size_t size) try {
-
-    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-
-    ggml_sycl_set_device(ctx->device);
-    auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
-    char* host_buf = (char*)malloc(size);
-    memcpy(host_buf, data, size);
-    SYCL_CHECK(
-        CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
-                             .wait()));
-    free(host_buf);
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
-                                                const ggml_tensor *tensor,
-                                                void *data, size_t offset,
-                                                size_t size) try {
-
-    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-
-    ggml_sycl_set_device(ctx->device);
-    auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
-
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        stream.memcpy(data, (const char *)tensor->data + offset, size)
-            .wait()));
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
-                    const void *ptr_src, size_t size) {
-    char *host_buf = (char *)malloc(size);
-    q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
-    q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
-    free(host_buf);
-}
-
-static bool
-ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
-                                    const ggml_tensor *src,
-                                    ggml_tensor *dst) try {
-    if (ggml_backend_buffer_is_sycl(src->buffer)) {
-        ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
-        ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
-
-        ggml_sycl_set_device(src_ctx->device);
-        /*
-        DPCT1009:198: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));
-        ggml_sycl_set_device(dst_ctx->device);
-        /*
-        DPCT1009:199: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
-        /*
-        DPCT1009:200: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-
-        queue_ptr stream_dst = dst_ctx->stream;
-        queue_ptr stream_src = src_ctx->stream;
-        size_t size = ggml_nbytes(src);
-
-        //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs.
-        dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);
-
-//todo, it's known issue:error in device2device cross GPUs. reused when the issue is fixed. DON"T remove
-#if 0
-        SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(
-            (char *)dst->data, (const char *)src->data, size).wait()));
-
-        /*
-        DPCT1009:201: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
-#endif
-        return true;
-    }
-    return false;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-
-static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
-                                           uint8_t value) try {
-     ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-
-    ggml_sycl_set_device(ctx->device);
-    queue_ptr stream = ctx->stream;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
-
-    SYCL_CHECK(CHECK_TRY_ERROR((*stream)
-                                    .memset(ctx->dev_ptr, value, buffer->size)
-                                    .wait()));
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
-    /* .free_buffer     = */ ggml_backend_sycl_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_sycl_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_sycl_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_sycl_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_sycl_buffer_get_tensor,
-    /* .cpy_tensor      = */ ggml_backend_sycl_buffer_cpy_tensor,
-    /* .clear           = */ ggml_backend_sycl_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// sycl buffer type
-struct ggml_backend_sycl_buffer_type_context {
-    int device;
-    std::string name;
-
-    // each buffer type has its own stream
-    queue_ptr stream = nullptr;
-};
-
-static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
-
-    return ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t
-ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
-                                           size_t size) try {
-    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
-    ggml_sycl_set_device(buft_ctx->device);
-    const queue_ptr stream = buft_ctx->stream;
-    size = std::max(size, (size_t)1); // syclMalloc returns null for size 0
-
-    void * dev_ptr;
-    SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(
-                                    size, *stream)));
-    if (!dev_ptr) {
-        fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, size);
-        return nullptr;
-    }
-    ggml_backend_sycl_buffer_context * ctx = new  ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);
-    return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return 128;
-    GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    return dpct::get_current_device().get_max_mem_alloc_size();
-
-    GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    size_t size = ggml_nbytes(tensor);
-    int64_t ne0 = tensor->ne[0];
-
-    if (ggml_is_quantized(tensor->type)) {
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-    }
-
-    return size;
-
-    GGML_UNUSED(buft);
-}
-
-static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_sycl_buffer_type_get_name,
-    /* .alloc_buffer     = */ ggml_backend_sycl_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_sycl_buffer_type_get_alignment,
-    /* .get_max_size     = */ ggml_backend_sycl_buffer_type_get_max_size,
-    /* .get_alloc_size   = */ ggml_backend_sycl_buffer_type_get_alloc_size,
-    /* .is_host          = */ NULL,
-};
-
-ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
-
-    auto dev_count = ggml_backend_sycl_get_device_count();
-
-    if (device>=dev_count or device<0) {
-        printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
-            device, dev_count-1);
-        GGML_ASSERT(device<dev_count);
-    }
-    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
-
-    static bool ggml_backend_sycl_buffer_type_initialized = false;
-
-    if (!ggml_backend_sycl_buffer_type_initialized) {
-        for (int i = 0; i < dev_count; i++) {
-            auto & device_i = dpct::dev_mgr::instance().get_device(i);
-            queue_ptr stream = &(device_i.default_queue());
-            ggml_backend_sycl_buffer_types[i] = {
-                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,
-                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), i),
-                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
-            };
-        }
-        ggml_backend_sycl_buffer_type_initialized = true;
-    }
-    return &ggml_backend_sycl_buffer_types[device];
-}
-
-ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
-
-    int device = ctx->device;
-    if (device>=ggml_sycl_info().device_count or device<0) {
-        printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
-            device, ggml_sycl_info().device_count-1);
-        GGML_ASSERT(device<ggml_sycl_info().device_count);
-    }
-    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
-
-    static bool ggml_backend_sycl_buffer_type_initialized = false;
-
-    if (!ggml_backend_sycl_buffer_type_initialized) {
-        for (int i = 0; i < ggml_sycl_info().device_count; i++) {
-            ggml_backend_sycl_buffer_types[i] = {
-                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,
-                /* .device   = */ nullptr,
-                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
-            };
-        }
-        ggml_backend_sycl_buffer_type_initialized = true;
-    }
-    return &ggml_backend_sycl_buffer_types[device];
-}
-
-// sycl split buffer
-
-static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {
-    int64_t min_compute_capability = INT_MAX;
-    int64_t max_compute_capability = INT_MIN;
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {
-            if (min_compute_capability > ggml_sycl_info().devices[i].cc) {
-                min_compute_capability = ggml_sycl_info().devices[i].cc;
-            }
-            if (max_compute_capability < ggml_sycl_info().devices[i].cc) {
-                max_compute_capability = ggml_sycl_info().devices[i].cc;
-            }
-        }
-    }
-
-    switch(type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-            return max_compute_capability >= VER_GEN9 ? 128 : 64;
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-            return 64;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_F32:
-            return 1;
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ4_NL:
-            return max_compute_capability >= VER_GEN9 ? 128 : 64;
-        case GGML_TYPE_IQ3_S:
-            return max_compute_capability >= VER_GEN9 ? 128 : 64;
-        case GGML_TYPE_Q6_K:
-            return 64;
-        default:
-            GGML_ABORT("fatal error");
-    }
-}
-
-static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
-    const int64_t nrows = ggml_nrows(tensor);
-    const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
-
-    *row_low = id == 0 ? 0 : nrows*tensor_split[id];
-    *row_low -= *row_low % rounding;
-    if (id == ggml_sycl_info().device_count - 1) {
-        *row_high = nrows;
-    } else {
-        *row_high = nrows*tensor_split[id + 1];
-        *row_high -= *row_high % rounding;
-    }
-}
-
-static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
-    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
-
-    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
-}
-
-struct ggml_backend_sycl_split_buffer_type_context {
-    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
-};
-
-struct ggml_backend_sycl_split_buffer_context {
-    ~ggml_backend_sycl_split_buffer_context() try {
-        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
-            for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-                for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
-                    if (extra->events[i][is] != nullptr) {
-                        /*
-                        DPCT1009:206: SYCL uses exceptions to report errors and
-                        does not use the error codes. The original code was
-                        commented out and a warning string was inserted. You
-                        need to rewrite this code.
-                        */
-                        SYCL_CHECK(CHECK_TRY_ERROR(
-                            dpct::destroy_event(extra->events[i][is])));
-                    }
-                }
-                if (extra->data_device[i] != nullptr) {
-                    /*
-                    DPCT1009:207: SYCL uses exceptions to report errors and does
-                    not use the error codes. The original code was commented out
-                    and a warning string was inserted. You need to rewrite this
-                    code.
-                    */
-                    ggml_sycl_set_device(i);
-                    SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
-                        extra->data_device[i], *(streams[i]))));
-                }
-            }
-            delete extra;
-        }
-    }
-    catch (sycl::exception const &exc) {
-      std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-                << ", line:" << __LINE__ << std::endl;
-      std::exit(1);
-    }
-
-    std::vector<ggml_tensor_extra_gpu *> tensor_extras;
-    std::vector<queue_ptr> streams;
-};
-
-static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
-    delete ctx;
-}
-
-static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
-    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
-    return (void *)0x1000;
-
-    GGML_UNUSED(buffer);
-}
-
-static void
-ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
-                                           ggml_tensor *tensor) try {
-    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
-
-    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
-    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
-
-    const int64_t ne0 = tensor->ne[0];
-
-    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
-
-    ctx->tensor_extras.push_back(extra);
-        ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
-
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
-
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
-
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
-
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-
-        // FIXME: do not crash if cudaMalloc fails
-        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
-        ggml_sycl_set_device(i);
-        const queue_ptr stream = ctx->streams[i];
-        char * buf;
-        /*
-        DPCT1009:208: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(
-                                        size, *stream)));
-        if (!buf) {
-            char err_buf[1024];
-            snprintf(err_buf, 1023, "%s: can't malloc %lu Bytes memory on device", __func__, size);
-            throw std::runtime_error(err_buf);
-        }
-        // set padding to 0 to avoid possible NaN values
-        if (size > original_size) {
-            /*
-            DPCT1009:209: SYCL uses exceptions to report errors and does not use
-            the error codes. The original code was commented out and a warning
-            string was inserted. You need to rewrite this code.
-            */
-            SYCL_CHECK(CHECK_TRY_ERROR(
-                (*stream)
-                    .memset(buf + original_size, 0, size - original_size)
-                    .wait()));
-        }
-
-        extra->data_device[i] = buf;
-
-        for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
-            /*
-            DPCT1009:210: SYCL uses exceptions to report errors and does not use
-            the error codes. The original code was commented out and a warning
-            string was inserted. You need to rewrite this code.
-            */
-            SYCL_CHECK(
-                CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
-        }
-    }
-    tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
-    tensor->extra = extra;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void
-ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
-                                          ggml_tensor *tensor, const void *data,
-                                          size_t offset, size_t size) try {
-    // split tensors must always be set in their entirety at once
-    GGML_ASSERT(offset == 0);
-    GGML_ASSERT(size == ggml_nbytes(tensor));
-
-    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
-    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
-
-    const int64_t ne0 = tensor->ne[0];
-    const size_t nb1 = tensor->nb[1];
-    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
-
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
-
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
-
-        const size_t offset_split = row_low*nb1;
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
-
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-
-        const char * buf_host = (const char *)data + offset_split;
-        /*
-        DPCT1009:211: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        ggml_sycl_set_device(i);
-        const queue_ptr stream = ctx->streams[i];
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            (*stream)
-                .memcpy(extra->data_device[i], buf_host, original_size)
-                .wait()));
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void
-ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
-                                          const ggml_tensor *tensor, void *data,
-                                          size_t offset, size_t size) try {
-    // split tensors must always be set in their entirety at once
-    GGML_ASSERT(offset == 0);
-    GGML_ASSERT(size == ggml_nbytes(tensor));
-
-    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
-    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
-
-    const int64_t ne0 = tensor->ne[0];
-    const size_t nb1 = tensor->nb[1];
-    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
-
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
-
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
-
-        const size_t offset_split = row_low*nb1;
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
-
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-
-        char * buf_host = (char *)data + offset_split;
-        /*
-        DPCT1009:212: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        ggml_sycl_set_device(i);
-        const queue_ptr stream = ctx->streams[i];
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            (*stream)
-                .memcpy(buf_host, extra->data_device[i], original_size)
-                .wait()));
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    GGML_UNUSED(buffer);
-    GGML_UNUSED(value);
-}
-
-static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
-    /* .free_buffer     = */ ggml_backend_sycl_split_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_sycl_split_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_sycl_split_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_sycl_split_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_sycl_split_buffer_get_tensor,
-    /* .cpy_tensor      = */ NULL,
-    /* .clear           = */ ggml_backend_sycl_split_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// sycl split buffer type
-
-static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    return GGML_SYCL_NAME "_Split";
-
-    GGML_UNUSED(buft);
-}
-
-static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
-   return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name;
-}
-
-static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
-    // instead, we allocate them for each tensor separately in init_tensor
-    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
-    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
-    ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();
-
-    return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);
-}
-
-static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return 128;
-    GGML_UNUSED(buft);
-}
-
-static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;
-
-    size_t total_size = 0;
-
-    const int64_t ne0 = tensor->ne[0];
-
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);
-
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
-
-        total_size += ggml_nbytes_split(tensor, nrows_split);
-
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-    }
-
-    return total_size;
-}
-
-static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
-    return false;
-
-    GGML_UNUSED(buft);
-}
-
-static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_sycl_split_buffer_type_get_name,
-    /* .alloc_buffer     = */ ggml_backend_sycl_split_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_sycl_split_buffer_type_get_alignment,
-    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-    /* .get_alloc_size   = */ ggml_backend_sycl_split_buffer_type_get_alloc_size,
-    /* .is_host          = */ ggml_backend_sycl_split_buffer_type_is_host,
-};
-
-ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
-    static std::mutex mutex;
-    std::lock_guard<std::mutex> lock(mutex);
-
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
-    ggml_check_sycl();
-    // FIXME: this is not thread safe
-    static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
-
-    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split_arr = {};
-
-    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });
-    if (all_zero) {
-        tensor_split_arr = ggml_sycl_info().default_tensor_split;
-    } else {
-        float split_sum = 0.0f;
-        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-            tensor_split_arr[i] = split_sum;
-            split_sum += tensor_split[i];
-        }
-        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-            tensor_split_arr[i] /= split_sum;
-        }
-    }
-
-    auto it = buft_map.find(tensor_split_arr);
-    if (it != buft_map.end()) {
-        return &it->second;
-    }
-
-    struct ggml_backend_buffer_type buft {
-        /* .iface   = */ ggml_backend_sycl_split_buffer_type_interface,
-        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
-        /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},
-    };
-
-    auto result = buft_map.emplace(tensor_split_arr, buft);
-    return &result.first->second;
-}
-
-// host buffer type
-
-static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    return GGML_SYCL_NAME "_Host";
-
-    GGML_UNUSED(buft);
-}
-
-static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    ggml_sycl_host_free(buffer->context);
-}
-
-static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    void * ptr = ggml_sycl_host_malloc(size);
-
-    if (ptr == nullptr) {
-        // fallback to cpu buffer
-        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
-    }
-
-    // FIXME: this is a hack to avoid having to implement a new buffer type
-    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
-    buffer->buft = buft;
-    buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
-
-    return buffer;
-}
-
-ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
-    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
-        /* .iface    = */ {
-            /* .get_name         = */ ggml_backend_sycl_host_buffer_type_name,
-            /* .alloc_buffer     = */ ggml_backend_sycl_host_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
-            /* .get_max_size     = */ NULL, // TODO: return device.maxBufferLength
-            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
-            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
-        },
-        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
-        /* .context  = */ nullptr,
-    };
-
-    return &ggml_backend_sycl_buffer_type_host;
-}
-
-// buffer pool for sycl (legacy)
-struct ggml_sycl_pool_leg : public ggml_sycl_pool {
-    static const int MAX_SYCL_BUFFERS = 256;
-
-    int device;
-    queue_ptr qptr;
-    struct ggml_sycl_buffer {
-        void * ptr = nullptr;
-        size_t size = 0;
-    };
-
-    ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};
-    size_t pool_size = 0;
-
-    explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) :
-        qptr(qptr_),
-        device(device_) {
-    }
-
-    ~ggml_sycl_pool_leg() {
-        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
-            ggml_sycl_buffer & b = buffer_pool[i];
-            if (b.ptr != nullptr) {
-                SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
-                pool_size -= b.size;
-            }
-        }
-        GGML_ASSERT(pool_size == 0);
-    }
-
-    void * alloc(size_t size, size_t * actual_size) override {
-#ifdef DEBUG_sycl_MALLOC
-        int nnz = 0;
-        size_t max_size = 0;
-#endif
-        size_t best_diff = 1ull << 36;
-        int ibest = -1;
-        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
-            ggml_sycl_buffer& b = buffer_pool[i];
-            if (b.ptr != nullptr) {
-#ifdef DEBUG_sycl_MALLOC
-                ++nnz;
-                if (b.size > max_size) max_size = b.size;
-#endif
-                if (b.size >= size) {
-                    size_t diff = b.size - size;
-                    if (diff < best_diff) {
-                        best_diff = diff;
-                        ibest = i;
-                        if (!best_diff) {
-                            void * ptr = b.ptr;
-                            *actual_size = b.size;
-                            b.ptr = nullptr;
-                            b.size = 0;
-                            return ptr;
-                        }
-                    }
-                }
-            }
-        }
-        if (ibest >= 0) {
-            ggml_sycl_buffer& b = buffer_pool[ibest];
-            void * ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
-        }
-        void * ptr;
-        size_t look_ahead_size = (size_t) (1.05 * size);
-
-        SYCL_CHECK(
-            CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(
-                                look_ahead_size, *qptr)));
-        if (!ptr) {
-            fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, look_ahead_size);
-            return nullptr;
-        }
-
-        *actual_size = look_ahead_size;
-        pool_size += look_ahead_size;
-
-    #ifdef DEBUG_SYCL_MALLOC
-        fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
-                (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
-    #endif
-        // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr);
-        return ptr;
-    }
-
-    void free(void * ptr, size_t size) override {
-        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
-            ggml_sycl_buffer& b = buffer_pool[i];
-            if (b.ptr == nullptr) {
-                b.ptr = ptr;
-                b.size = size;
-                return;
-            }
-        }
-        fprintf(stderr, "WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n");
-        SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
-        pool_size -= size;
-    }
-};
-
-std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
-    // TBD: NO VMM support
-    // if (ggml_sycl_info().devices[device].vmm) {
-    //     return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));
-    // }
-   return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
-}
-
-// TBD pool with virtual memory management
-// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
-
-/// kernels
-
-typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
-typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
-typedef void (*ggml_sycl_op_mul_mat_t)(
-    ggml_backend_sycl_context & ctx,
-    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
-    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
-    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
-    const int64_t src1_ncols, const int64_t src1_padded_row_size,
-    const queue_ptr &stream);
-
-
-
-template<int QUANT_BLOCK_TILE>
-static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
-
-    if (ix >= kx_padded) {
-        return;
-    }
-
-    const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                   item_ct1.get_local_id(1);
-
-    const int i_padded = iy*kx_padded + ix;
-
-    block_q8_1 * y = (block_q8_1 *) vy;
-
-    const int ib = i_padded / QK8_1; // block index
-    const int iqs = i_padded % QK8_1; // quant index
-    typedef  sycl::vec<float, QUANT_BLOCK_TILE> TC;
-    typedef  sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
-    TC zeros;
-    TQ qzeros;
-#pragma unroll
-    for (int i = 0; i < QUANT_BLOCK_TILE; i++)
-    {
-        zeros[i] = 0.f;
-        qzeros[i] = 0;
-    }
-    const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
-    float sum = xi[0];
-    float amax = sycl::fabs(xi[0]);
-#pragma unroll
-    for (int i = 1; i < QUANT_BLOCK_TILE; i++)
-    {
-        sum += xi[i];
-        amax = sycl::fmax(sycl::fabs(xi[i]), amax);
-    }
-    sum = warp_reduce_sum(sum, item_ct1);
-    amax = warp_reduce_max(amax, item_ct1);
-
-    const float d = amax / 127;
-    TQ q = qzeros;
-    if (amax != 0.0f)
-    {
-#pragma unroll
-        for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
-            q[i] = sycl::round(xi[i] / d);
-        }
-    }
-
-    *(TQ *)&y[ib].qs[iqs] = q;
-
-    if (iqs > 0) {
-        return;
-    }
-
-    reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
-    reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
-}
-
-template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void k_get_rows(
-            const void * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12,
-            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
-    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                     item_ct1.get_local_id(2)) *
-                    2;
-    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) /
-                    ne12;
-    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) %
-                    ne12;
-
-    if (i00 >= ne00) {
-        return;
-    }
-
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
-
-    const int ib = i00/qk; // block index
-    const int iqs = (i00%qk)/qr; // quant index
-    const int iybs = i00 - i00%qk; // dst block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
-
-    // dequantize
-    dfloat2 v;
-    dequantize_kernel(src0_row, ib, iqs, v);
-
-    dst_row[iybs + iqs + 0] = v.x();
-    dst_row[iybs + iqs + y_offset] = v.y();
-}
-
-template<typename src0_t, typename dst_t>
-static void k_get_rows_float(
-            const src0_t * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12,
-            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
-    const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                    item_ct1.get_local_id(2);
-    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) /
-                    ne12;
-    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) %
-                    ne12;
-
-    if (i00 >= ne00) {
-        return;
-    }
-
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
-
-    dst_row[i00] = src0_row[i00];
-}
-
-static void mul_mat_p021_f16_f32(
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
-    const sycl::nd_item<3> &item_ct1) {
-
-    const sycl::half *x = (const sycl::half *)vx;
-
-    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                      item_ct1.get_local_id(1);
-    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                        item_ct1.get_local_id(0);
-    const int channel_x = channel / (nchannels_y / nchannels_x);
-
-    const int nrows_y = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst = row_x;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x;
-         col_x0 += item_ct1.get_local_range(2)) {
-        const int col_x = col_x0 + item_ct1.get_local_id(2);
-
-        if (col_x >= ncols_x) {
-            break;
-        }
-
-        // x is transposed and permuted
-        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
-        const float xi =
-            sycl::vec<sycl::half, 1>(x[ix])
-                .convert<float, sycl::rounding_mode::automatic>()[0];
-
-        const int row_y = col_x;
-
-
-        // y is not transposed but permuted
-        const int iy = channel*nrows_y + row_y;
-
-        tmp += xi * y[iy];
-    }
-
-    // dst is not transposed and not permuted
-    const int idst = channel*nrows_dst + row_dst;
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[idst] = tmp;
-    }
-}
-
-static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
-    const sycl::nd_item<3> &item_ct1) {
-
-    const sycl::half *x = (const sycl::half *)vx;
-
-    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                      item_ct1.get_local_id(1);
-    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                        item_ct1.get_local_id(0);
-    const int channel_x = channel / channel_x_divisor;
-
-    const int nrows_y   = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst   = row_x;
-
-    const int idst = channel*nrows_dst + row_dst;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x;
-         col_x0 += item_ct1.get_local_range(2)) {
-        const int col_x = col_x0 + item_ct1.get_local_id(2);
-
-        if (col_x >= ncols_x) {
-            break;
-        }
-
-        const int row_y = col_x;
-
-        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
-        const int iy = channel*nrows_y + row_y;
-
-        const float xi =
-            sycl::vec<sycl::half, 1>(x[ix])
-                .convert<float, sycl::rounding_mode::automatic>()[0];
-
-        tmp += xi * y[iy];
-    }
-
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
-
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[idst] = tmp;
-    }
-}
-
-static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    float * dsti = (float *) cdsti;
-
-    *dsti = *xi;
-}
-
-static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    sycl::half *dsti = (sycl::half *)cdsti;
-
-    *dsti = sycl::vec<float, 1>(*xi)
-                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
-}
-
-static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
-    const sycl::half *xi = (const sycl::half *)cxi;
-    sycl::half *dsti = (sycl::half *)cdsti;
-
-    *dsti = *xi;
-}
-
-static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
-    const sycl::half *xi = (const sycl::half *)cxi;
-    float * dsti = (float *) cdsti;
-
-    *dsti = *xi;
-}
-
-static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
-    const int16_t *xi = (const int16_t *)cxi;
-    int16_t *dsti = (int16_t *)cdsti;
-
-    *dsti = *xi;
-}
-
-static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
-    const int32_t *xi = (const int32_t *)cxi;
-    int32_t *dsti = (int32_t *)cdsti;
-
-    *dsti = *xi;
-}
-
-template <cpy_kernel_t cpy_1>
-static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
-                        const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                        const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                        const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= ne) {
-        return;
-    }
-
-    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
-    // then combine those indices with the corresponding byte offsets to get the total offsets
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
-
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
-
-    cpy_1(cx + x_offset, cdst + dst_offset);
-}
-
-static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q8_0 * dsti = (block_q8_0 *) cdsti;
-
-    float amax = 0.0f; // absolute max
-
-    for (int j = 0; j < QK8_0; j++) {
-        const float v = xi[j];
-        amax = sycl::fmax(amax, sycl::fabs((float)v));
-    }
-
-    const float d = amax / ((1 << 7) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->d = d;
-
-    for (int j = 0; j < QK8_0; ++j) {
-        const float x0 = xi[j]*id;
-
-        dsti->qs[j] = sycl::round((float)x0);
-    }
-}
-
-static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q4_0 * dsti = (block_q4_0 *) cdsti;
-
-    float amax = 0.0f;
-    float vmax = 0.0f;
-
-    for (int j = 0; j < QK4_0; ++j) {
-        const float v = xi[j];
-        if (amax < sycl::fabs((float)v)) {
-            amax = sycl::fabs((float)v);
-            vmax = v;
-        }
-    }
-
-    const float d  = vmax / -8;
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->d = d;
-
-    for (int j = 0; j < QK4_0/2; ++j) {
-        const float x0 = xi[0       + j]*id;
-        const float x1 = xi[QK4_0/2 + j]*id;
-
-        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
-        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
-
-        dsti->qs[j]  = xi0;
-        dsti->qs[j] |= xi1 << 4;
-    }
-}
-
-static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q4_1 * dsti = (block_q4_1 *) cdsti;
-
-    float vmin = FLT_MAX;
-    float vmax = -FLT_MAX;
-
-    for (int j = 0; j < QK4_1; ++j) {
-        const float v = xi[j];
-
-        if (v < vmin) vmin = v;
-        if (v > vmax) vmax = v;
-    }
-
-    const float d  = (vmax - vmin) / ((1 << 4) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->dm.x() = d;
-    dsti->dm.y() = vmin;
-
-    for (int j = 0; j < QK4_1/2; ++j) {
-        const float x0 = (xi[0       + j] - vmin)*id;
-        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
-
-        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
-        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
-
-        dsti->qs[j]  = xi0;
-        dsti->qs[j] |= xi1 << 4;
-    }
-}
-
-template <cpy_kernel_t cpy_blck, int qk>
-static void cpy_f32_q(const char * cx, char * cdst, const int ne,
-                      const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                      const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                      const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
-    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                   item_ct1.get_local_id(2)) *
-                  qk;
-
-    if (i >= ne) {
-        return;
-    }
-
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
-
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
-
-    cpy_blck(cx + x_offset, cdst + dst_offset);
-}
-
-static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
-                           const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(1);
-    const int col = item_ct1.get_local_id(2);
-
-    float sum = 0.0f;
-    for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
-        sum += x[row * ncols + i];
-    }
-
-    sum = warp_reduce_sum(sum, item_ct1);
-
-    if (col == 0) {
-        dst[row] = sum;
-    }
-}
-
-
-template<typename T>
-static inline void ggml_sycl_swap(T & a, T & b) {
-    T tmp = a;
-    a = b;
-    b = tmp;
-}
-
-template <ggml_sort_order order>
-__dpct_inline__ static void
-k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
-                  const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
-    // bitonic sort
-    int col = item_ct1.get_local_id(2);
-    int row = item_ct1.get_group(1);
-
-    if (col >= ncols_pad) {
-        return;
-    }
-
-    const float * x_row = x + row * ncols;
-    auto dst_row = (int *)dpct_local;
-
-    // initialize indices
-    dst_row[col] = col;
-
-    item_ct1.barrier(sycl::access::fence_space::local_space);
-
-    for (int k = 2; k <= ncols_pad; k *= 2) {
-        for (int j = k / 2; j > 0; j /= 2) {
-            int ixj = col ^ j;
-            if (ixj > col) {
-                if ((col & k) == 0) {
-                    if (dst_row[col] >= ncols ||
-                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
-                    ) {
-                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
-                    }
-                } else {
-                    if (dst_row[ixj] >= ncols ||
-                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
-                    ) {
-                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
-                    }
-                }
-            }
-            /*
-            DPCT1118:1: SYCL group functions and algorithms must be encountered
-            in converged control flow. You may need to adjust the code.
-            */
-            item_ct1.barrier(sycl::access::fence_space::local_space);
-        }
-    }
-
-    // copy the result to dst without the padding
-    if (col < ncols) {
-        dst[row * ncols + col] = dst_row[col];
-    }
-}
-
-
-static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
-                              const sycl::nd_item<3> &item_ct1) {
-    const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-
-    if (col >= ncols) {
-        return;
-    }
-
-    const int i = row*ncols + col;
-    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
-    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
-    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
-}
-
-static void scale_f32(const float * x, float * dst, const float scale, const int k,
-                      const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-
-    dst[i] = scale * x[i];
-}
-
-static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
-                      const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-
-    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
-}
-
-template <typename Ti, typename To>
-static  void pool2d_nchw_kernel(
-        const int ih, const int iw, const int oh, const int ow,
-        const int kh, const int kw, const int sh, const int sw,
-        const int ph, const int pw, const int parallel_elements,
-        const Ti* src, To* dst, const enum ggml_op_pool op,
-        const sycl::nd_item<3> &item_ct1) {
-        int idx = item_ct1.get_local_id(2) +
-                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
-        if (idx >= parallel_elements) {
-            return;
-        }
-
-        const int I_HW = ih * iw;
-        const int O_HW = oh * ow;
-        const int nc = idx / O_HW;
-        const int cur_oh = idx % O_HW / ow;
-        const int cur_ow = idx % O_HW % ow;
-        const Ti* i_ptr = src + nc * I_HW;
-        To* o_ptr = dst + nc * O_HW;
-        const int start_h = cur_oh * sh - ph;
-        const int bh = sycl::max(0, start_h);
-        const int eh = sycl::min(ih, start_h + kh);
-        const int start_w = cur_ow * sw - pw;
-        const int bw = sycl::max(0, start_w);
-        const int ew = sycl::min(iw, start_w + kw);
-
-        To res = 0;
-
-        switch (op) {
-            case GGML_OP_POOL_AVG: res = 0; break;
-            case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
-        }
-
-        for (int i = bh; i < eh; i += 1) {
-            for (int j = bw; j < ew; j += 1) {
-#if DPCT_COMPATIBILITY_TEMP >= 350
-                /*
-                DPCT1098:106: The '*' expression is used instead of the __ldg
-                call. These two expressions do not provide the exact same
-                functionality. Check the generated code for potential precision
-                and/or performance issues.
-                */
-                Ti cur = *(i_ptr + i * iw + j);
-#else
-                Ti cur = i_ptr[i * iw + j];
-#endif
-                switch (op) {
-                    case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
-                    case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
-                }
-            }
-        }
-        o_ptr[cur_oh * ow + cur_ow] = res;
-}
-
-template <int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                          ggml_tensor *dst, const void *src0_dd,
-                          const int32_t *src1_dd, float *dst_dd,
-                          queue_ptr stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
-    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
-    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
-
-    // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
-
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
-
-    GGML_ASSERT(ne00 % 2 == 0);
-
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             k_get_rows<qk, qr, dq>(
-                                 src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
-                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
-                         });
-
-    (void) dst;
-}
-
-template <typename src0_t>
-static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const src0_t *src0_dd, const int32_t *src1_dd,
-                                float *dst_dd, queue_ptr stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
-    const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
-    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
-
-    // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
-
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
-
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
-                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
-            });
-    }
-
-    (void) dst;
-}
-
-
-static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
-                                   const int ky, const int kx_padded,
-                                   queue_ptr stream) {
-    const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
-    const sycl::range<3> num_blocks(1, ky, block_num_x);
-    int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
-    static_assert(QK8_1 % WARP_SIZE == 0);
-    const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(num_blocks * block_size, block_size),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
-            });
-    }
-}
-
-static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
-                                           float *dst, const int ncols_x,
-                                           const int nrows_x,
-                                           const int nchannels_x,
-                                           const int nchannels_y,
-                                           queue_ptr stream) {
-
-    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
-    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
-                                     nchannels_y, item_ct1);
-            });
-    }
-}
-
-static void ggml_mul_mat_vec_nc_f16_f32_sycl(
-    const void *vx, const float *y, float *dst, const int ncols_x,
-    const int nrows_x, const int row_stride_x, const int nchannels_x,
-    const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
-
-    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
-    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
-                                       row_stride_x, channel_stride_x,
-                                       nchannels_y / nchannels_x, item_ct1);
-            });
-    }
-}
-
-static void
-ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
-                      const int ne01, const int ne02, const int nb00,
-                      const int nb01, const int nb02, const int nb03,
-                      const int ne10, const int ne11, const int ne12,
-                      const int nb10, const int nb11, const int nb12,
-                      const int nb13, queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
-                                           nb01, nb02, nb03, ne10, ne11, ne12,
-                                           nb10, nb11, nb12, nb13, item_ct1);
-            });
-    }
-}
-
-static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
-    }
-}
-
-static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
-    }
-}
-
-static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
-                                   const int ne00, const int ne01,
-                                   const int ne02, const int nb00,
-                                   const int nb01, const int nb02,
-                                   const int nb03, const int ne10,
-                                   const int ne11, const int ne12,
-                                   const int nb10, const int nb11,
-                                   const int nb12, const int nb13,
-                                   queue_ptr stream) {
-
-    GGML_ASSERT(ne % QK8_0 == 0);
-    const int num_blocks = ne / QK8_0;
-    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
-                                           sycl::range<3>(1, 1, 1)),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
-                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                 item_ct1);
-                         });
-}
-
-static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
-                                   const int ne00, const int ne01,
-                                   const int ne02, const int nb00,
-                                   const int nb01, const int nb02,
-                                   const int nb03, const int ne10,
-                                   const int ne11, const int ne12,
-                                   const int nb10, const int nb11,
-                                   const int nb12, const int nb13,
-                                   queue_ptr stream) {
-
-    GGML_ASSERT(ne % QK4_0 == 0);
-    const int num_blocks = ne / QK4_0;
-    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
-                                           sycl::range<3>(1, 1, 1)),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
-                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                 item_ct1);
-                         });
-}
-
-static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
-                                   const int ne00, const int ne01,
-                                   const int ne02, const int nb00,
-                                   const int nb01, const int nb02,
-                                   const int nb03, const int ne10,
-                                   const int ne11, const int ne12,
-                                   const int nb10, const int nb11,
-                                   const int nb12, const int nb13,
-                                   queue_ptr stream) {
-
-    GGML_ASSERT(ne % QK4_1 == 0);
-    const int num_blocks = ne / QK4_1;
-    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
-                                           sycl::range<3>(1, 1, 1)),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
-                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                 item_ct1);
-                         });
-}
-
-static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
-    }
-}
-
-static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        // dpct::has_capability_or_fail(stream->get_device(),
-        //                              {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
-    }
-}
-
-static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        // dpct::has_capability_or_fail(stream->get_device(),
-        //                              {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
-    }
-}
-
-static void scale_f32_sycl(const float *x, float *dst, const float scale,
-                           const int k, queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            scale_f32(x, dst, scale, k, item_ct1);
-        });
-}
-
-static void clamp_f32_sycl(const float *x, float *dst, const float min,
-                           const float max, const int k,
-                           queue_ptr stream) {
-    const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            clamp_f32(x, dst, min, max, k, item_ct1);
-        });
-}
-
-static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
-                              const int nrows, queue_ptr stream) {
-    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-    const sycl::range<3> block_nums(1, nrows, 1);
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1)
-                             [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                                 k_sum_rows_f32(x, dst, ncols, item_ct1);
-                             });
-}
-
-static int next_power_of_2(int x) {
-    int n = 1;
-    while (n < x) {
-        n *= 2;
-    }
-    return n;
-}
-
-static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
-                                 const int nrows, ggml_sort_order order,
-                                 queue_ptr stream) {
-    // bitonic sort requires ncols to be power of 2
-    const int ncols_pad = next_power_of_2(ncols);
-
-    const sycl::range<3> block_dims(1, 1, ncols_pad);
-    const sycl::range<3> block_nums(1, nrows, 1);
-    const size_t shared_mem = ncols_pad * sizeof(int);
-
-    if (order == GGML_SORT_ORDER_ASC) {
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
-                sycl::range<1>(shared_mem), cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1) {
-                    k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
-                        x, dst, ncols, ncols_pad, item_ct1,
-                        dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
-                            .get());
-                });
-        });
-    } else if (order == GGML_SORT_ORDER_DESC) {
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
-                sycl::range<1>(shared_mem), cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1) {
-                    k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
-                        x, dst, ncols, ncols_pad, item_ct1,
-                        dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
-                            .get());
-                });
-        });
-    } else {
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
-                               const int nrows, queue_ptr stream) {
-    const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
-    const sycl::range<3> block_nums(1, nrows, 1);
-    const size_t shared_mem = 256 * sizeof(float);
-
-    stream->submit([&](sycl::handler &cgh) {
-        sycl::local_accessor<float, 1> shared_data(
-            sycl::range<1>(shared_mem/sizeof(float)), cgh);
-        sycl::local_accessor<int, 1> shared_indices(
-            sycl::range<1>(shared_mem/sizeof(float)), cgh);
-
-        cgh.parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                const int tid = item_ct1.get_local_id(2);
-                const int row = item_ct1.get_global_id(1);
-
-                float max_val = -INFINITY;
-                int max_idx = -1;
-
-                for (int col = tid; col < ncols; col += 256) {
-                    float val = x[row * ncols + col];
-                    if (val > max_val) {
-                        max_val = val;
-                        max_idx = col;
-                    }
-                }
-
-                shared_data[tid] = max_val;
-                shared_indices[tid] = max_idx;
-                item_ct1.barrier(sycl::access::fence_space::local_space);
-
-                for (int stride = 256/2; stride > 0; stride >>= 1) {
-                    if (tid < stride) {
-                        float val1 = shared_data[tid];
-                        float val2 = shared_data[tid + stride];
-                        if (val2 > val1) {
-                            shared_data[tid] = val2;
-                            shared_indices[tid] = shared_indices[tid + stride];
-                        }
-                    }
-                    item_ct1.barrier(sycl::access::fence_space::local_space);
-                }
-
-
-                if (tid == 0) {
-                    dst[row] = shared_indices[0];
-                }
-            });
-    });
-}
-static void diag_mask_inf_f32_sycl(const float *x, float *dst,
-                                   const int ncols_x, const int nrows_x,
-                                   const int rows_per_channel, const int n_past,
-                                   queue_ptr stream) {
-    const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
-    const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
-    const sycl::range<3> block_nums(1, block_num_x, nrows_x);
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             diag_mask_inf_f32(x, dst, ncols_x,
-                                               rows_per_channel, n_past,
-                                               item_ct1);
-                         });
-}
-
-static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
-                                          const struct ggml_tensor *src,
-                                          int64_t i3, int64_t i2,
-                                          int64_t i1_low, int64_t i1_high,
-                                          queue_ptr stream) try {
-
-    dpct::memcpy_direction kind;
-    char * src_ptr;
-    if (src->backend == GGML_BACKEND_TYPE_CPU) {
-        kind = dpct::host_to_device;
-        src_ptr = (char *) src->data;
-        // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d  GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
-    } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
-        GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
-        kind = dpct::device_to_device;
-        ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
-        int id;
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            id = get_current_device_id()));
-        // GGML_SYCL_DEBUG("current device index %d\n", id);
-        src_ptr = (char *) extra->data_device[id];
-    } else {
-        // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
-        GGML_ABORT("fatal error");
-    }
-    char * dst_ptr = (char *) dst;
-
-    GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
-    GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
-    const enum ggml_type type = src->type;
-    const int64_t ts = ggml_type_size(type);
-    const int64_t bs = ggml_blck_size(type);
-    int64_t i1_diff = i1_high - i1_low;
-
-    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
-    if (nb0 == ts && nb1 == ts*ne0/bs) {
-        // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
-        // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
-        return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
-                                    kind, *stream));
-
-    } else if (nb0 == ts) {
-        return CHECK_TRY_ERROR(
-            dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
-                                    ts * ne0 / bs, i1_diff, kind, *stream));
-    } else {
-        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
-            const void * rx = (const void *) ((const char *) x + i1*nb1);
-            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
-            // pretend the row is a matrix with cols=1
-            dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
-                rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
-            /*
-            DPCT1001:85: The statement could not be removed.
-            */
-            /*
-            DPCT1000:86: Error handling if-stmt was detected but could not be
-            rewritten.
-            */
-            if (r != 0) return r;
-        }
-        return 0;
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_d, const float *src1_d,
-                                  float *dst_d, const queue_ptr &stream) {
-
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
-
-    const int32_t * src1_i32 = (const int32_t *) src1_d;
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
-                                src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_F32:
-            get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q4_0:
-            get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        default:
-            // TODO: k-quants
-            fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
-            GGML_ABORT("fatal error");
-            break;
-    }
-}
-
-
-static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const float *src0_d, const float *src1_d,
-                                float *dst_d,
-                                const queue_ptr &main_stream) {
-
-    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
-
-    (void) src1;
-    (void) src1_d;
-}
-
-
-inline void ggml_sycl_op_mul_mat_sycl(
-    ggml_backend_sycl_context & ctx,
-    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
-    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
-    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
-    const int64_t src1_ncols, const int64_t src1_padded_row_size,
-    const queue_ptr &stream) try {
-
-    GGML_ASSERT(src0_dd_i  != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_dd_i   != nullptr);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne10 = src1->ne[0];
-
-    const int64_t ne0 = dst->ne[0];
-
-    const int64_t row_diff = row_high - row_low;
-
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
-
-    // the main device has a larger memory buffer to hold the results from all GPUs
-    // ldc == nrows of the matrix that cuBLAS writes into
-    int ldc = id == ctx.device ? ne0 : row_diff;
-
-#ifdef GGML_SYCL_F16
-    bool use_fp16 = true;  // TODO(Yu) SYCL capability check
-#else
-    bool use_fp16 = false;
-#endif
-    if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
-        use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
-        dst->op_params[0] == GGML_PREC_DEFAULT) {
-
-        // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
-        ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
-        if (src0->type != GGML_TYPE_F16) {
-            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
-            GGML_ASSERT(to_fp16_sycl != nullptr);
-            size_t ne = row_diff*ne00;
-            src0_as_f16.alloc(ne);
-            to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);
-        }
-        const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16
-                                         ? (const sycl::half *)src0_dd_i
-                                         : src0_as_f16.get();
-
-        ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
-        if (src1->type != GGML_TYPE_F16) {
-            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
-            GGML_ASSERT(to_fp16_sycl != nullptr);
-            size_t ne = src1_ncols*ne10;
-            src1_as_f16.alloc(ne);
-            to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
-        }
-        const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
-                ? (const sycl::half *)src1->data + src1_padded_row_size
-                                         : src1_as_f16.get();
-        ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
-
-        const sycl::half alpha_f16 = 1.0f;
-        const sycl::half beta_f16 = 0.0f;
-#if !GGML_SYCL_DNNL
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
-            *stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
-            &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
-            src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
-            dst_f16.get(), dpct::library_data_t::real_half, ldc,
-            dpct::library_data_t::real_half)));
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
-        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
-#else
-        auto dnnl_stream = ctx.stream_dnnl(stream);
-        DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
-            src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
-        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
-#endif
-    }
-    else {
-        // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
-        ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
-        ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
-        if (src0->type != GGML_TYPE_F32) {
-            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
-            GGML_ASSERT(to_fp32_sycl != nullptr);
-            src0_ddq_as_f32.alloc(row_diff*ne00);
-            to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
-        }
-        if (src1->type != GGML_TYPE_F32) {
-            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
-            GGML_ASSERT(to_fp32_sycl != nullptr);
-            src1_ddq_as_f32.alloc(src1_ncols*ne10);
-            to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
-        }
-        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
-        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
-
-        const float alpha = 1.0f;
-        const float beta = 0.0f;
-#if !GGML_SYCL_DNNL
-        SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
-            *stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
-            dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
-            src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
-            dst_dd_i, ldc)));
-#else
-        auto dnnl_stream = ctx.stream_dnnl(stream);
-         DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
-            src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
-#endif
-    }
-    (void) dst;
-    (void) src1_ddq_i;
-    (void) src1_padded_row_size;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const float *src0_dd, const float *src1_dd,
-                                float *dst_dd, const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    const int32_t * opts = (const int32_t *)dst->op_params;
-    enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
-    const int k0 = opts[1];
-    const int k1 = opts[2];
-    const int s0 = opts[3];
-    const int s1 = opts[4];
-    const int p0 = opts[5];
-    const int p1 = opts[6];
-
-    const int64_t IH = src0->ne[1];
-    const int64_t IW = src0->ne[0];
-
-    const int64_t N = dst->ne[3];
-    const int64_t OC = dst->ne[2];
-    const int64_t OH = dst->ne[1];
-    const int64_t OW = dst->ne[0];
-
-    const int parallel_elements = N * OC * OH * OW;
-    const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
-    sycl::range<3> block_nums(1, 1, num_blocks);
-    main_stream->parallel_for(
-        sycl::nd_range<3>(block_nums *
-                              sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
-                               parallel_elements, src0_dd, dst_dd, op,
-                               item_ct1);
-        });
-
-    (void) src1;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_dd, const float *src1_dd,
-                                  float *dst_dd,
-                                  const queue_ptr &main_stream) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    const int64_t ne = ggml_nelements(src0);
-
-    sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_dd, const float *src1_dd,
-                                  float *dst_dd,
-                                  const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    const int64_t ncols = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
-
-    sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 const float *src0_dd, const float *src1_dd,
-                                 float *dst_dd,
-                                 const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_I32);
-
-    const int64_t ncols = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
-
-    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
-
-    argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 const float *src0_dd, const float *src1_dd,
-                                 float *dst_dd,
-                                 const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_I32);
-
-    const int64_t ncols = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
-
-    argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                       const ggml_tensor *src1,
-                                       ggml_tensor *dst, const float *src0_dd,
-                                       const float *src1_dd, float *dst_dd,
-                                       const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int nrows0 = ggml_nrows(src0);
-
-    const int n_past = ((int32_t *) dst->op_params)[0];
-
-    diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                               ggml_tensor *dst, const float *src0_dd,
-                               const float *src1_dd, float *dst_dd,
-                               const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    float scale;
-    memcpy(&scale, dst->op_params, sizeof(float));
-
-    scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
-    /*
-    DPCT1010:87: SYCL uses exceptions to report errors and does not use the
-    error codes. The call was replaced with 0. You need to rewrite this code.
-    */
-    SYCL_CHECK(0);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                               ggml_tensor *dst, const float *src0_dd,
-                               const float *src1_dd, float *dst_dd,
-                               const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    float min;
-    float max;
-    memcpy(&min, dst->op_params, sizeof(float));
-    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
-
-    clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
-    /*
-    DPCT1010:88: SYCL uses exceptions to report errors and does not use the
-    error codes. The call was replaced with 0. You need to rewrite this code.
-    */
-    SYCL_CHECK(0);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
-    static bool peer_access_enabled = false;
-
-    const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;
-
-    if (peer_access_enabled == enable_peer_access) {
-        return;
-    }
-
-#ifdef NDEBUG
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        SYCL_CHECK(ggml_sycl_set_device(i));
-    }
-
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        SYCL_CHECK(ggml_sycl_set_device(i));
-
-        for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {
-            if (i == id_other) {
-                continue;
-            }
-            if (i != main_device && id_other != main_device) {
-                continue;
-            }
-
-            // int can_access_peer;
-            // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other));
-            // if (can_access_peer) {
-            //     if (enable_peer_access) {
-            //         SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0));
-            //     } else {
-            //         SYCL_CHECK(syclDeviceDisablePeerAccess(id_other));
-            //     }
-            // }
-        }
-    }
-#endif // NDEBUG
-
-    peer_access_enabled = enable_peer_access;
-}
-
-static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 ggml_sycl_op_mul_mat_t op,
-                                 const bool convert_src1_to_q8_1) try {
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
-
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
-    const int64_t nrows1 = ggml_nrows(src1);
-
-    GGML_ASSERT(ne03 == ne13);
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-
-    const int nb2 = dst->nb[2];
-    const int nb3 = dst->nb[3];
-
-    GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
-
-    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
-
-    const int64_t i02_divisor = ne12 / ne02;
-
-    const size_t src0_ts = ggml_type_size(src0->type);
-    const size_t src0_bs = ggml_blck_size(src0->type);
-    const size_t q8_1_ts = sizeof(block_q8_1);
-    const size_t q8_1_bs = QK8_1;
-
-    ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
-    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
-    ggml_tensor_extra_gpu *  dst_extra = (ggml_tensor_extra_gpu *)  dst->extra;
-
-    const bool src0_is_contiguous = ggml_is_contiguous(src0);
-    const bool src1_is_contiguous = ggml_is_contiguous(src1);
-
-    int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
-
-    const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
-    GGML_ASSERT(!(split && ne02 > 1));
-    GGML_ASSERT(!(split && ne03 > 1));
-    GGML_ASSERT(!(split && ne02 < ne12));
-
-    std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
-    if (split) {
-        // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
-        // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
-        ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
-        tensor_split = buft_ctx->tensor_split;
-    }
-
-    struct dev_data {
-        ggml_sycl_pool_alloc<char> src0_dd_alloc;
-        ggml_sycl_pool_alloc<float> src1_ddf_alloc;
-        ggml_sycl_pool_alloc<char> src1_ddq_alloc;
-        ggml_sycl_pool_alloc<float> dst_dd_alloc;
-
-        char *src0_dd = nullptr;
-        float *src1_ddf = nullptr; // float
-        char *src1_ddq = nullptr;  // q8_1
-        float *dst_dd = nullptr;
-
-        int64_t row_low;
-        int64_t row_high;
-    };
-
-    dev_data dev[GGML_SYCL_MAX_DEVICES];
-
-    int used_devices = 0;
-    queue_ptr main_stream = ctx.stream();
-
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        // by default, use all rows
-        dev[i].row_low  = 0;
-        dev[i].row_high = ne01;
-
-        // for multi GPU, get the row boundaries from tensor split
-        // and round to mul_mat_q tile sizes
-        if (split) {
-            const int64_t rounding = get_row_rounding(src0->type, tensor_split);
-
-            if (i != 0) {
-                dev[i].row_low  = ne01*tensor_split[i];
-                if (dev[i].row_low < ne01) {
-                    dev[i].row_low -= dev[i].row_low % rounding;
-                }
-            }
-
-            if (i != ggml_sycl_info().device_count - 1) {
-                dev[i].row_high  = ne01*tensor_split[i + 1];
-                if (dev[i].row_high < ne01) {
-                    dev[i].row_high -= dev[i].row_high % rounding;
-                }
-            }
-        }
-    }
-
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
-            continue;
-        }
-
-        used_devices++;
-
-        const bool src1_on_device = i == ctx.device;
-        const bool  dst_on_device = i == ctx.device;
-
-        ggml_sycl_set_device(i);
-        queue_ptr stream = ctx.stream(i, 0);
-
-        if (src0_is_contiguous) {
-            dev[i].src0_dd = (char *) src0->data;
-        } else {
-            dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));
-        }
-
-        if (src1_on_device && src1_is_contiguous) {
-            dev[i].src1_ddf = (float *) src1->data;
-        } else {
-            dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
-        }
-
-        if (convert_src1_to_q8_1) {
-            dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
-
-            if (src1_on_device && src1_is_contiguous) {
-                quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
-                /*
-                DPCT1010:90: SYCL uses exceptions to report errors and does not
-                use the error codes. The call was replaced with 0. You need to
-                rewrite this code.
-                */
-                SYCL_CHECK(0);
-            }
-        }
-
-        if (dst_on_device) {
-            dev[i].dst_dd = (float *) dst->data;
-        } else {
-            const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);
-            dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);
-        }
-    }
-
-    // if multiple devices are used they need to wait for the main device
-    // here an event is recorded that signals that the main device has finished calculating the input data
-    if (split && used_devices > 1) {
-        ggml_sycl_set_device(ctx.device);
-        /*
-        DPCT1024:91: The original code returned the error code that was further
-        consumed by the program logic. This original code was replaced with 0.
-        You may need to rewrite the program logic consuming the error code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            *src0_extra->events[ctx.device][0] =
-                ctx.stream()->ext_oneapi_submit_barrier()));
-    }
-
-    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
-    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
-        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
-        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
-
-        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-            if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
-                continue;
-            }
-
-            const bool src1_on_device = i == ctx.device;
-            const bool  dst_on_device = i == ctx.device;
-            const int64_t row_diff = dev[i].row_high - dev[i].row_low;
-
-            ggml_sycl_set_device(i);
-            queue_ptr stream = ctx.stream(i, is);
-
-            // wait for main GPU data if necessary
-            if (split && (i != ctx.device || is != 0)) {
-                /*
-                DPCT1009:163: SYCL uses exceptions to report errors and does not
-                use the error codes. The original code was commented out and a
-                warning string was inserted. You need to rewrite this code.
-                */
-                SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
-                    {*src0_extra->events[ctx.device][0]})));
-            }
-
-            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
-                const int64_t i03 = i0 / ne12;
-                const int64_t i02 = i0 % ne12;
-
-                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
-
-                // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
-                float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
-                char  * src1_ddq_i = dev[i].src1_ddq +  src1_ddq_i_offset;
-                float *   dst_dd_i =   dev[i].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
-
-                // the main device memory buffer can be on VRAM scratch, with space for all partial results
-                // in that case an offset on dst_ddf_i is needed
-                if (i == ctx.device) {
-                    dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split
-                }
-
-                // copy src0, src1 to device if necessary
-                if (src1_is_contiguous) {
-                    if (i != ctx.device) {
-                        if (convert_src1_to_q8_1) {
-                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
-                          SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
-                                src1_ddq_i, src1_ddq_i_source,
-                                src1_ncols * src1_padded_col_size * q8_1_ts /
-                                    q8_1_bs).wait()));
-                        } else {
-
-                            float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
-                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
-
-                            SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
-                                src1_ddf_i, src1_ddf_i_source,
-                                src1_ncols * ne10 * sizeof(float))));
-                        }
-                    }
-                } else if (src1_on_device && !src1_is_contiguous) {
-                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
-                                   src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
-                } else {
-                    GGML_ABORT("fatal error");
-                }
-
-                if (convert_src1_to_q8_1 && !src1_is_contiguous) {
-                    quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
-                    /*
-                    DPCT1010:92: SYCL uses exceptions to report errors and does
-                    not use the error codes. The call was replaced with 0. You
-                    need to rewrite this code.
-                    */
-                    SYCL_CHECK(0);
-                }
-
-                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
-                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
-                }
-                if (src1->type == GGML_TYPE_F16) {
-                    src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
-                }
-                // do the computation
-                SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
-                    dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
-                /*
-                DPCT1010:93: SYCL uses exceptions to report errors and does not
-                use the error codes. The call was replaced with 0. You need to
-                rewrite this code.
-                */
-                SYCL_CHECK(0);
-
-                // copy dst to host or other device if necessary
-                if (!dst_on_device) {
-                    void * dst_off_device = dst->data;
-                    if (split) {
-                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
-                        // dst is NOT transposed.
-                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
-                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
-                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
-                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
-                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
-                        dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;
-
-                        SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
-                            dhf_dst_i, ne0 * sizeof(float), dst_dd_i,
-                            row_diff * sizeof(float), row_diff * sizeof(float),
-                            src1_ncols, dpct::device_to_device, *stream)));
-                    } else {
-                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
-                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
-                        dhf_dst_i += src1_col_0*ne0;
-                        SYCL_CHECK(CHECK_TRY_ERROR(
-                            stream->memcpy(dhf_dst_i, dst_dd_i,
-                                           src1_ncols * ne0 * sizeof(float)).wait()));
-                    }
-                }
-
-                // add event for the main device to wait on until other device is done
-                if (split && (i != ctx.device || is != 0)) {
-                    /*
-                    DPCT1024:94: The original code returned the error code that
-                    was further consumed by the program logic. This original
-                    code was replaced with 0. You may need to rewrite the
-                    program logic consuming the error code.
-                    */
-                    SYCL_CHECK(CHECK_TRY_ERROR(
-                        *src0_extra->events[i][is] =
-                            stream->ext_oneapi_submit_barrier()));
-                }
-            }
-        }
-    }
-
-    // main device waits for all other devices to be finished
-    if (split && ggml_sycl_info().device_count > 1) {
-        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
-        is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;
-
-        ggml_sycl_set_device(ctx.device);
-        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-            if (dev[i].row_low == dev[i].row_high) {
-                continue;
-            }
-            for (int64_t is = 0; is < is_max; ++is) {
-                SYCL_CHECK(CHECK_TRY_ERROR(
-                    ctx.stream()->ext_oneapi_submit_barrier(
-                        {*src0_extra->events[i][is]})));
-            }
-        }
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-
-static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                       const ggml_tensor *src1,
-                                       ggml_tensor *dst) try {
-    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
-    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
-    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-
-    const int64_t ne12 = src1->ne[2];
-
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();
-
-    void  * src0_ddq = src0->data;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
-
-    ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                     const ggml_tensor *src1,
-                                     ggml_tensor *dst) try {
-    GGML_ASSERT(!ggml_is_transposed(src0));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-    GGML_ASSERT(!ggml_is_permuted(src0));
-    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-
-    const int64_t nb01 = src0->nb[1];
-    const int64_t nb02 = src0->nb[2];
-
-    const int64_t ne12 = src1->ne[2];
-
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();
-
-    void  * src0_ddq = src0->data;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
-
-    const int64_t row_stride_x = nb01 / sizeof(sycl::half);
-    const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
-
-    ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
-                                   const sycl::half *src1_as_f16, char *dst,
-                                   const void **ptrs_src, void **ptrs_dst,
-                                   int64_t ne12, int64_t ne13, int64_t ne23,
-                                   size_t nb02, size_t nb03, size_t nb12,
-                                   size_t nb13, size_t nbd2, size_t nbd3,
-                                   int64_t r2, int64_t r3,
-                                   const sycl::nd_item<3> &item_ct1) {
-    int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                  item_ct1.get_local_id(2);
-    int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
-                  item_ct1.get_local_id(1);
-
-    if (i13 >= ne13 || i12 >= ne12) {
-        return;
-    }
-
-    int64_t i03 = i13 / r3;
-    int64_t i02 = i12 / r2;
-
-    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
-    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
-}
-
-static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
-                                             const ggml_tensor *src0,
-                                             const ggml_tensor *src1,
-                                             ggml_tensor *dst) try {
-    GGML_ASSERT(!ggml_is_transposed(src0));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const int64_t ne_dst = ggml_nelements(dst);
-
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();;
-
-    void * src0_ddq = src0->data;
-    sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf = (float *) dst->data;
-
-    // convert src1 to fp16
-    ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
-    if (src1->type != GGML_TYPE_F16) {
-        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
-        const int64_t ne_src1 = ggml_nelements(src1);
-        src1_f16_alloc.alloc(ne_src1);
-        GGML_ASSERT(to_fp16_sycl != nullptr);
-        to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
-    }
-    sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
-                                                       : src1_f16_alloc.get();
-
-    char * dst_t;
-
-    dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
-    dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
-
-    // dst strides
-    size_t nbd2 = dst->nb[2];
-    size_t nbd3 = dst->nb[3];
-
-    const float alpha_f32 = 1.0f;
-    const float beta_f32 = 0.0f;
-
-    const void * alpha = &alpha_f32;
-    const void * beta  = &beta_f32;
-
-    dst_t = (char *) dst_ddf;
-
-    GGML_ASSERT(ne12 % ne02 == 0);
-    GGML_ASSERT(ne13 % ne03 == 0);
-
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
-
-    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
-        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
-            (const char *)src0_as_f16, dpct::library_data_t::real_half,
-            nb01 / nb00, nb02 / nb00,
-            (const char *)src1_f16, dpct::library_data_t::real_half,
-            nb11 / nb10, nb12 / nb10, beta,
-            (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
-            ne12 * ne13, cu_compute_type)));
-    } else {
-        const int ne23 = ne12*ne13;
-
-        ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
-        ggml_sycl_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
-
-        sycl::range<3> block_dims(1, ne12, ne13);
-        /*
-        DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(main_stream->get_device(),
-                                         {sycl::aspect::fp16});
-
-            main_stream->submit([&](sycl::handler &cgh) {
-                const void **ptrs_src_get = ptrs_src.get();
-                void **ptrs_dst_get = ptrs_dst.get();
-                size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
-                size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
-                cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
-                                 [=](sycl::nd_item<3> item_ct1) {
-                                     k_compute_batched_ptrs(
-                                         src0_as_f16, src1_f16,
-                                         dst_t, ptrs_src_get,
-                                         ptrs_dst_get, ne12, ne13, ne23,
-                                         nb02, nb03, nb12_scaled, nb13_scaled,
-                                         nbd2, nbd3, r2, r3, item_ct1);
-                                 });
-            });
-        }
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
-            (const void **)(ptrs_src.get() + 0 * ne23),
-            dpct::library_data_t::real_half, nb01 / nb00,
-            (const void **)(ptrs_src.get() + 1 * ne23),
-            dpct::library_data_t::real_half, nb11 / nb10, beta,
-            (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
-            cu_compute_type)));
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
-    // TODO: accuracy issues in MMQ
-    return false;
-}
-
-bool ggml_sycl_supports_dmmv(enum ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_F16:
-            return true;
-        default:
-            return false;
-    }
-}
-
-static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
-    int64_t min_compute_capability = INT_MAX;
-
-    if (split) {
-        ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
-        auto & tensor_split = buft_ctx->tensor_split;
-        for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
-            // skip devices that are not going to do any work:
-            if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {
-                continue;
-            }
-
-            if (min_compute_capability > ggml_sycl_info().devices[id].cc) {
-                min_compute_capability = ggml_sycl_info().devices[id].cc;
-            }
-        }
-    } else {
-        min_compute_capability    = ggml_sycl_info().devices[ctx.device].cc;
-    }
-
-    // check data types and tensor shapes for custom matrix multiplication kernels:
-    bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
-
-    bool use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
-
-    bool use_mul_mat_q =  ggml_sycl_supports_mmq(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
-
-    // mmvq and mmq need the __dp4a instruction which is available for gen12+
-    // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
-    use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
-#ifdef SYCL_USE_XMX
-    use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
-#endif // SYCL_USE_XMX
-
-    // mmvq path is faster in the CUDA backend.
-    if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
-        use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
-
-    if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
-        // KQ single-batch
-        ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
-        // KQV single-batch
-        ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
-        // KQ + KQV multi-batch
-        ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
-    } else if (use_dequantize_mul_mat_vec) {
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
-    } else if (use_mul_mat_vec_q) {
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
-    } else if (use_mul_mat_q) {
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
-    } else {
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
-    }
-}
-
-
-struct mmid_row_mapping {
-    int32_t i1;
-    int32_t i2;
-};
-
-__dpct_inline__ static void k_copy_src1_to_contiguous(
-    const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
-    int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
-    const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
-    int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
-    const sycl::nd_item<3> &item_ct1, int &src1_row) {
-    int32_t iid1 = item_ct1.get_group(2);
-    int32_t id = item_ct1.get_group(1);
-
-    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
-
-    if (row_id_i != i02) {
-        return;
-    }
-
-    const int64_t i11 = id % ne11;
-    const int64_t i12 = iid1;
-
-    if (item_ct1.get_local_id(2) == 0) {
-        src1_row =
-            dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
-                cur_src1_row, 1);
-        row_mapping[src1_row] = {id, iid1};
-    }
-    /*
-    DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
-    sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
-    performance if there is no access to global memory.
-    */
-    item_ct1.barrier();
-
-    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
-    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
-
-#pragma unroll
-    for (int i = item_ct1.get_local_id(2); i < ne10;
-         i += item_ct1.get_local_range(2)) {
-        src1_row_contiguous[i] = src1_row_original[i];
-    }
-}
-
-__dpct_inline__ static void k_copy_dst_from_contiguous(
-    char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
-    const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
-    size_t nb2, const sycl::nd_item<3> &item_ct1) {
-    int32_t i = item_ct1.get_group(2);
-
-    const int32_t i1 = row_mapping[i].i1;
-    const int32_t i2 = row_mapping[i].i2;
-
-    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
-    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
-
-#pragma unroll
-    for (int j = item_ct1.get_local_id(2); j < ne0;
-         j += item_ct1.get_local_range(2)) {
-        dst_row_original[j] = dst_row_contiguous[j];
-    }
-}
-
-static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1,
-                                 ggml_tensor *dst) try {
-    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
-
-    const ggml_tensor *ids = dst->src[2];
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    const queue_ptr stream = ctx.stream();
-
-    const int64_t n_as = ne02;
-    const int64_t n_ids = ids->ne[0];
-
-    std::vector<char> ids_host(ggml_nbytes(ids));
-    const char * ids_dev = (const char *) ids->data;
-
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
-    SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
-
-    ggml_tensor src0_row = *src0;
-    ggml_tensor src1_row = *src1;
-    ggml_tensor dst_row = *dst;
-
-    char *src0_original = (char *)src0->data;
-    char *src1_original = (char *)src1->data;
-    char *dst_original = (char *)dst->data;
-
-    src0_row.ne[2] = 1;
-    src0_row.ne[3] = 1;
-    src0_row.nb[3] = nb02;
-
-    src1_row.ne[1] = 1;
-    src1_row.ne[2] = 1;
-    src1_row.ne[3] = 1;
-    src1_row.nb[2] = nb11;
-    src1_row.nb[3] = nb11;
-
-    dst_row.ne[1] = 1;
-    dst_row.ne[2] = 1;
-    dst_row.ne[3] = 1;
-    dst_row.nb[2] = nb1;
-    dst_row.nb[3] = nb1;
-    if (ne12 == 1) {
-        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-            for (int64_t id = 0; id < n_ids; id++) {
-                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-                GGML_ASSERT(i02 >= 0 && i02 < n_as);
-
-                const int64_t i11 = id % ne11;
-                const int64_t i12 = iid1;
-
-                const int64_t i1 = id;
-                const int64_t i2 = i12;
-
-            src0_row.data = src0_original + i02*nb02;
-            src1_row.data = src1_original + + i11*nb11 + i12*nb12;
-            dst_row.data = dst_original + i1*nb1   + i2*nb2;
-
-            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-            }
-        }
-    } else {
-        ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
-        ggml_sycl_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
-
-        src1_row.data = src1_contiguous.get();
-        dst_row.data  =  dst_contiguous.get();
-
-        for (int64_t i02 = 0; i02 < n_as; i02++) {
-            int64_t num_src1_rows = 0;
-            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-                for (int64_t id = 0; id < n_ids; id++) {
-                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-
-                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
-
-                    if (row_id_i != i02) {
-                        continue;
-                    }
-
-                    num_src1_rows++;
-                }
-            }
-
-            if (num_src1_rows == 0) {
-                continue;
-            }
-
-
-            ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
-            ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
-            SYCL_CHECK(CHECK_TRY_ERROR(
-                stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
-
-            {
-                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
-                sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
-                stream->submit([&](sycl::handler &cgh) {
-                    sycl::local_accessor<int, 0> src1_row_acc(cgh);
-
-                    char *__restrict src1_contiguous_get =
-                        src1_contiguous.get();
-                    int *__restrict dev_cur_src1_row_get =
-                        dev_cur_src1_row.get();
-                    mmid_row_mapping *__restrict dev_row_mapping_get =
-                        dev_row_mapping.get();
-                    size_t ids_nb_ct6 = ids->nb[1];
-                    size_t ids_nb_ct7 = ids->nb[0];
-
-                    cgh.parallel_for(
-                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
-                        [=](sycl::nd_item<3> item_ct1) {
-                            k_copy_src1_to_contiguous(
-                                src1_original, src1_contiguous_get,
-                                dev_cur_src1_row_get,
-                                dev_row_mapping_get, ids_dev, i02,
-                                ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
-                                item_ct1, src1_row_acc);
-                        });
-                });
-            }
-
-            src0_row.data = src0_original + i02*nb02;
-
-            GGML_ASSERT(nb11 == sizeof(float)*ne10);
-            GGML_ASSERT(nb1 == sizeof(float)*ne0);
-            src1_row.ne[1] = num_src1_rows;
-
-            src1_row.nb[1] = nb11;
-            src1_row.nb[2] = num_src1_rows*nb11;
-            src1_row.nb[3] = num_src1_rows*nb11;
-
-            dst_row.ne[1] = num_src1_rows;
-            dst_row.nb[1] = nb1;
-            dst_row.nb[2] = num_src1_rows*nb1;
-            dst_row.nb[3] = num_src1_rows*nb1;
-
-            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-
-            {
-                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
-                sycl::range<3> grid_dims(1, 1, num_src1_rows);
-                stream->submit([&](sycl::handler &cgh) {
-                    const char *__restrict dst_contiguous_get =
-                        dst_contiguous.get();
-                    const mmid_row_mapping *__restrict dev_row_mapping_get =
-                        dev_row_mapping.get();
-
-                    cgh.parallel_for(
-                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
-                        [=](sycl::nd_item<3> item_ct1) {
-                            k_copy_dst_from_contiguous(dst_original,
-                                                       dst_contiguous_get,
-                                                       dev_row_mapping_get,
-                                                       ne0, nb1, nb2, item_ct1);
-                        });
-                });
-            }
-        }
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
-}
-
-static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
-}
-
-static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                          ggml_tensor *dst) try {
-    const int64_t ne = ggml_nelements(src0);
-    GGML_ASSERT(ne == ggml_nelements(src1));
-
-    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
-    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
-
-    GGML_TENSOR_BINARY_OP_LOCALS01;
-
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();
-
-    char * src0_ddc = (char *) src0->data;
-    char * src1_ddc = (char *) src1->data;
-
-    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
-        ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
-        ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
-        ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
-        ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
-        ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else {
-        fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
-                ggml_type_name(src0->type), ggml_type_name(src1->type));
-        GGML_ABORT("fatal error");
-    }
-
-    (void) dst;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    // TODO: why do we pass dst as src1 here?
-    ggml_sycl_cpy(ctx, src0, dst, nullptr);
-    (void) src1;
-}
-
-static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
-}
-
-static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
-}
-
-static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
-}
-
-static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
-}
-
-static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
-}
-
-static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
-}
-
-static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
-}
-
-static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
-}
-
-static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
-}
-
-static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    (void) src0;
-    (void) src1;
-    (void) dst;
-}
-
-void ggml_sycl_set_main_device(const int main_device) try {
-    if (dpct::get_current_device_id() == main_device) return;
-    check_allow_gpu_index(main_device);
-    dpct::select_device(main_device);
-
-    if (g_ggml_sycl_debug) {
-        dpct::device_info prop;
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
-            prop, dpct::dev_mgr::instance().get_device(main_device))));
-        fprintf(stderr, "Using device %d (%s) as main device\n",
-                main_device, prop.get_name());
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
-    if (!g_sycl_loaded) return false;
-
-    ggml_sycl_func_t func;
-
-    switch (tensor->op) {
-        case GGML_OP_ARGMAX:
-            func = ggml_sycl_argmax;
-            break;
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            func = ggml_sycl_op_conv_transpose_1d;
-            break;
-        case GGML_OP_REPEAT:
-            func = ggml_sycl_repeat;
-            break;
-        case GGML_OP_GET_ROWS:
-            func = ggml_sycl_get_rows;
-            break;
-        case GGML_OP_DUP:
-            func = ggml_sycl_dup;
-            break;
-        case GGML_OP_ADD:
-        case GGML_OP_ADD1: // TODO: more efficient implementation
-            func = ggml_sycl_add;
-            break;
-        case GGML_OP_SUB:
-            func = ggml_sycl_sub;
-            break;
-        case GGML_OP_ACC:
-            func = ggml_sycl_acc;
-            break;
-        case GGML_OP_MUL:
-            func = ggml_sycl_mul;
-            break;
-        case GGML_OP_LOG:
-            func = ggml_sycl_log;
-            break;
-        case GGML_OP_DIV:
-            func = ggml_sycl_div;
-            break;
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(tensor)) {
-                case GGML_UNARY_OP_NEG:
-                    func = ggml_sycl_neg;
-                    break;
-                case GGML_UNARY_OP_STEP:
-                    func = ggml_sycl_step;
-                    break;
-                case GGML_UNARY_OP_GELU:
-                    func = ggml_sycl_gelu;
-                    break;
-                case GGML_UNARY_OP_SILU:
-                    func = ggml_sycl_silu;
-                    break;
-                case GGML_UNARY_OP_GELU_QUICK:
-                    func = ggml_sycl_gelu_quick;
-                    break;
-                case GGML_UNARY_OP_TANH:
-                    func = ggml_sycl_tanh;
-                    break;
-                case GGML_UNARY_OP_RELU:
-                    func = ggml_sycl_relu;
-                    break;
-                case GGML_UNARY_OP_SIGMOID:
-                    func = ggml_sycl_sigmoid;
-                    break;
-                case GGML_UNARY_OP_HARDSIGMOID:
-                    func = ggml_sycl_hardsigmoid;
-                    break;
-                case GGML_UNARY_OP_HARDSWISH:
-                    func = ggml_sycl_hardswish;
-                    break;
-                case GGML_UNARY_OP_EXP:
-                    func = ggml_sycl_exp;
-                    break;
-                default:
-                    return false;
-            }
-            break;
-        case GGML_OP_NORM:
-            func = ggml_sycl_norm;
-            break;
-        case GGML_OP_GROUP_NORM:
-            func = ggml_sycl_group_norm;
-            break;
-        case GGML_OP_CONCAT:
-            func = ggml_sycl_op_concat;
-            break;
-        case GGML_OP_UPSCALE:
-            func = ggml_sycl_upscale;
-            break;
-        case GGML_OP_PAD:
-            func = ggml_sycl_pad;
-            break;
-        case GGML_OP_LEAKY_RELU:
-            func = ggml_sycl_leaky_relu;
-            break;
-        case GGML_OP_RMS_NORM:
-            func = ggml_sycl_rms_norm;
-            break;
-        case GGML_OP_MUL_MAT:
-            if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
-                return false;
-            }
-            func = ggml_sycl_mul_mat;
-            break;
-        case GGML_OP_MUL_MAT_ID:
-            if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
-                return false;
-            }
-            func = ggml_sycl_mul_mat_id;
-            break;
-        case GGML_OP_OUT_PROD:
-            func = ggml_sycl_op_out_prod;
-            break;
-        case GGML_OP_SCALE:
-            func = ggml_sycl_scale;
-            break;
-        case GGML_OP_SQR:
-            func = ggml_sycl_sqr;
-            break;
-        case GGML_OP_SQRT:
-            func = ggml_sycl_sqrt;
-            break;
-        case GGML_OP_SIN:
-            func = ggml_sycl_sin;
-            break;
-        case GGML_OP_COS:
-            func = ggml_sycl_cos;
-            break;
-        case GGML_OP_CLAMP:
-            func = ggml_sycl_clamp;
-            break;
-        case GGML_OP_CPY:
-            func = ggml_sycl_cpy;
-            break;
-        case GGML_OP_CONT:
-            func = ggml_sycl_dup;
-            break;
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-            func = ggml_sycl_nop;
-            break;
-        case GGML_OP_DIAG_MASK_INF:
-            func = ggml_sycl_diag_mask_inf;
-            break;
-        case GGML_OP_SOFT_MAX:
-            func = ggml_sycl_soft_max;
-            break;
-        case GGML_OP_ROPE:
-            func = ggml_sycl_rope;
-            break;
-        case GGML_OP_IM2COL:
-            func = ggml_sycl_im2col;
-            break;
-        case GGML_OP_POOL_2D:
-            func = ggml_sycl_pool2d;
-            break;
-        case GGML_OP_SUM:
-            func = ggml_sycl_sum;
-            break;
-        case GGML_OP_SUM_ROWS:
-            func = ggml_sycl_sum_rows;
-            break;
-        case GGML_OP_ARGSORT:
-            func = ggml_sycl_argsort;
-            break;
-        case GGML_OP_TIMESTEP_EMBEDDING:
-            func = ggml_sycl_op_timestep_embedding;
-            break;
-        case GGML_OP_RWKV_WKV6:
-            func = ggml_sycl_op_rwkv_wkv6;
-            break;
-        default:
-            return false;
-    }
-
-    if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
-        ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
-    }
-
-    func(ctx, tensor->src[0], tensor->src[1], tensor);
-    return true;
-}
-
-GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
-                                      size_t description_size) try {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n");
-    dpct::device_info prop;
-    SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
-        prop, dpct::dev_mgr::instance().get_device(device))));
-    snprintf(description, description_size, "%s", prop.get_name());
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-void ggml_backend_sycl_get_device_memory(int device, size_t *free,
-                                                   size_t *total) try {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
-    ggml_sycl_set_device(device);
-
-    /*
-    DPCT1009:218: SYCL uses exceptions to report errors and does not use the
-    error codes. The original code was commented out and a warning string was
-    inserted. You need to rewrite this code.
-    */
-    /*
-    DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
-    device information which may not be supported by all compilers or runtimes.
-    You may need to adjust the code.
-    */
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend
-
-static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) {
-
-    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-
-    return sycl_ctx->name.c_str();
-}
-
-static void ggml_backend_sycl_free(ggml_backend_t backend) {
-    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-
-    delete sycl_ctx;
-    delete backend;
-}
-
-static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
-                                               ggml_tensor *tensor,
-                                               const void *data, size_t offset,
-                                               size_t size) try {
-    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
-    GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
-    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        (stream)->memcpy((char *)tensor->data + offset, data, size)));
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
-                                               const ggml_tensor *tensor,
-                                               void *data, size_t offset,
-                                               size_t size) try {
-    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
-
-    GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
-    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
-    SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
-        data, (const char *)tensor->data + offset, size).wait()));
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
-                                               const ggml_tensor *src,
-                                               ggml_tensor *dst) try {
-    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-    if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
-        /*
-        DPCT1009:215: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
-        SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
-            dst->data, src->data, ggml_nbytes(dst)).wait()));
-        return true;
-    }
-
-    return false;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
-    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-    const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
-    SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
-
-    GGML_UNUSED(backend);
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-    ggml_sycl_set_main_device(sycl_ctx->device);
-
-
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_tensor * node = cgraph->nodes[i];
-        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
-            continue;
-        }
-#ifndef NDEBUG
-        assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
-        for (int j = 0; j < GGML_MAX_SRC; j++) {
-            if (node->src[j] != nullptr) {
-                assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
-            }
-        }
-#endif
-        bool ok = ggml_sycl_compute_forward(*sycl_ctx, node);
-        if (!ok) {
-            fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
-        }
-        GGML_ASSERT(ok);
-    }
-
-    return GGML_STATUS_SUCCESS;
-}
-
-static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event)
-try
-{
-    ggml_backend_sycl_context *sycl_ctx =
-        (ggml_backend_sycl_context *)backend->context;
-    sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
-
-    const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0);
-    // Record the current state of the queue
-    SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier()));
-}
-catch (sycl::exception const &exc)
-{
-    std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-              << ", line:" << __LINE__ << std::endl;
-    std::exit(1);
-}
-
-static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
-    ggml_backend_sycl_context* sycl_ctx = static_cast<ggml_backend_sycl_context*>(backend->context);
-    sycl::event* sycl_event = static_cast<sycl::event*>(event->context);
-
-    if (ggml_backend_is_sycl(backend)) {
-        SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
-    } else
-        GGML_ABORT("fatal error");
-} catch (sycl::exception const& exc) {
-    std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-              << ", line:" << __LINE__ << std::endl;
-    std::exit(1);
-}
-
-static ggml_backend_i ggml_backend_sycl_interface = {
-    /* .get_name                = */ ggml_backend_sycl_get_name,
-    /* .free                    = */ ggml_backend_sycl_free,
-    /* .set_tensor_async        = */ ggml_backend_sycl_set_tensor_async,
-    /* .get_tensor_async        = */ ggml_backend_sycl_get_tensor_async,
-    /* .cpy_tensor_async        = */ NULL, // ggml_backend_sycl_cpy_tensor_async,
-                                           // // TODO: update for the new
-                                           // interface
-    /* .synchronize             = */ ggml_backend_sycl_synchronize,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_sycl_graph_compute,
-    /* .event_record            = */ ggml_backend_sycl_event_record,
-    /* .event_wait              = */ ggml_backend_sycl_event_wait,
-};
-
-static ggml_guid_t ggml_backend_sycl_guid() {
-    static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };
-    return &guid;
-}
-
-bool ggml_backend_is_sycl(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());
-}
-
-int ggml_backend_sycl_get_device_count() {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
-    return ggml_sycl_info().device_count;
-}
-
-
-// backend device
-
-struct ggml_backend_sycl_device_context {
-    int device;
-    std::string name;
-    std::string description;
-};
-
-static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
-    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
-    return ctx->name.c_str();
-}
-
-static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) {
-    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
-    return ctx->description.c_str();
-}
-
-static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
-    ggml_sycl_set_device(ctx->device);
-    SYCL_CHECK(CHECK_TRY_ERROR(
-    dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total)));
-}
-
-static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) {
-    GGML_UNUSED(dev);
-    return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_sycl_device_get_name(dev);
-    props->description = ggml_backend_sycl_device_get_description(dev);
-    props->type        = ggml_backend_sycl_device_get_type(dev);
-    ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total);
-
-    bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr;
-#ifdef GGML_SYCL_NO_PEER_COPY
-    bool events = false;
-#else
-    bool events = true;
-#endif
-
-    props->caps = {
-        /* .async                 = */ true,
-        /* .host_buffer           = */ host_buffer,
-        /* .buffer_from_host_ptr  = */ false,
-        /* .events                = */ events,
-    };
-}
-
-static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) {
-    GGML_UNUSED(params);
-    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
-    return ggml_backend_sycl_init(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) {
-    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
-    return ggml_backend_sycl_buffer_type(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) {
-    GGML_UNUSED(dev);
-    return ggml_backend_sycl_host_buffer_type();
-}
-
-static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
-    GGML_UNUSED(dev);
-    GGML_UNUSED(ptr);
-    GGML_UNUSED(size);
-    GGML_UNUSED(max_tensor_size);
-    return nullptr;
-}
-
-static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    switch (op->op) {
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                ggml_type src1_type = op->src[1]->type;
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                return false;
-            } break;
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_NEG:
-                case GGML_UNARY_OP_STEP:
-                case GGML_UNARY_OP_GELU:
-                case GGML_UNARY_OP_SILU:
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_SIGMOID:
-                case GGML_UNARY_OP_HARDSIGMOID:
-                case GGML_UNARY_OP_HARDSWISH:
-                case GGML_UNARY_OP_GELU_QUICK:
-                case GGML_UNARY_OP_TANH:
-                case GGML_UNARY_OP_EXP:
-                    return ggml_is_contiguous(op->src[0]);
-                default:
-                    return false;
-            }
-            break;
-        case GGML_OP_MUL_MAT:
-        case GGML_OP_MUL_MAT_ID:
-            {
-                struct ggml_tensor * a;
-                struct ggml_tensor * b;
-                if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                    if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
-                        // TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
-                        return false;
-                    }
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
-                if (a->ne[3] != b->ne[3]) {
-                    return false;
-                }
-                ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ4_NL  || a_type == GGML_TYPE_IQ4_XS ||
-                    a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S  ||
-                    a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
-                    a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
-                    ) {
-                    if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
-                        return false;
-                    }
-                }
-                ggml_type src0_type = op->src[0]->type;
-                if (src0_type == GGML_TYPE_BF16) {
-                    return false;
-                }
-                return true;
-            } break;
-        case GGML_OP_OUT_PROD:
-            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
-        case GGML_OP_GET_ROWS:
-            {
-                switch (op->src[0]->type) {
-                    case GGML_TYPE_F16:
-                    case GGML_TYPE_F32:
-                    case GGML_TYPE_Q4_0:
-                    case GGML_TYPE_Q4_1:
-                    case GGML_TYPE_Q5_0:
-                    case GGML_TYPE_Q5_1:
-                    case GGML_TYPE_Q8_0:
-                        return true;
-                    default:
-                        return false;
-                }
-            } break;
-        case GGML_OP_CPY:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                ggml_type src1_type = op->src[1]->type;
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                return false;
-            } break;
-        case GGML_OP_CONCAT:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
-            } break;
-        case GGML_OP_DUP:
-        case GGML_OP_ARGMAX:
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_REPEAT:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_NORM:
-        case GGML_OP_ADD:
-        case GGML_OP_ADD1:
-        case GGML_OP_LOG:
-        case GGML_OP_SUB:
-        case GGML_OP_MUL:
-        case GGML_OP_DIV:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_SCALE:
-        case GGML_OP_SQR:
-        case GGML_OP_SQRT:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-        case GGML_OP_CLAMP:
-            return true;
-        case GGML_OP_CONT:
-            return op->src[0]->type != GGML_TYPE_BF16;
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX:
-            return true;
-        case GGML_OP_ROPE:
-            return ggml_is_contiguous(op->src[0]);
-        case GGML_OP_IM2COL:
-            // TODO: add support for the new F32 operations
-            return op->src[0]->type == GGML_TYPE_F16;
-        case GGML_OP_POOL_2D:
-        case GGML_OP_SUM:
-        case GGML_OP_SUM_ROWS:
-        case GGML_OP_ARGSORT:
-        case GGML_OP_ACC:
-        case GGML_OP_GROUP_NORM:
-        case GGML_OP_UPSCALE:
-        case GGML_OP_PAD:
-        case GGML_OP_LEAKY_RELU:
-        case GGML_OP_TIMESTEP_EMBEDDING:
-        case GGML_OP_RWKV_WKV6:
-            return true;
-        default:
-            return false;
-    }
-
-    GGML_UNUSED(dev);
-}
-
-static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) {
-        return false;
-    }
-    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
-    ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
-    return buft_ctx->device == sycl_ctx->device;
-}
-
-static int64_t get_op_batch_size(const ggml_tensor * op) {
-    switch (op->op) {
-        case GGML_OP_GET_ROWS:
-            return op->ne[1]; // this will increse the speed of prefill in test
-        case GGML_OP_MUL_MAT:
-            return op->ne[1];
-        case GGML_OP_MUL_MAT_ID:
-        case GGML_OP_ROPE:
-            return op->ne[2];
-        default:
-            return ggml_nrows(op);
-    }
-}
-
-static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    const int min_batch_size = 32;
-    return get_op_batch_size(op) >= min_batch_size;
-    GGML_UNUSED(dev);
-}
-
-static ggml_backend_event_t
-ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) {
-
-#ifdef GGML_SYCL_NO_PEER_COPY
-    return nullptr;
-#else
-  sycl::event *event_ptr = new sycl::event();
-
-  return new ggml_backend_event{
-      /* .device = */ dev,
-      /* .context = */ event_ptr,
-  };
-#endif
-}
-
-static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
-  GGML_UNUSED(dev);
-  if (event == nullptr) {
-    return;
-  }
-
-  if (event->context != nullptr) {
-    sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
-    delete sycl_event;
-    event->context = nullptr;
-  }
-
-  delete event;
-} catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-
-static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
-  GGML_UNUSED(dev);
-
-  sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
-  SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
-} catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static const ggml_backend_device_i ggml_backend_sycl_device_interface = {
-    /* .get_name                = */ ggml_backend_sycl_device_get_name,
-    /* .get_description         = */ ggml_backend_sycl_device_get_description,
-    /* .get_memory              = */ ggml_backend_sycl_device_get_memory,
-    /* .get_type                = */ ggml_backend_sycl_device_get_type,
-    /* .get_props               = */ ggml_backend_sycl_device_get_props,
-    /* .init_backend            = */ ggml_backend_sycl_device_init,
-    /* .get_buffer_type         = */ ggml_backend_sycl_device_get_buffer_type,
-    /* .get_host_buffer_type    = */ ggml_backend_sycl_device_get_host_buffer_type,
-    /* .buffer_from_host_ptr    = */ ggml_backend_sycl_device_buffer_from_host_ptr,
-    /* .supports_op             = */ ggml_backend_sycl_device_supports_op,
-    /* .supports_buft           = */ ggml_backend_sycl_device_supports_buft,
-    /* .offload_op              = */ ggml_backend_sycl_device_offload_op,
-    /* .event_new               = */ ggml_backend_sycl_device_event_new,
-    /* .event_free              = */ ggml_backend_sycl_device_event_free,
-    /* .event_synchronize       = */ ggml_backend_sycl_device_event_synchronize,
-};
-
-// backend reg
-
-struct ggml_backend_sycl_reg_context {
-    std::vector<ggml_backend_dev_t> devices;
-};
-
-static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) {
-    GGML_UNUSED(reg);
-    return GGML_SYCL_NAME;
-}
-
-static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) {
-    ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
-    return ctx->devices.size();
-}
-
-static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) {
-    ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
-    GGML_ASSERT(index < ctx->devices.size());
-    return ctx->devices[index];
-}
-
-static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
-    GGML_UNUSED(reg);
-
-    // TODO: update to the current function signature
-    //if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
-    //    return (void *)ggml_backend_sycl_split_buffer_type;
-    //}
-
-    // SYCL doesn't support registering host memory, left here for reference
-    // "ggml_backend_register_host_buffer"
-    // "ggml_backend_unregister_host_buffer"
-    return nullptr;
-}
-
-static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {
-    /* .get_name          = */ ggml_backend_sycl_reg_get_name,
-    /* .get_device_count  = */ ggml_backend_sycl_reg_get_device_count,
-    /* .get_device_get    = */ ggml_backend_sycl_reg_get_device,
-    /* .get_proc_address  = */ ggml_backend_sycl_reg_get_proc_address,
-};
-
-
-// backend registry
-
-ggml_backend_reg_t ggml_backend_sycl_reg() {
-    static ggml_backend_reg reg;
-    static bool initialized = false;
-
-    {
-        static std::mutex mutex;
-        std::lock_guard<std::mutex> lock(mutex);
-        if (!initialized) {
-            ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
-
-            for (int i = 0; i < ggml_sycl_info().device_count; i++) {
-                ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
-                dev_ctx->device = i;
-                dev_ctx->name = GGML_SYCL_NAME + std::to_string(i);
-
-                ggml_sycl_set_device(i);
-
-                dpct::device_info prop;
-                SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
-                    prop, dpct::dev_mgr::instance().get_device(i))));
-
-                dev_ctx->description = prop.get_name();
-
-                ggml_backend_dev_t dev = new ggml_backend_device {
-                    /* .interface = */ ggml_backend_sycl_device_interface,
-                    /* .reg       = */ &reg,
-                    /* .context   = */ dev_ctx
-                };
-                ctx->devices.push_back(dev);
-            }
-
-            reg = ggml_backend_reg {
-                /* .interface = */ ggml_backend_sycl_reg_interface,
-                /* .context   = */ ctx
-            };
-        }
-
-        initialized = true;
-    }
-
-    return &reg;
-}
-
-ggml_backend_t ggml_backend_sycl_init(int device) {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n");
-    ggml_check_sycl();
-
-    check_allow_gpu_index(device);
-
-    ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device);
-    if (ctx == nullptr) {
-        fprintf(stderr, "%s: error: failed to allocate context\n", __func__);
-        return nullptr;
-    };
-
-    ggml_backend_t sycl_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_sycl_guid(),
-        /* .interface = */ ggml_backend_sycl_interface,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
-        /* .context   = */ ctx
-    };
-
-    return sycl_backend;
-}
-
diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp
deleted file mode 100644 (file)
index 169b5a3..0000000
+++ /dev/null
@@ -1,7648 +0,0 @@
-#include "ggml-vulkan.h"
-#include <vulkan/vulkan_core.h>
-#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
-#include <chrono>
-#endif
-
-#include <vulkan/vulkan.hpp>
-
-#include <algorithm>
-#include <cmath>
-#include <iomanip>
-#include <iostream>
-#include <tuple>
-#include <vector>
-#include <sstream>
-#include <utility>
-#include <memory>
-#include <limits>
-#include <map>
-#include <unordered_map>
-#include <memory>
-#include <mutex>
-#include <future>
-#include <thread>
-
-#include "ggml-impl.h"
-#include "ggml-backend-impl.h"
-
-#include "ggml-vulkan-shaders.hpp"
-
-#define VK_API_VERSION VK_API_VERSION_1_2
-
-#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
-
-#define VK_VENDOR_ID_AMD 0x1002
-#define VK_VENDOR_ID_APPLE 0x106b
-#define VK_VENDOR_ID_INTEL 0x8086
-#define VK_VENDOR_ID_NVIDIA 0x10de
-
-#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32
-
-#define GGML_VK_MAX_NODES 8192
-
-#define MAX_VK_BUFFERS 256
-
-#ifndef K_QUANTS_PER_ITERATION
-#define K_QUANTS_PER_ITERATION 1
-#else
-static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
-#endif
-
-#define VK_CHECK(err, msg)                                          \
-    do {                                                            \
-        vk::Result err_ = (err);                                    \
-        if (err_ != vk::Result::eSuccess) {                         \
-            fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n",  \
-                #err, to_string(err_).c_str(), __FILE__, __LINE__); \
-            exit(1);                                                \
-        }                                                           \
-    } while (0)
-
-#ifdef GGML_VULKAN_DEBUG
-#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl
-#else
-#define VK_LOG_DEBUG(msg) ((void) 0)
-#endif // GGML_VULKAN_DEBUG
-
-struct ggml_backend_vk_context;
-
-struct vk_queue {
-    uint32_t queue_family_index;
-    vk::Queue queue;
-    vk::CommandPool pool;
-    uint32_t cmd_buffer_idx;
-    std::vector<vk::CommandBuffer> cmd_buffers;
-
-    vk::PipelineStageFlags stage_flags;
-
-    bool transfer_only;
-};
-
-struct vk_pipeline_struct {
-    std::string name;
-    vk::ShaderModule shader_module;
-    vk::DescriptorSetLayout dsl;
-    std::vector<vk::DescriptorPool> descriptor_pools;
-    std::vector<vk::DescriptorSet> descriptor_sets;
-    uint32_t descriptor_set_idx;
-    vk::PipelineLayout layout;
-    vk::Pipeline pipeline;
-    uint32_t push_constant_size;
-    uint32_t parameter_count;
-    std::array<uint32_t, 3> wg_denoms;
-    uint32_t align;
-};
-
-typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
-typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
-
-static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
-
-struct vk_matmul_pipeline_struct {
-    vk_pipeline l, m, s;
-    vk_pipeline a_l, a_m, a_s;
-};
-
-typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
-
-struct vk_device_struct;
-typedef std::shared_ptr<vk_device_struct> vk_device;
-typedef std::weak_ptr<vk_device_struct> vk_device_ref;
-
-struct vk_buffer_struct;
-typedef std::shared_ptr<vk_buffer_struct> vk_buffer;
-typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;
-
-struct ggml_backend_vk_buffer_type_context {
-    std::string name;
-    vk_device device;
-};
-
-static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
-static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
-static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
-static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
-static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
-static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_vk_buffer_type_name,
-    /* .alloc_buffer     = */ ggml_backend_vk_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_vk_buffer_type_get_alignment,
-    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
-    /* .get_alloc_size   = */ ggml_backend_vk_buffer_type_get_alloc_size,
-    /* .is_host          = */ NULL,
-};
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-class vk_memory_logger;
-#endif
-#ifdef GGML_VULKAN_PERF
-class vk_perf_logger;
-#endif
-static void ggml_vk_destroy_buffer(vk_buffer& buf);
-
-struct vk_device_struct {
-    std::mutex mutex;
-
-    vk::PhysicalDevice physical_device;
-    vk::PhysicalDeviceProperties properties;
-    std::string name;
-    uint64_t max_memory_allocation_size;
-    bool fp16;
-    vk::Device device;
-    uint32_t vendor_id;
-    vk_queue compute_queue;
-    vk_queue transfer_queue;
-    bool single_queue;
-    uint32_t subgroup_size;
-    bool uma;
-
-    size_t idx;
-
-    vk_matmul_pipeline pipeline_matmul_f32;
-    vk_matmul_pipeline pipeline_matmul_f32_f16;
-    vk_matmul_pipeline pipeline_matmul_f16;
-    vk_matmul_pipeline pipeline_matmul_f16_f32;
-    vk_pipeline pipeline_matmul_split_k_reduce;
-
-    vk_matmul_pipeline pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
-
-    vk_matmul_pipeline pipeline_matmul_id_f32;
-    vk_matmul_pipeline pipeline_matmul_id_f16;
-    vk_matmul_pipeline pipeline_matmul_id_f16_f32;
-
-    vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
-
-    vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
-    vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
-    vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
-    vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
-
-    vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
-    vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
-    vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
-    vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
-    vk_pipeline pipeline_acc_f32;
-    vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
-    vk_pipeline pipeline_mul_f32;
-    vk_pipeline pipeline_div_f32;
-    vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
-    vk_pipeline pipeline_upscale_f32;
-    vk_pipeline pipeline_scale_f32;
-    vk_pipeline pipeline_sqr_f32;
-    vk_pipeline pipeline_sin_f32;
-    vk_pipeline pipeline_cos_f32;
-    vk_pipeline pipeline_clamp_f32;
-    vk_pipeline pipeline_pad_f32;
-    vk_pipeline pipeline_repeat_f32;
-    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
-    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
-    vk_pipeline pipeline_norm_f32;
-    vk_pipeline pipeline_group_norm_f32;
-    vk_pipeline pipeline_rms_norm_f32;
-    vk_pipeline pipeline_gelu_f32;
-    vk_pipeline pipeline_gelu_quick_f32;
-    vk_pipeline pipeline_silu_f32;
-    vk_pipeline pipeline_relu_f32;
-    vk_pipeline pipeline_leaky_relu_f32;
-    vk_pipeline pipeline_tanh_f32;
-    vk_pipeline pipeline_diag_mask_inf_f32;
-    vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
-    vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
-    vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
-    vk_pipeline pipeline_argsort_f32;
-    vk_pipeline pipeline_sum_rows_f32;
-    vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
-    vk_pipeline pipeline_timestep_embedding_f32;
-    vk_pipeline pipeline_pool2d_f32;
-
-    std::unordered_map<std::string, vk_pipeline_ref> pipelines;
-    std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
-
-    std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
-
-    vk::Fence fence;
-    vk_buffer sync_staging;
-
-    ggml_backend_buffer_type buffer_type;
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-    std::unique_ptr<vk_memory_logger> memory_logger;
-#endif
-#ifdef GGML_VULKAN_PERF
-    std::unique_ptr<vk_perf_logger> perf_logger;
-#endif
-
-    ~vk_device_struct() {
-        VK_LOG_DEBUG("destroy device " << name);
-
-        device.destroyFence(fence);
-
-        ggml_vk_destroy_buffer(sync_staging);
-
-        device.destroyCommandPool(compute_queue.pool);
-        if (!single_queue) {
-            device.destroyCommandPool(transfer_queue.pool);
-        }
-
-        for (auto& pipeline : pipelines) {
-            if (pipeline.second.expired()) {
-                continue;
-            }
-
-            vk_pipeline pl = pipeline.second.lock();
-            ggml_vk_destroy_pipeline(device, pl);
-        }
-        pipelines.clear();
-
-        device.destroy();
-    }
-};
-
-struct vk_buffer_struct {
-    vk::Buffer buffer = VK_NULL_HANDLE;
-    vk::DeviceMemory device_memory = VK_NULL_HANDLE;
-    vk::MemoryPropertyFlags memory_property_flags;
-    void * ptr;
-    size_t size = 0;
-
-    vk_device device;
-
-    ~vk_buffer_struct() {
-        if (size == 0) {
-            return;
-        }
-        VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")");
-
-        device->device.freeMemory(device_memory);
-        device->device.destroyBuffer(buffer);
-    }
-};
-
-struct vk_subbuffer {
-    vk_buffer buffer;
-    uint64_t offset;
-    uint64_t size;
-
-    operator vk::DescriptorBufferInfo() const {
-        return { buffer->buffer, offset, size };
-    }
-};
-
-struct vk_semaphore {
-    vk::Semaphore s;
-    uint64_t value;
-};
-
-struct vk_submission {
-    vk::CommandBuffer buffer;
-    std::vector<vk_semaphore> wait_semaphores;
-    std::vector<vk_semaphore> signal_semaphores;
-};
-
-typedef std::vector<vk_submission> vk_sequence;
-
-struct vk_mat_mat_push_constants {
-    uint32_t M; uint32_t N; uint32_t K;
-    uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
-    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
-    uint32_t k_split;
-    uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
-};
-struct vk_mat_vec_push_constants {
-    uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
-    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
-    uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
-};
-
-struct vk_mat_mat_id_push_constants {
-    uint32_t M; uint32_t N; uint32_t K;
-    uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
-    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
-    uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
-};
-struct vk_mat_vec_id_push_constants {
-    uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
-    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
-    uint32_t nei0; uint32_t ne11;
-};
-
-struct vk_op_push_constants {
-    uint32_t KX;
-    uint32_t KY;
-    float param1;
-    float param2;
-};
-
-struct vk_op_unary_push_constants {
-    uint32_t ne;
-    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
-    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
-    uint32_t d_offset;
-    float param1; float param2;
-};
-
-struct vk_op_binary_push_constants {
-    uint32_t ne;
-    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
-    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
-    uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
-    uint32_t d_offset;
-    float param1; float param2; int32_t param3;
-};
-
-struct vk_op_diag_mask_push_constants {
-    uint32_t ncols;
-    uint32_t rows_per_channel;
-    int32_t n_past;
-};
-
-struct vk_op_rope_push_constants {
-    uint32_t ncols;
-    uint32_t n_dims;
-    float freq_scale;
-    uint32_t p_delta_rows;
-    float freq_base;
-    float ext_factor;
-    float attn_factor;
-    float corr_dims[2];
-    float theta_scale;
-    uint32_t has_ff;
-};
-
-struct vk_op_soft_max_push_constants {
-    uint32_t KX;
-    uint32_t KY;
-    float scale;
-    float max_bias;
-    float m0;
-    float m1;
-    uint32_t n_head_log2;
-};
-
-struct vk_op_argsort_push_constants {
-    uint32_t ncols;
-    uint32_t ncols_pad;
-    int32_t order;
-};
-
-struct vk_op_im2col_push_constants {
-    uint32_t batch_offset; uint32_t offset_delta;
-    uint32_t IC;
-    uint32_t IW; uint32_t IH;
-    uint32_t OW; uint32_t OH;
-    uint32_t KW; uint32_t KH;
-    uint32_t pelements;
-    uint32_t CHW;
-    int32_t s0; int32_t s1;
-    int32_t p0; int32_t p1;
-    int32_t d0; int32_t d1;
-};
-
-struct vk_op_timestep_embedding_push_constants {
-    uint32_t nb1;
-    uint32_t dim;
-    uint32_t max_period;
-};
-
-struct vk_op_pool2d_push_constants {
-    uint32_t IW; uint32_t IH;
-    uint32_t OW; uint32_t OH;
-    uint32_t OC;
-    uint32_t pelements;
-    uint32_t op;
-    int32_t k0; int32_t k1;
-    int32_t s0; int32_t s1;
-    int32_t p0; int32_t p1;
-};
-
-// Allow pre-recording command buffers
-struct vk_staging_memcpy {
-    vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
-
-    void * dst;
-    const void * src;
-    size_t n;
-};
-
-struct vk_op_upscale_push_constants {
-    uint32_t ne; uint32_t d_offset;
-    uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
-    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
-    float sf0; float sf1; float sf2; float sf3;
-};
-
-struct vk_context_struct {
-    vk_submission * s;
-    std::vector<vk_sequence> seqs;
-
-    int exit_tensor_idx;
-
-    std::vector<vk_staging_memcpy> in_memcpys;
-    std::vector<vk_staging_memcpy> out_memcpys;
-
-    vk_queue * q;
-};
-typedef std::shared_ptr<vk_context_struct> vk_context;
-typedef std::weak_ptr<vk_context_struct> vk_context_ref;
-
-struct ggml_vk_garbage_collector {
-    std::vector<vk_semaphore> tl_semaphores;
-    std::vector<vk_semaphore> semaphores;
-    std::vector<vk::Event> events;
-    std::vector<vk_buffer> temp_buffers;
-    std::vector<vk_context> contexts;
-};
-
-#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG)
-#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl
-
-static std::string format_size(size_t size) {
-    const size_t kib = 1024;
-    const size_t mib = kib * 1024;
-    const size_t gib = mib * 1024;
-
-    std::ostringstream oss;
-    oss << std::fixed << std::setprecision(2);
-
-    if (size >= gib) {
-        oss << static_cast<double>(size) / gib << " GiB";
-    } else if (size >= mib) {
-        oss << static_cast<double>(size) / mib << " MiB";
-    } else if (size >= kib) {
-        oss << static_cast<double>(size) / kib << " KiB";
-    } else {
-        oss << size << " B";
-    }
-
-    return oss.str();
-}
-
-static std::mutex log_mutex;
-
-class vk_memory_logger {
-public:
-    vk_memory_logger(): total_device(0), total_host(0) {}
-    void log_allocation(vk_buffer_ref buf_ref, size_t size);
-    void log_deallocation(vk_buffer_ref buf_ref);
-
-private:
-    std::map<vk::Buffer, size_t> allocations; // Track allocations
-    size_t total_device;
-    size_t total_host;
-};
-#else
-#define VK_LOG_MEMORY(msg) ((void) 0)
-#endif // GGML_VULKAN_MEMORY_DEBUG
-
-#if defined(GGML_VULKAN_PERF)
-
-class vk_perf_logger {
-public:
-    void print_timings() {
-        std::cerr << "----------------\nVulkan Timings:" << std::endl;
-        for (const auto& t : timings) {
-            uint64_t total = 0;
-            for (const auto& time : t.second) {
-                total += time;
-            }
-            std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl;
-        }
-
-        timings.clear();
-    }
-
-    void log_timing(const ggml_tensor * node, uint64_t time) {
-        if (node->op == GGML_OP_UNARY) {
-            timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
-            return;
-        }
-        if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
-            const uint64_t m = node->src[0]->ne[1];
-            const uint64_t n = node->src[1]->ne[1];
-            const uint64_t k = node->src[1]->ne[0];
-            std::string name = ggml_op_name(node->op);
-            if (n == 1) {
-                name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
-            } else {
-                name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
-            }
-            timings[name].push_back(time);
-            return;
-        }
-        timings[ggml_op_name(node->op)].push_back(time);
-    }
-private:
-    std::map<std::string, std::vector<uint64_t>> timings;
-};
-#endif // GGML_VULKAN_PERF
-
-struct ggml_backend_vk_context {
-    std::string name;
-
-    vk_device device;
-
-    size_t semaphore_idx, event_idx;
-    ggml_vk_garbage_collector gc;
-    size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
-    vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
-    vk::Fence fence;
-
-    vk_buffer buffer_pool[MAX_VK_BUFFERS];
-
-    vk_context_ref compute_ctx;
-    vk_context_ref transfer_ctx;
-
-    std::vector<vk_context_ref> tensor_ctxs;
-};
-
-static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT
-
-static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
-    if (tensor->view_src) {
-        return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
-    }
-    return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
-}
-
-struct ggml_backend_vk_buffer_context {
-    vk_device_ref device;
-    vk_buffer dev_buffer;
-    std::string name;
-
-    ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
-        device(device),
-        dev_buffer(dev_buffer),
-        name(name) {
-    }
-
-    ~ggml_backend_vk_buffer_context() {
-        ggml_vk_destroy_buffer(dev_buffer);
-    }
-};
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
-    std::lock_guard<std::mutex> guard(log_mutex);
-    vk_buffer buf = buf_ref.lock();
-    const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
-    const std::string type = device ? "device" : "host";
-    allocations[buf->buffer] = size;
-    total_device += device ? size : 0;
-    total_host += device ? 0 : size;
-    VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
-}
-
-void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
-    if (buf_ref.expired() || buf_ref.lock()->size == 0) {
-        return;
-    }
-
-    std::lock_guard<std::mutex> guard(log_mutex);
-    vk_buffer buf = buf_ref.lock();
-    const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
-    std::string type = device ? "device" : "host";
-    auto it = allocations.find(buf->buffer);
-    total_device -= device ? it->second : 0;
-    total_host -= device ? 0 : it->second;
-    if (it != allocations.end()) {
-        VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
-        allocations.erase(it);
-    } else {
-        VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
-    }
-}
-#endif // GGML_VULKAN_MEMORY_DEBUG
-
-struct vk_instance_t {
-    vk::Instance instance;
-
-    std::vector<size_t> device_indices;
-    vk_device devices[GGML_VK_MAX_DEVICES];
-};
-
-static bool vk_instance_initialized = false;
-static vk_instance_t vk_instance;
-
-#ifdef GGML_VULKAN_CHECK_RESULTS
-static size_t vk_skip_checks;
-static size_t vk_output_tensor;
-
-static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
-static void ggml_vk_check_results_0(ggml_tensor * tensor);
-static void ggml_vk_check_results_1(ggml_tensor * tensor);
-#endif
-
-typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
-
-static void ggml_backend_vk_free(ggml_backend_t backend);
-
-// variables to track number of compiles in progress
-static uint32_t compile_count = 0;
-static std::mutex compile_count_mutex;
-static std::condition_variable compile_count_cond;
-
-static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
-    VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
-    GGML_ASSERT(parameter_count > 0);
-    GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
-
-    pipeline = std::make_shared<vk_pipeline_struct>();
-    pipeline->name = name;
-    pipeline->parameter_count = parameter_count;
-    pipeline->push_constant_size = push_constant_size;
-    pipeline->wg_denoms = wg_denoms;
-    pipeline->align = align;
-
-    vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
-    pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
-
-    std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
-    std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
-    for (uint32_t i = 0; i < parameter_count; i++) {
-        dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
-        dsl_binding_flags.push_back({});
-    }
-
-    vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
-
-    vk::PushConstantRange pcr(
-        vk::ShaderStageFlagBits::eCompute,
-        0,
-        pipeline->push_constant_size
-    );
-
-    vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
-        {},
-        dsl_binding);
-    descriptor_set_layout_create_info.setPNext(&dslbfci);
-    pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
-
-    vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
-    vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
-    pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
-
-    pipeline->descriptor_set_idx = 0;
-
-    vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
-    pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
-
-    std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
-
-    for (size_t i = 0; i < specialization_constants.size(); i++) {
-        specialization_entries[i].constantID = i;
-        specialization_entries[i].offset = i * sizeof(uint32_t);
-        specialization_entries[i].size = sizeof(uint32_t);
-    }
-
-    vk::SpecializationInfo specialization_info(
-        specialization_entries.size(),
-        specialization_entries.data(),
-        specialization_constants.size() * sizeof(uint32_t),
-        specialization_constants.data()
-    );
-
-    vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
-            vk::PipelineShaderStageCreateFlags(),
-            vk::ShaderStageFlagBits::eCompute,
-            pipeline->shader_module,
-            entrypoint.c_str(),
-            &specialization_info);
-    vk::ComputePipelineCreateInfo compute_pipeline_create_info(
-        vk::PipelineCreateFlags(),
-        pipeline_shader_create_info,
-        pipeline->layout);
-    pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
-
-    {
-        std::lock_guard<std::mutex> guard(device->mutex);
-        device->pipelines.insert({ pipeline->name, pipeline });
-    }
-
-    {
-        std::lock_guard<std::mutex> guard(compile_count_mutex);
-        assert(compile_count > 0);
-        compile_count--;
-
-        // "Progress bar" for shader compiles
-        static uint32_t total_compile_count = 0;
-        if ((total_compile_count++ % 10) == 0) {
-            std::cerr << ".";
-        }
-    }
-    compile_count_cond.notify_all();
-}
-
-static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
-    VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
-    for (auto& pool : pipeline->descriptor_pools) {
-        device.destroyDescriptorPool(pool);
-    }
-    pipeline->descriptor_pools.clear();
-    pipeline->descriptor_sets.clear();
-    pipeline->descriptor_set_idx = 0;
-
-    device.destroyDescriptorSetLayout(pipeline->dsl);
-
-    device.destroyPipelineLayout(pipeline->layout);
-
-    device.destroyShaderModule(pipeline->shader_module);
-
-    device.destroyPipeline(pipeline->pipeline);
-}
-
-static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
-    VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
-    device->pipeline_descriptor_set_requirements[pipeline->name] += n;
-}
-
-static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) {
-    std::lock_guard<std::mutex> guard(device->mutex);
-
-    for (auto& pair : device->pipeline_descriptor_set_requirements) {
-        vk_pipeline pipeline = device->pipelines.at(pair.first).lock();
-        const uint64_t n = pair.second;
-
-        VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")");
-
-        if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
-            // Enough descriptors are available
-            continue;
-        }
-
-        uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
-        uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
-        uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
-
-        while (to_alloc > 0) {
-            const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
-            to_alloc -= alloc_count;
-            pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
-
-            if (pool_idx >= pipeline->descriptor_pools.size()) {
-                vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
-                vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
-                pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
-            }
-
-            std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
-            for (uint32_t i = 0; i < alloc_count; i++) {
-                layouts[i] = pipeline->dsl;
-            }
-            vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data());
-            std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
-            pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
-
-            pool_idx++;
-        }
-    }
-}
-
-static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
-    VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")");
-    pipeline->descriptor_set_idx = 0;
-}
-
-static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) {
-    VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
-    std::lock_guard<std::mutex> guard(device->mutex);
-
-    if (q.cmd_buffers.size() > q.cmd_buffer_idx) {
-        // Reuse command buffer
-        return q.cmd_buffers[q.cmd_buffer_idx++];
-    }
-
-    vk::CommandBufferAllocateInfo command_buffer_alloc_info(
-        q.pool,
-        vk::CommandBufferLevel::ePrimary,
-        1);
-    const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
-    auto buf = cmd_buffers.front();
-
-    q.cmd_buffers.push_back(buf);
-    q.cmd_buffer_idx++;
-
-    return buf;
-}
-
-static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
-    VK_LOG_DEBUG("ggml_vk_create_submission()");
-    vk_submission s;
-    s.buffer = ggml_vk_create_cmd_buffer(device, q);
-    s.wait_semaphores = std::move(wait_semaphores);
-    s.signal_semaphores = std::move(signal_semaphores);
-    return s;
-}
-
-static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
-    if (ctx->seqs.empty()) {
-        if (fence) {
-            ctx->q->queue.submit({}, fence);
-        }
-        return;
-    }
-    VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
-
-    std::vector<std::vector<uint64_t>> tl_wait_vals;
-    std::vector<std::vector<uint64_t>> tl_signal_vals;
-    std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores;
-    std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores;
-    std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos;
-    std::vector<vk::SubmitInfo> submit_infos;
-    int idx = -1;
-    std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;
-
-    size_t reserve = 0;
-
-    for (const auto& sequence : ctx->seqs) {
-        reserve += sequence.size();
-    }
-
-    // Pre-reserve vectors to prevent reallocation, which invalidates pointers
-    tl_wait_semaphores.reserve(reserve);
-    tl_wait_vals.reserve(reserve);
-    tl_signal_semaphores.reserve(reserve);
-    tl_signal_vals.reserve(reserve);
-    tl_submit_infos.reserve(reserve);
-    submit_infos.reserve(reserve);
-    stage_flags.reserve(reserve);
-
-    for (const auto& sequence : ctx->seqs) {
-        for (const auto& submission : sequence) {
-            stage_flags.push_back({});
-            idx++;
-            tl_wait_vals.push_back({});
-            tl_wait_semaphores.push_back({});
-            tl_signal_vals.push_back({});
-            tl_signal_semaphores.push_back({});
-            for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
-                stage_flags[idx].push_back(ctx->q->stage_flags);
-                tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);
-                tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);
-            }
-            for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {
-                tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value);
-                tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s);
-            }
-            tl_submit_infos.push_back({
-                (uint32_t) submission.wait_semaphores.size(),
-                tl_wait_vals[idx].data(),
-                (uint32_t) submission.signal_semaphores.size(),
-                tl_signal_vals[idx].data(),
-            });
-            tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;
-            tl_submit_infos[idx].pNext = nullptr;
-            vk::SubmitInfo si{
-                (uint32_t) submission.wait_semaphores.size(),
-                tl_wait_semaphores[idx].data(),
-                stage_flags[idx].data(),
-                1,
-                &submission.buffer,
-                (uint32_t) submission.signal_semaphores.size(),
-                tl_signal_semaphores[idx].data(),
-            };
-            si.setPNext(&tl_submit_infos[idx]);
-            submit_infos.push_back(si);
-        }
-    }
-
-    ctx->q->queue.submit(submit_infos, fence);
-
-    ctx->seqs.clear();
-}
-
-static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {
-    VK_LOG_DEBUG("ggml_vk_find_queue_family_index()");
-    const uint32_t qfsize = queue_family_props.size();
-
-    // Try with avoid preferences first
-    for (uint32_t i = 0; i < qfsize; i++) {
-        if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {
-            return i;
-        }
-    }
-
-    // Fall back to only required
-    for (size_t i = 0; i < qfsize; i++) {
-        if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {
-            return i;
-        }
-    }
-
-    // Fall back to reusing compute queue
-    for (size_t i = 0; i < qfsize; i++) {
-        if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {
-            return i;
-        }
-    }
-
-    // Fall back to ignoring min_num_queries
-    for (size_t i = 0; i < qfsize; i++) {
-        if (queue_family_props[i].queueFlags & required) {
-            return i;
-        }
-    }
-
-    // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.
-    // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.
-    if (compute_index >= 0) {
-        return compute_index;
-    }
-
-    std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
-
-    for(auto &q_family : queue_family_props) {
-        std::cerr << "Queue number: "  + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
-    }
-    abort();
-}
-
-static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
-    VK_LOG_DEBUG("ggml_vk_create_queue()");
-    std::lock_guard<std::mutex> guard(device->mutex);
-
-    q.queue_family_index = queue_family_index;
-    q.transfer_only = transfer_only;
-
-    vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
-    q.pool = device->device.createCommandPool(command_pool_create_info_compute);
-
-    q.cmd_buffer_idx = 0;
-
-    q.queue = device->device.getQueue(queue_family_index, queue_index);
-
-    q.stage_flags = stage_flags;
-}
-
-static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) {
-    vk_context result = std::make_shared<vk_context_struct>();
-    VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
-    ctx->gc.contexts.emplace_back(result);
-    result->q = &q;
-    return result;
-}
-
-static vk_context ggml_vk_create_temporary_context(vk_queue& q) {
-    vk_context result = std::make_shared<vk_context_struct>();
-    VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
-    result->q = &q;
-    return result;
-}
-
-static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
-    VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
-    vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
-    vk::SemaphoreCreateInfo ci{};
-    ci.setPNext(&tci);
-    vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
-    ctx->gc.semaphores.push_back({ semaphore, 0 });
-    return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
-}
-
-static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {
-    VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
-    if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {
-        vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
-        vk::SemaphoreCreateInfo ci{};
-        ci.setPNext(&tci);
-        vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
-        ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
-    }
-    return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
-}
-
-static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
-    if (ctx->event_idx >= ctx->gc.events.size()) {
-        ctx->gc.events.push_back(ctx->device->device.createEvent({}));
-    }
-    return ctx->gc.events[ctx->event_idx++];
-}
-
-static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) {
-    VK_LOG_DEBUG("ggml_vk_queue_cleanup()");
-    std::lock_guard<std::mutex> guard(device->mutex);
-
-    // Requires command buffers to be done
-    device->device.resetCommandPool(q.pool);
-    q.cmd_buffer_idx = 0;
-}
-
-static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
-    for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
-        vk::MemoryType memory_type = mem_props->memoryTypes[i];
-        if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
-            (flags & memory_type.propertyFlags) == flags &&
-            mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
-            return static_cast<int32_t>(i);
-        }
-    }
-    return UINT32_MAX;
-}
-
-static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
-    VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
-    if (size > device->max_memory_allocation_size) {
-        throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
-    }
-
-    std::lock_guard<std::mutex> guard(device->mutex);
-
-    vk_buffer buf = std::make_shared<vk_buffer_struct>();
-
-    if (size == 0) {
-        buf->size = 0;
-        return buf;
-    }
-
-    vk::BufferCreateInfo buffer_create_info{
-        vk::BufferCreateFlags(),
-        size,
-        vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
-        vk::SharingMode::eExclusive,
-        0,
-        nullptr,
-    };
-
-    buf->buffer = device->device.createBuffer(buffer_create_info);
-
-    vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
-
-    vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
-
-    uint32_t memory_type_index = UINT32_MAX;
-
-    memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
-    buf->memory_property_flags = req_flags;
-
-    if (memory_type_index == UINT32_MAX && fallback_flags) {
-        memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
-        buf->memory_property_flags = fallback_flags;
-    }
-
-    if (memory_type_index == UINT32_MAX) {
-        device->device.destroyBuffer(buf->buffer);
-        throw vk::OutOfDeviceMemoryError("No suitable memory type found");
-    }
-
-    try {
-        buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
-    } catch (const vk::SystemError& e) {
-        if (buf->memory_property_flags != fallback_flags) {
-            // Try again with fallback flags
-            memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
-            buf->memory_property_flags = fallback_flags;
-
-            try {
-                buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
-            }
-            catch (const vk::SystemError& e) {
-                device->device.destroyBuffer(buf->buffer);
-                throw e;
-            }
-        } else {
-            // Out of Host/Device memory, clean up buffer
-            device->device.destroyBuffer(buf->buffer);
-            throw e;
-        }
-    }
-    buf->ptr = nullptr;
-
-    if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
-        buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
-    }
-
-    device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
-
-    buf->device = device;
-    buf->size = size;
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-    device->memory_logger->log_allocation(buf, size);
-#endif
-
-    return buf;
-}
-
-static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
-    try {
-        return ggml_vk_create_buffer(device, size, req_flags, fallback_flags);
-    } catch (const vk::SystemError& e) {
-        std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
-        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
-        throw e;
-    }
-}
-
-static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
-    vk_buffer buf;
-    try {
-        if (device->uma) {
-            // Fall back to host memory type
-            buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
-        } else {
-            // use rebar if available, otherwise fallback to device only visible memory
-            buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
-        }
-    } catch (const vk::SystemError& e) {
-        std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
-        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
-        throw e;
-    }
-
-    return buf;
-}
-
-static void ggml_vk_destroy_buffer(vk_buffer& buf) {
-    if (buf == nullptr) {
-        return;
-    }
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-    if (buf->device != nullptr) {
-        buf->device->memory_logger->log_deallocation(buf);
-    }
-#endif
-
-    buf.reset();
-}
-
-static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
-    return { buf, 0, VK_WHOLE_SIZE };
-}
-
-static void ggml_vk_sync_buffers(vk_context& ctx) {
-    VK_LOG_DEBUG("ggml_vk_sync_buffers()");
-
-    const bool transfer_queue = ctx->q->transfer_only;
-
-    ctx->s->buffer.pipelineBarrier(
-        ctx->q->stage_flags,
-        ctx->q->stage_flags,
-        {},
-        { {
-          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
-          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }
-        } },
-        {},
-        {}
-    );
-}
-
-static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
-    VK_LOG_DEBUG("ggml_vk_wait_events()");
-    if (events.empty()) {
-        return;
-    }
-
-    ctx->s->buffer.waitEvents(
-        events,
-        ctx->q->stage_flags,
-        ctx->q->stage_flags,
-        {},
-        {},
-        {}
-    );
-}
-
-static void ggml_vk_load_shaders(vk_device& device) {
-    VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
-
-    std::cerr << "ggml_vulkan: Compiling shaders";
-
-    // mulmat
-    std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
-    std::initializer_list<uint32_t> warptile_m = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
-    std::initializer_list<uint32_t> warptile_s = { std::max(device->subgroup_size, 16u),  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
-
-    std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
-    std::initializer_list<uint32_t> warptile_mmq_m = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
-    std::initializer_list<uint32_t> warptile_mmq_s = { std::max(device->subgroup_size, 16u),  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
-
-    std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
-    std::array<uint32_t, 3> m_wg_denoms = { 64,  64, 1 };
-    std::array<uint32_t, 3> s_wg_denoms = { 32,  32, 1 };
-
-    uint32_t l_align = 128;
-    uint32_t m_align =  64;
-    uint32_t s_align =  32;
-
-    device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
-
-    device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
-
-    std::vector<std::future<void>> compiles;
-    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
-        {
-            // wait until fewer than N compiles are in progress
-            uint32_t N = std::max(1u, std::thread::hardware_concurrency());
-            std::unique_lock<std::mutex> guard(compile_count_mutex);
-            while (compile_count >= N) {
-                compile_count_cond.wait(guard);
-            }
-            compile_count++;
-        }
-        compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
-    };
-
-    if (device->fp16) {
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-    } else {
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
-        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
-    }
-
-    // mul mat vec
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32",  mul_mat_vec_f32_f32_f32_len,  mul_mat_vec_f32_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32",  mul_mat_vec_f16_f32_f32_len,  mul_mat_vec_f16_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32",  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32",  mul_mat_vec_f16_f16_f32_len,  mul_mat_vec_f16_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",  mul_mat_vec_id_f32_f32_len,  mul_mat_vec_id_f32_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32",  mul_mat_vec_id_f16_f32_len,  mul_mat_vec_id_f16_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
-    // dequant shaders
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   dequant_f32_len,  dequant_f32_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
-
-    // get_rows
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32",  get_rows_f32_len,  get_rows_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16",  get_rows_f16_len,  get_rows_f16_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
-
-    ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
-
-    for (auto &c : compiles) {
-        c.wait();
-    }
-    std::cerr << "Done!" << std::endl;
-}
-
-static vk_device ggml_vk_get_device(size_t idx) {
-    VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
-
-    if (vk_instance.devices[idx] == nullptr) {
-        VK_LOG_DEBUG("Initializing new vk_device");
-        vk_device device = std::make_shared<vk_device_struct>();
-        vk_instance.devices[idx] = device;
-
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-        device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
-#endif
-#ifdef GGML_VULKAN_PERF
-        device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
-#endif
-
-        size_t dev_num = vk_instance.device_indices[idx];
-
-        std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices();
-
-        if (dev_num >= physical_devices.size()) {
-            std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
-            throw std::runtime_error("Device not found");
-        }
-
-        device->physical_device = physical_devices[dev_num];
-        const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
-
-        bool maintenance4_support = false;
-
-        // Check if maintenance4 is supported
-        for (const auto& properties : ext_props) {
-            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
-                maintenance4_support = true;
-            }
-        }
-
-        vk::PhysicalDeviceProperties2 props2;
-        vk::PhysicalDeviceMaintenance3Properties props3;
-        vk::PhysicalDeviceMaintenance4Properties props4;
-        vk::PhysicalDeviceSubgroupProperties subgroup_props;
-        props2.pNext = &props3;
-        props3.pNext = &subgroup_props;
-        if (maintenance4_support) {
-            subgroup_props.pNext = &props4;
-        }
-        device->physical_device.getProperties2(&props2);
-        device->properties = props2.properties;
-
-        const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
-
-        if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
-            device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
-        } else if (maintenance4_support) {
-            device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
-        } else {
-            device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
-        }
-
-        device->vendor_id = device->properties.vendorID;
-        device->subgroup_size = subgroup_props.subgroupSize;
-        device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
-
-        bool fp16_storage = false;
-        bool fp16_compute = false;
-
-        for (const auto& properties : ext_props) {
-            if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
-                fp16_storage = true;
-            } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
-                fp16_compute = true;
-            }
-        }
-
-        const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
-        const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
-
-        device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
-
-        std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
-
-        // Try to find a non-graphics compute queue and transfer-focused queues
-        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
-        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
-
-        const float priorities[] = { 1.0f, 1.0f };
-        device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
-
-        std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
-        if (compute_queue_family_index != transfer_queue_family_index) {
-            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
-            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
-        } else if(!device->single_queue) {
-            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
-        } else {
-            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
-        }
-        vk::DeviceCreateInfo device_create_info;
-        std::vector<const char *> device_extensions;
-        vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
-
-        VkPhysicalDeviceFeatures2 device_features2;
-        device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
-        device_features2.pNext = nullptr;
-        device_features2.features = (VkPhysicalDeviceFeatures)device_features;
-
-        VkPhysicalDeviceVulkan11Features vk11_features;
-        vk11_features.pNext = nullptr;
-        vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
-        device_features2.pNext = &vk11_features;
-
-        VkPhysicalDeviceVulkan12Features vk12_features;
-        vk12_features.pNext = nullptr;
-        vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
-        vk11_features.pNext = &vk12_features;
-
-        vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
-
-        device->fp16 = device->fp16 && vk12_features.shaderFloat16;
-
-        if (!vk11_features.storageBuffer16BitAccess) {
-            std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
-            throw std::runtime_error("Unsupported device");
-        }
-
-        device_extensions.push_back("VK_KHR_16bit_storage");
-
-#ifdef GGML_VULKAN_VALIDATE
-        device_extensions.push_back("VK_KHR_shader_non_semantic_info");
-#endif
-
-        if (device->fp16) {
-            device_extensions.push_back("VK_KHR_shader_float16_int8");
-        }
-        device->name = GGML_VK_NAME + std::to_string(idx);
-
-        device_create_info = {
-            vk::DeviceCreateFlags(),
-            device_queue_create_infos,
-            {},
-            device_extensions
-        };
-        device_create_info.setPNext(&device_features2);
-        device->device = device->physical_device.createDevice(device_create_info);
-
-        // Queues
-        ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
-
-        // Shaders
-        ggml_vk_load_shaders(device);
-
-        if (!device->single_queue) {
-            const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
-            ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
-        } else {
-            // TODO: Use pointer or reference to avoid copy
-            device->transfer_queue = device->compute_queue;
-        }
-
-        device->buffer_type = {
-            /* .iface    = */ ggml_backend_vk_buffer_type_interface,
-            /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
-            /* .context  = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
-        };
-
-        device->fence = device->device.createFence({});
-
-        device->idx = idx;
-
-        return device;
-    }
-
-    return vk_instance.devices[idx];
-}
-
-
-static void ggml_vk_print_gpu_info(size_t idx) {
-    GGML_ASSERT(idx < vk_instance.device_indices.size());
-    size_t dev_num = vk_instance.device_indices[idx];
-    VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")");
-    GGML_ASSERT(vk_instance_initialized);
-
-    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
-
-    if (dev_num >= devices.size()) {
-        std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
-        throw std::runtime_error("Device not found");
-    }
-
-    vk::PhysicalDevice physical_device = devices[dev_num];
-    std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
-
-    vk::PhysicalDeviceProperties2 props2;
-    vk::PhysicalDeviceMaintenance3Properties props3;
-    vk::PhysicalDeviceSubgroupProperties subgroup_props;
-    vk::PhysicalDeviceDriverProperties driver_props;
-    props2.pNext = &props3;
-    props3.pNext = &subgroup_props;
-    subgroup_props.pNext = &driver_props;
-    physical_device.getProperties2(&props2);
-
-    const size_t subgroup_size = subgroup_props.subgroupSize;
-    const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
-
-    bool fp16_storage = false;
-    bool fp16_compute = false;
-
-    for (auto properties : ext_props) {
-        if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
-            fp16_storage = true;
-        } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
-            fp16_compute = true;
-        }
-    }
-
-    const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
-    bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
-
-    bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
-
-    vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
-
-    VkPhysicalDeviceFeatures2 device_features2;
-    device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
-    device_features2.pNext = nullptr;
-    device_features2.features = (VkPhysicalDeviceFeatures)device_features;
-
-    VkPhysicalDeviceVulkan11Features vk11_features;
-    vk11_features.pNext = nullptr;
-    vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
-    device_features2.pNext = &vk11_features;
-
-    VkPhysicalDeviceVulkan12Features vk12_features;
-    vk12_features.pNext = nullptr;
-    vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
-    vk11_features.pNext = &vk12_features;
-
-    vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
-
-    fp16 = fp16 && vk12_features.shaderFloat16;
-
-    std::string device_name = props2.properties.deviceName.data();
-    GGML_LOG_DEBUG("ggml_vulkan: %d = %s (%s) | uma: %d | fp16: %d | warp size: %d\n",
-              idx, device_name.c_str(), driver_props.driverName, uma, fp16, subgroup_size);
-
-    if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
-        std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
-    }
-}
-
-static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
-static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
-
-void ggml_vk_instance_init() {
-    if (vk_instance_initialized) {
-        return;
-    }
-    VK_LOG_DEBUG("ggml_vk_instance_init()");
-
-    vk_instance_initialized = true;
-
-    vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
-
-    const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
-    const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
-#ifdef __APPLE__
-    const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
-#endif
-
-    std::vector<const char*> layers;
-
-    if (validation_ext) {
-        layers.push_back("VK_LAYER_KHRONOS_validation");
-    }
-    std::vector<const char*> extensions;
-    if (validation_ext) {
-        extensions.push_back("VK_EXT_validation_features");
-    }
-#ifdef __APPLE__
-    if (portability_enumeration_ext) {
-        extensions.push_back("VK_KHR_portability_enumeration");
-    }
-#endif
-    vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
-#ifdef __APPLE__
-    if (portability_enumeration_ext) {
-        instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
-    }
-#endif
-
-    std::vector<vk::ValidationFeatureEnableEXT> features_enable;
-    vk::ValidationFeaturesEXT validation_features;
-
-    if (validation_ext) {
-        features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
-        validation_features = {
-            features_enable,
-            {},
-        };
-        validation_features.setPNext(nullptr);
-        instance_create_info.setPNext(&validation_features);
-        GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
-    }
-    vk_instance.instance = vk::createInstance(instance_create_info);
-
-    size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
-
-    // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
-    char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
-    if (devices_env != nullptr) {
-        std::string devices(devices_env);
-        std::replace(devices.begin(), devices.end(), ',', ' ');
-
-        std::stringstream ss(devices);
-        size_t tmp;
-        while (ss >> tmp) {
-            if(tmp >= num_available_devices) {
-                std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl;
-                throw std::runtime_error("Invalid Vulkan device index");
-            }
-            vk_instance.device_indices.push_back(tmp);
-        }
-    } else {
-        std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
-
-        // Make sure at least one device exists
-        if (devices.empty()) {
-            std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
-            GGML_ABORT("fatal error");
-        }
-
-        // Default to using all dedicated GPUs
-        for (size_t i = 0; i < devices.size(); i++) {
-            vk::PhysicalDeviceProperties2 new_props;
-            vk::PhysicalDeviceDriverProperties new_driver;
-            vk::PhysicalDeviceIDProperties new_id;
-            new_props.pNext = &new_driver;
-            new_driver.pNext = &new_id;
-            devices[i].getProperties2(&new_props);
-
-            if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
-                // Check if there are two physical devices corresponding to the same GPU
-                auto old_device = std::find_if(
-                    vk_instance.device_indices.begin(),
-                    vk_instance.device_indices.end(),
-                    [&devices, &new_id](const size_t k){
-                        vk::PhysicalDeviceProperties2 old_props;
-                        vk::PhysicalDeviceIDProperties old_id;
-                        old_props.pNext = &old_id;
-                        devices[k].getProperties2(&old_props);
-                        return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
-                    }
-                );
-                if (old_device == vk_instance.device_indices.end()) {
-                    vk_instance.device_indices.push_back(i);
-                } else {
-                    // There can be two physical devices corresponding to the same GPU if there are 2 different drivers
-                    // This can cause error when splitting layers aross the devices, need to keep only 1
-                    VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
-
-                    vk::PhysicalDeviceProperties2 old_props;
-                    vk::PhysicalDeviceDriverProperties old_driver;
-                    old_props.pNext = &old_driver;
-                    devices[*old_device].getProperties2(&old_props);
-
-                    std::map<vk::DriverId, int> driver_priorities {};
-                    int old_priority = std::numeric_limits<int>::max();
-                    int new_priority = std::numeric_limits<int>::max();
-
-                    // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
-                    // Smaller number -> higher priority
-                    switch (old_props.properties.vendorID) {
-                        case VK_VENDOR_ID_AMD:
-                            driver_priorities[vk::DriverId::eMesaRadv] = 1;
-                            driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
-                            driver_priorities[vk::DriverId::eAmdProprietary] = 3;
-                            break;
-                        case VK_VENDOR_ID_INTEL:
-                            driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1;
-                            driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2;
-                            break;
-                        case VK_VENDOR_ID_NVIDIA:
-                            driver_priorities[vk::DriverId::eNvidiaProprietary] = 1;
-#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235
-                            driver_priorities[vk::DriverId::eMesaNvk] = 2;
-#endif
-                            break;
-                    }
-
-                    if (driver_priorities.count(old_driver.driverID)) {
-                        old_priority = driver_priorities[old_driver.driverID];
-                    }
-                    if (driver_priorities.count(new_driver.driverID)) {
-                        new_priority = driver_priorities[new_driver.driverID];
-                    }
-
-                    if (new_priority < old_priority) {
-                        auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device);
-                        vk_instance.device_indices.erase(r, vk_instance.device_indices.end());
-                        vk_instance.device_indices.push_back(i);
-
-                        VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName);
-                    }
-                    else {
-                        VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl);
-                    }
-                }
-            }
-        }
-
-        // If no dedicated GPUs found, fall back to GPU 0
-        if (vk_instance.device_indices.empty()) {
-            vk_instance.device_indices.push_back(0);
-        }
-    }
-    GGML_LOG_DEBUG("ggml_vulkan: Found %d Vulkan devices:\n", vk_instance.device_indices.size());
-
-
-    for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
-        ggml_vk_print_gpu_info(i);
-    }
-}
-
-static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
-    VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")");
-    ggml_vk_instance_init();
-    GGML_ASSERT(idx < vk_instance.device_indices.size());
-
-    ctx->name = GGML_VK_NAME + std::to_string(idx);
-
-    ctx->device = ggml_vk_get_device(idx);
-
-    ctx->semaphore_idx = 0;
-    ctx->event_idx = 0;
-
-    ctx->prealloc_size_x = 0;
-    ctx->prealloc_size_y = 0;
-    ctx->prealloc_size_split_k = 0;
-
-    ctx->fence = ctx->device->device.createFence({});
-
-#ifdef GGML_VULKAN_CHECK_RESULTS
-    const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
-    vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
-    const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR");
-    vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor));
-#endif
-}
-
-static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
-    VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
-    switch (type) {
-        case GGML_TYPE_F32:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ4_NL:
-            break;
-        default:
-            return nullptr;
-    }
-
-    return ctx->device->pipeline_dequant[type];
-}
-
-static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
-    VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
-    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
-        return ctx->device->pipeline_matmul_f32;
-    }
-    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
-        return ctx->device->pipeline_matmul_f32_f16;
-    }
-    if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
-        return ctx->device->pipeline_matmul_f16_f32;
-    }
-    if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
-        return ctx->device->pipeline_matmul_f16;
-    }
-
-    if (src1_type != GGML_TYPE_F32) {
-        return nullptr;
-    }
-
-    switch (src0_type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ4_NL:
-            break;
-        default:
-            return nullptr;
-    }
-
-    return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
-}
-
-static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
-    VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
-    GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
-
-    switch (a_type) {
-        case GGML_TYPE_F32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ4_NL:
-            break;
-        default:
-            return nullptr;
-    }
-
-    return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
-}
-
-static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
-    VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
-    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
-        return ctx->device->pipeline_matmul_id_f32;
-    }
-    if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
-        return ctx->device->pipeline_matmul_id_f16_f32;
-    }
-    if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
-        return ctx->device->pipeline_matmul_id_f16;
-    }
-
-    GGML_ASSERT(src1_type == GGML_TYPE_F32);
-
-    switch (src0_type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ4_NL:
-            break;
-        default:
-            return nullptr;
-    }
-
-    return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
-}
-
-static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
-    VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
-    GGML_ASSERT(b_type == GGML_TYPE_F32);
-
-    switch (a_type) {
-        case GGML_TYPE_F32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ4_NL:
-            break;
-        default:
-            return nullptr;
-    }
-
-    return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
-}
-
-static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
-    VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
-    VK_LOG_MEMORY("ggml_vk_pool_malloc");
-
-    int best_i = -1;
-    size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
-    int worst_i = -1;
-    size_t worst_size = 0; //largest unused buffer seen so far
-    for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
-        vk_buffer &b = ctx->buffer_pool[i];
-        if (b != nullptr && b->size >= size && b->size < best_size) {
-            best_i = i;
-            best_size = b->size;
-        }
-        if (b != nullptr && b->size > worst_size) {
-            worst_i = i;
-            worst_size = b->size;
-        }
-    }
-    if(best_i != -1) {
-        //found the smallest buffer that fits our needs
-        vk_buffer b = ctx->buffer_pool[best_i];
-        ctx->buffer_pool[best_i].reset();
-        return b;
-    }
-    if(worst_i != -1) {
-        //no buffer that fits our needs, resize largest one to save memory
-        vk_buffer& b = ctx->buffer_pool[worst_i];
-        ggml_vk_destroy_buffer(b);
-    }
-
-    return ggml_vk_create_buffer_device(ctx->device, size);
-}
-
-static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
-    VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
-    for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
-        vk_buffer& b = ctx->buffer_pool[i];
-        if (b == nullptr) {
-            b = buffer;
-            return;
-        }
-    }
-    std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
-    ggml_vk_destroy_buffer(buffer);
-}
-
-// Returns an available temporary buffer that may only be used temporarily, it will be reused
-static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
-    // Try to find existing temp buffer with enough capacity
-    for (auto& buffer : ctx->gc.temp_buffers) {
-        if (buffer->size >= size) {
-            return buffer;
-        }
-    }
-
-    VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
-
-    // Otherwise create new buffer
-    vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
-    ctx->gc.temp_buffers.push_back(buf);
-
-    return buf;
-}
-
-static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
-    VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
-    vk_buffer buf = ggml_vk_create_buffer(device, size,
-        vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
-        vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
-
-    if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
-        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
-            size/1024.0/1024.0);
-        device->device.freeMemory(buf->device_memory);
-        device->device.destroyBuffer(buf->buffer);
-        return nullptr;
-    }
-
-    device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
-
-    return buf->ptr;
-}
-
-static void ggml_vk_host_free(vk_device& device, void* ptr) {
-    if (ptr == nullptr) {
-        return;
-    }
-    VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
-    vk_buffer buf;
-    size_t index;
-    for (size_t i = 0; i < device->pinned_memory.size(); i++) {
-        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
-        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
-        if (ptr >= addr && ptr < endr) {
-            buf = std::get<2>(device->pinned_memory[i]);
-            index = i;
-            break;
-        }
-    }
-    if (buf == nullptr) {
-        fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
-        return;
-    }
-
-    ggml_vk_destroy_buffer(buf);
-
-    device->pinned_memory.erase(device->pinned_memory.begin() + index);
-}
-
-static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
-    buf = nullptr;
-    buf_offset = 0;
-    for (size_t i = 0; i < device->pinned_memory.size(); i++) {
-        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
-        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
-        if (ptr >= addr && ptr < endr) {
-            buf = std::get<2>(device->pinned_memory[i]);
-            buf_offset = ((const uint8_t *)ptr) - addr;
-            break;
-        }
-    }
-}
-
-static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) {
-    vk_submission s;
-    s.buffer = ggml_vk_create_cmd_buffer(device, q);
-    if (one_time) {
-        s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
-    } else {
-        s.buffer.begin({ vk::CommandBufferUsageFlags{} });
-    }
-
-    return s;
-}
-
-
-
-static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
-    const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
-    const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
-    const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
-    VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
-    for (auto& buffer : descriptor_buffer_infos) {
-        std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
-    }
-    std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
-    GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
-    GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count);
-
-    vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
-    vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
-    ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
-
-    subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
-    subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
-    subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
-                                pipeline->layout,
-                                0,
-                                { descriptor_set },
-                                {});
-    subctx->s->buffer.dispatch(wg0, wg1, wg2);
-}
-
-static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
-    s.buffer.end();
-
-    s.wait_semaphores = std::move(wait_semaphores);
-    s.signal_semaphores = std::move(signal_semaphores);
-}
-
-static void ggml_vk_ctx_end(vk_context& ctx) {
-    VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
-    if (ctx->s == nullptr) {
-        return;
-    }
-
-    ctx->s->buffer.end();
-    ctx->s = nullptr;
-}
-
-static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
-    VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
-    if (subctx->s != nullptr) {
-        ggml_vk_ctx_end(subctx);
-    }
-
-    subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) });
-    subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
-}
-
-static size_t ggml_vk_align_size(size_t width, size_t align) {
-    VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
-    return CEIL_DIV(width, align) * align;
-}
-
-static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr) {
-    if (memcpys == nullptr) {
-        memcpy(dst, src, size);
-    } else {
-        memcpys->emplace_back(dst, src, size);
-    }
-}
-
-static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
-    if (device->sync_staging == nullptr || device->sync_staging->size < size) {
-        VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
-        ggml_vk_destroy_buffer(device->sync_staging);
-        device->sync_staging = ggml_vk_create_buffer_check(device, size,
-            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
-            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
-    }
-}
-
-static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
-    VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
-    GGML_ASSERT(!ggml_is_contiguous(tensor));
-    // Buffer is already mapped
-    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
-        std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl;
-        GGML_ABORT("fatal error");
-    }
-    // Check if src is pinned memory
-    vk_buffer buf;
-    size_t buf_offset;
-    ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
-
-    const uint64_t ne0 = tensor->ne[0];
-    const uint64_t ne1 = tensor->ne[1];
-    const uint64_t ne2 = tensor->ne[2];
-    const uint64_t ne3 = tensor->ne[3];
-    const uint64_t nb0 = tensor->nb[0];
-    const uint64_t nb1 = tensor->nb[1];
-    const uint64_t nb2 = tensor->nb[2];
-    const uint64_t nb3 = tensor->nb[3];
-    const ggml_type type = tensor->type;
-    const uint64_t ts = ggml_type_size(type);
-    const uint64_t bs = ggml_blck_size(type);
-
-    const uint64_t dstnb0 = ts;
-    const uint64_t dstnb1 = dstnb0*(ne0/bs);
-    const uint64_t dstnb2 = dstnb1*ne1;
-    const uint64_t dstnb3 = dstnb2*ne2;
-
-    const uint64_t ne = ggml_nelements(tensor);
-
-    if (buf != nullptr) {
-        // Memory is pinned, use as staging buffer
-        std::vector<vk::BufferCopy> slices;
-
-        for (uint64_t i3 = 0; i3 < ne3; i3++) {
-            for (uint64_t i2 = 0; i2 < ne2; i2++) {
-                // Find longest contiguous slice
-                if (ne1*nb1 == dstnb2) {
-                    slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });
-                } else {
-                    for (uint64_t i1 = 0; i1 < ne1; i1++) {
-                        if (ne0*nb0/bs == dstnb1) {
-                            slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });
-                        } else {
-                            const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
-                            const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
-                            for (uint64_t i0 = 0; i0 < ne0; i0++) {
-                                slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
-                            }
-                        }
-                    }
-                }
-            }
-        }
-
-        ggml_vk_sync_buffers(subctx);
-        subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
-        return;
-    }
-
-    if (!sync_staging) {
-        GGML_ABORT("Asynchronous write to non-pinned memory not supported");
-    }
-
-    // Staging buffer required
-    vk_buffer& staging = ctx->device->sync_staging;
-    const uint64_t copy_size = ts*ne/bs;
-    ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
-    VkBufferCopy buf_copy{ 0, offset, copy_size };
-
-    ggml_vk_sync_buffers(subctx);
-    vkCmdCopyBuffer(subctx->s->buffer, staging->buffer, dst->buffer, 1, &buf_copy);
-
-    for (uint64_t i3 = 0; i3 < ne3; i3++) {
-        for (uint64_t i2 = 0; i2 < ne2; i2++) {
-            // Find longest contiguous slice
-            if (ne1*nb1 == dstnb2) {
-                deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
-            } else {
-                for (uint64_t i1 = 0; i1 < ne1; i1++) {
-                    if (ne0*nb0/bs == dstnb1) {
-                        deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
-                    } else {
-                        const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
-                        const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
-                        for (uint64_t i0 = 0; i0 < ne0; i0++) {
-                            deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);
-                        }
-                    }
-                }
-            }
-        }
-    }
-}
-
-static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
-    VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
-    // Buffer is already mapped
-    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
-        std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
-        GGML_ABORT("fatal error");
-    }
-    // Check if src is pinned memory
-    vk_buffer buf = nullptr;
-    size_t buf_offset;
-    ggml_vk_host_get(dst->device, src, buf, buf_offset);
-
-    if (buf != nullptr) {
-        // Memory is pinned, use as staging buffer
-        std::vector<vk::BufferCopy> slices(1);
-        if (width == spitch) {
-            // Only do single write if stride is equal
-            slices[0].srcOffset = buf_offset;
-            slices[0].dstOffset = offset;
-            slices[0].size = width * height;
-        } else {
-            slices.resize(height);
-            for (size_t i = 0; i < height; i++) {
-                slices[i].srcOffset = buf_offset + i * spitch;
-                slices[i].dstOffset = offset + i * width;
-                slices[i].size = width;
-            }
-        }
-
-        ggml_vk_sync_buffers(subctx);
-        subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
-        return;
-    }
-    VK_LOG_DEBUG("STAGING");
-
-    if (!sync_staging) {
-        GGML_ABORT("Asynchronous write to non-pinned memory not supported");
-    }
-
-    // Staging buffer required
-    const size_t copy_size = width*height;
-    ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
-
-    vk_buffer& staging_buffer = dst->device->sync_staging;
-
-    VkBufferCopy buf_copy = {
-        0,
-        offset,
-        copy_size};
-
-    ggml_vk_sync_buffers(subctx);
-    vkCmdCopyBuffer(subctx->s->buffer, staging_buffer->buffer, dst->buffer, 1, &buf_copy);
-
-    if (width == spitch) {
-        deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
-    } else {
-        for (size_t i = 0; i < height; i++) {
-            deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
-        }
-    }
-}
-
-static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
-    VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
-    return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
-}
-
-static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
-    VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
-    // Buffer is already mapped
-    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
-        GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
-
-        for (size_t i = 0; i < height; i++) {
-            memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
-        }
-    } else {
-        vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
-        ggml_vk_ctx_begin(dst->device, subctx);
-        ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
-        ggml_vk_ctx_end(subctx);
-
-        for (auto& cpy : subctx->in_memcpys) {
-            memcpy(cpy.dst, cpy.src, cpy.n);
-        }
-
-        ggml_vk_submit(subctx, dst->device->fence);
-        VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
-        dst->device->device.resetFences({ dst->device->fence });
-    }
-}
-
-static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
-    VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
-    ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
-}
-
-static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
-    VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
-    GGML_ASSERT(width > 0);
-    GGML_ASSERT(height > 0);
-    GGML_ASSERT(src != nullptr);
-
-    // TODO: staging_offset is not used
-
-    // Check if dst is pinned memory
-    vk_buffer buf = nullptr;
-    size_t buf_offset;
-    ggml_vk_host_get(src->device, dst, buf, buf_offset);
-
-    std::vector<vk::BufferCopy> slices(1);
-    if (width == spitch && width == dpitch) {
-        // Only do single write if stride is equal
-        slices[0].srcOffset = offset;
-        slices[0].dstOffset = buf_offset;
-        slices[0].size = width * height;
-    } else {
-        slices.resize(height);
-        for (size_t i = 0; i < height; i++) {
-            slices[i].srcOffset = offset + i * spitch;
-            slices[i].dstOffset = buf_offset + i * dpitch;
-            slices[i].size = width;
-        }
-    }
-
-    if (buf != nullptr) {
-        // Memory is pinned, use as staging buffer
-        ggml_vk_sync_buffers(subctx);
-        subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
-
-        return;
-    }
-    VK_LOG_DEBUG("STAGING");
-
-    if (!sync_staging) {
-        GGML_ABORT("Asynchronous read from non-pinned memory not supported");
-    }
-
-    // Fall back to staging buffer
-    const size_t copy_size = dpitch * height;
-    ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
-
-    vk_buffer& staging_buffer = src->device->sync_staging;
-
-    ggml_vk_sync_buffers(subctx);
-    subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
-
-    deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
-}
-
-static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
-    return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
-}
-
-static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
-    VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
-
-    // If the device is not an UMA device the memory is host-accessible through rebar. While writing
-    // through PCIe is sufficient fast reading back data from PCIe is slower than going through
-    // the HW device to host copy path.
-    if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
-        GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
-
-        memcpy(dst, (uint8_t *) src->ptr + offset, size);
-    } else {
-        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
-        ggml_vk_ctx_begin(src->device, subctx);
-        ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
-        ggml_vk_ctx_end(subctx);
-
-        ggml_vk_submit(subctx, src->device->fence);
-        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
-        src->device->device.resetFences({ src->device->fence });
-
-        for (auto& cpy : subctx->out_memcpys) {
-            memcpy(cpy.dst, cpy.src, cpy.n);
-        }
-    }
-}
-
-static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
-    VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
-    // Make sure both buffers are on same device
-    GGML_ASSERT(src->device == dst->device);
-
-    VkBufferCopy bc{ src_offset, dst_offset, size };
-
-    vkCmdCopyBuffer(ctx->s->buffer, src->buffer, dst->buffer, 1, &bc);
-}
-
-static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
-    if (src->device == dst->device) {
-        VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
-        // Copy within the device
-        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
-        ggml_vk_ctx_begin(src->device, subctx);
-        ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
-        ggml_vk_ctx_end(subctx);
-        ggml_vk_submit(subctx, src->device->fence);
-        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
-        src->device->device.resetFences({ src->device->fence });
-    } else {
-        VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
-        // Copy device to device
-        ggml_vk_ensure_sync_staging_buffer(src->device, size);
-        ggml_vk_ensure_sync_staging_buffer(dst->device, size);
-
-        // Copy to src staging buffer
-        ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
-        // memcpy to dst staging buffer
-        memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
-        // Copy to dst buffer
-        ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
-    }
-}
-
-static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
-    VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
-
-    vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
-    ggml_vk_ctx_begin(dst->device, subctx);
-    subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
-    ggml_vk_ctx_end(subctx);
-
-    ggml_vk_submit(subctx, dst->device->fence);
-    VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
-    dst->device->device.resetFences({ dst->device->fence });
-}
-
-static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
-    VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
-    // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
-    //     return 4;
-    // }
-
-    return 1;
-
-    GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
-    if (m <= 32 || n <= 32) {
-        return aligned ? mmp->a_s : mmp->s;
-    }
-    return aligned ? mmp->a_m : mmp->m;
-
-    GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
-    return aligned ? mmp->a_m : mmp->m;
-
-    GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
-    return aligned ? mmp->a_s : mmp->s;
-
-    GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
-    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
-    switch (ctx->device->vendor_id) {
-    case VK_VENDOR_ID_AMD:
-        return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
-    case VK_VENDOR_ID_APPLE:
-        return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
-    case VK_VENDOR_ID_INTEL:
-        return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
-    default:
-        break;
-    }
-
-    if (m <= 32 || n <= 32) {
-        return aligned ? mmp->a_s : mmp->s;
-    }
-    if (m <= 64 || n <= 64) {
-        return aligned ? mmp->a_m : mmp->m;
-    }
-    return aligned ? mmp->a_l : mmp->l;
-}
-
-static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
-    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
-    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
-}
-
-static void ggml_vk_matmul(
-        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
-        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
-        uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
-        uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
-        uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
-        VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
-    ggml_vk_sync_buffers(subctx);
-    if (split_k == 1) {
-        const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
-        return;
-    }
-
-    GGML_ASSERT(batch_stride_d == m * n);
-
-    const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
-    // Make sure enough workgroups get assigned for split k to work
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
-    ggml_vk_sync_buffers(subctx);
-    const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
-}
-
-static void ggml_vk_matmul_id(
-        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
-        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
-        uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
-        uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
-        uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
-    VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
-        "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
-        "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
-        "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
-    ggml_vk_sync_buffers(subctx);
-    const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
-                                              nei0, nei1, nbi1, ne11 };
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
-}
-
-static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
-    return
-        tensor->nb[0] == ggml_type_size(tensor->type) &&
-        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
-        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
-}
-
-static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
-
-    // Choose "contiguous copy" shader if src/dst are contiguous
-    bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
-
-    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
-        if (contig) {
-            return ctx->device->pipeline_contig_cpy_f32_f32;
-        } else {
-            return ctx->device->pipeline_cpy_f32_f32;
-        }
-    }
-    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
-        if (contig) {
-            return ctx->device->pipeline_contig_cpy_f32_f16;
-        } else {
-            return ctx->device->pipeline_cpy_f32_f16;
-        }
-    }
-    if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
-        if (contig) {
-            return ctx->device->pipeline_contig_cpy_f16_f16;
-        } else {
-            return ctx->device->pipeline_cpy_f16_f16;
-        }
-    }
-
-    std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
-    GGML_ABORT("fatal error");
-}
-
-static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
-    VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
-    std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
-    const int tensor_type_size = ggml_type_size(tensor->type);
-
-    const uint32_t ne = ggml_nelements(tensor);
-    std::array<uint32_t, 3> elements;
-
-    if (ne > 262144) {
-        elements = { 512, 512, CEIL_DIV(ne, 262144) };
-    } else if (ne > 512) {
-        elements = { 512, CEIL_DIV(ne, 512), 1 };
-    } else {
-        elements = { ne, 1, 1 };
-    }
-
-    const vk_op_unary_push_constants pc = {
-        (uint32_t)ne,
-        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
-        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3],                       1                   , (uint32_t)tensor->ne[0]                   , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
-        0,
-        0.0f, 0.0f,
-    };
-    ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
-}
-
-static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
-    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
-    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
-    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);  // NOLINT
-    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
-
-    const uint64_t ne00 = src0->ne[0];
-    const uint64_t ne01 = src0->ne[1];
-    const uint64_t ne02 = src0->ne[2];
-    const uint64_t ne03 = src0->ne[3];
-
-    const uint64_t ne10 = src1->ne[0];
-    const uint64_t ne11 = src1->ne[1];
-    const uint64_t ne12 = src1->ne[2];
-    const uint64_t ne13 = src1->ne[3];
-
-    const uint64_t ne20 = dst->ne[0];
-    const uint64_t ne21 = dst->ne[1];
-
-    const uint64_t r2 = ne12 / ne02;
-    const uint64_t r3 = ne13 / ne03;
-
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-
-    vk_buffer d_Qx;
-    size_t qx_buf_offset = 0;
-    vk_buffer d_Qy;
-    size_t qy_buf_offset = 0;
-
-    bool src0_uma = false;
-    bool src1_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
-        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
-        src0_uma = d_Qx != nullptr;
-        src1_uma = d_Qy != nullptr;
-    }
-
-    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
-    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
-
-    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
-
-    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
-
-    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
-    const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
-
-    if (qx_needs_dequant) {
-        // Fall back to dequant + f16 mulmat
-        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
-    }
-
-    // Not implemented
-    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
-
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-
-    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
-    const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
-
-    const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
-
-    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
-
-    const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
-    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
-    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
-    const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
-    const uint64_t d_sz = sizeof(float) * d_ne;
-
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
-
-    if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
-    } else {
-        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
-    }
-    if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
-    } else {
-        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
-    }
-    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
-    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
-
-    if (dryrun) {
-        const uint64_t x_sz_upd = x_sz * ne02 * ne03;
-        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
-        const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
-        if (
-                (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
-                (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
-                (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) {
-            GGML_ABORT("Requested preallocation size is too large");
-        }
-        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
-            ctx->prealloc_size_x = x_sz_upd;
-        }
-        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
-            ctx->prealloc_size_y = y_sz_upd;
-        }
-        if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
-            ctx->prealloc_size_split_k = split_k_size;
-        }
-
-        // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
-        if (qx_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
-        }
-        if (qy_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
-        }
-        if (split_k > 1) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
-        }
-        return;
-    }
-
-    vk_buffer d_D = dst_buf_ctx->dev_buffer;
-    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
-    GGML_ASSERT(d_D != nullptr);
-    GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03);
-    vk_buffer d_X;
-    uint64_t x_buf_offset = 0;
-    vk_buffer d_Y;
-    uint64_t y_buf_offset = 0;
-    if (!src0_uma) {
-        d_Qx = src0_buf_ctx->dev_buffer;
-        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
-        GGML_ASSERT(d_Qx != nullptr);
-    }
-    if (!src1_uma) {
-        d_Qy = src1_buf_ctx->dev_buffer;
-        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
-        GGML_ASSERT(d_Qy != nullptr);
-    }
-    if (qx_needs_dequant) {
-        d_X = ctx->prealloc_x;
-        GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03);
-    } else {
-        d_X = d_Qx;
-        x_buf_offset = qx_buf_offset;
-        GGML_ASSERT(qx_sz == x_sz);
-    }
-    if (qy_needs_dequant) {
-        d_Y = ctx->prealloc_y;
-        GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
-    } else {
-        d_Y = d_Qy;
-        y_buf_offset = qy_buf_offset;
-        GGML_ASSERT(qy_sz == y_sz);
-    }
-
-    if (x_non_contig) {
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
-    } else if (qx_needs_dequant) {
-        const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
-        ggml_vk_sync_buffers(subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
-    }
-    if (y_non_contig) {
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
-    }
-
-    uint32_t stride_batch_x = ne00*ne01;
-    uint32_t stride_batch_y = ne10*ne11;
-
-    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
-        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
-    }
-
-    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
-        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
-    }
-
-    // compute
-    ggml_vk_matmul(
-        ctx, subctx, pipeline,
-        { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
-        { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
-        ne01, ne11, ne10,
-        ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
-        split_k, ne12*ne13, ne02, ne12, r2, r3
-    );  // NOLINT
-}
-
-static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
-    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
-    std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
-    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);  // NOLINT
-    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
-
-    const uint64_t ne00 = src0->ne[0];
-    const uint64_t ne01 = src0->ne[1];
-    const uint64_t ne02 = src0->ne[2];
-    const uint64_t ne03 = src0->ne[3];
-
-    const uint64_t ne10 = src1->ne[0];
-    const uint64_t ne11 = src1->ne[1];
-    const uint64_t ne12 = src1->ne[2];
-    const uint64_t ne13 = src1->ne[3];
-
-    GGML_ASSERT(ne11 == 1);
-
-    const uint64_t ne20 = dst->ne[0];
-    const uint64_t ne21 = dst->ne[1];
-    const uint64_t ne22 = dst->ne[2];
-    const uint64_t ne23 = dst->ne[3];
-
-    const uint64_t r2 = ne12 / ne02;
-    const uint64_t r3 = ne13 / ne03;
-
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-
-    vk_buffer d_Qx;
-    size_t qx_buf_offset = 0;
-    vk_buffer d_Qy;
-    size_t qy_buf_offset = 0;
-
-    bool src0_uma = false;
-    bool src1_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
-        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
-        src0_uma = d_Qx != nullptr;
-        src1_uma = d_Qy != nullptr;
-    }
-
-    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
-    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
-
-    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
-
-    const bool qx_needs_dequant = x_non_contig;
-    const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
-
-    // Not implemented
-    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
-
-    const uint64_t x_ne = ne01 * ne00;
-    const uint64_t y_ne = ne11 * ne10;
-    const uint64_t d_ne = ne11 * ne01;
-
-    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
-    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
-    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
-    const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
-    const uint64_t d_sz = sizeof(float) * d_ne;
-
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
-    if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
-    }
-    if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
-    } else {
-        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
-    }
-    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
-    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
-    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
-    GGML_ASSERT(dmmv != nullptr);
-
-    if (dryrun) {
-        const uint64_t x_sz_upd = x_sz * ne02 * ne03;
-        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
-        if (
-                (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
-                (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
-            GGML_ABORT("Requested preallocation size is too large");
-        }
-        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
-            ctx->prealloc_size_x = x_sz_upd;
-        }
-        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
-            ctx->prealloc_size_y = y_sz_upd;
-        }
-
-        // Request descriptor sets
-        if (qx_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
-        }
-        if (qy_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
-        }
-        ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
-        return;
-    }
-
-    vk_buffer d_D = dst_buf_ctx->dev_buffer;
-    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
-    GGML_ASSERT(d_D != nullptr);
-    vk_buffer d_X;
-    uint64_t x_buf_offset = 0;
-    vk_buffer d_Y;
-    uint64_t y_buf_offset = 0;
-    if(!src0_uma) {
-        d_Qx = src0_buf_ctx->dev_buffer;
-        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
-        GGML_ASSERT(d_Qx != nullptr);
-    }
-    if(!src1_uma) {
-        d_Qy = src1_buf_ctx->dev_buffer;
-        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
-        GGML_ASSERT(d_Qy != nullptr);
-    }
-    if (qx_needs_dequant) {
-        d_X = ctx->prealloc_x;
-    } else {
-        d_X = d_Qx;
-        x_buf_offset = qx_buf_offset;
-        GGML_ASSERT(qx_sz == x_sz);
-    }
-    if (qy_needs_dequant) {
-        d_Y = ctx->prealloc_y;
-    } else {
-        d_Y = d_Qy;
-        y_buf_offset = qy_buf_offset;
-        GGML_ASSERT(qy_sz == y_sz);
-    }
-
-    if (x_non_contig) {
-        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
-    }
-    if (y_non_contig) {
-        GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
-    }
-
-    uint32_t stride_batch_x = ne00*ne01;
-    uint32_t stride_batch_y = ne10*ne11;
-
-    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
-        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
-    }
-
-    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
-        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
-    }
-
-    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
-
-    uint32_t groups_x = ne01;
-    uint32_t groups_z = 1;
-
-    if (ne01 > max_groups_x) {
-        groups_z = 64;
-        groups_x /= groups_z;
-    }
-
-    // compute
-    const vk_mat_vec_push_constants pc = {
-        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
-        (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
-    };
-    ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
-                              { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
-                              sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
-}
-
-static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
-    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
-    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
-    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
-    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]);  // NOLINT
-    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]);  // NOLINT
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const uint64_t ne00 = src0->ne[0];
-    const uint64_t ne01 = src0->ne[1];
-    const uint64_t ne02 = src0->ne[2];
-    // const uint64_t ne03 = src0->ne[3];
-
-    const uint64_t ne10 = src1->ne[0];
-    const uint64_t ne11 = src1->ne[1];
-    const uint64_t ne12 = src1->ne[2];
-    // const uint64_t ne13 = src1->ne[3];
-
-    GGML_ASSERT(ne11 == 1);
-
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-
-    vk_buffer d_Qy;
-    size_t qy_buf_offset = 0;
-
-    bool src1_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
-        src1_uma = d_Qy != nullptr;
-    }
-
-    const uint64_t x_ne = ne00 * ne01 * ne02;
-    const uint64_t y_ne = ne10 * ne11 * ne12;
-    const uint64_t d_ne = ne01 * ne11 * ne12;
-
-    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
-    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
-    const uint64_t d_sz = sizeof(float) * d_ne;
-
-    if (dryrun) {
-        // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
-        return;
-    }
-
-    vk_buffer d_D = dst_buf_ctx->dev_buffer;
-    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
-    GGML_ASSERT(d_D != nullptr);
-    vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
-    const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
-    GGML_ASSERT(d_Qx != nullptr);
-    if (!src1_uma) {
-        d_Qy = src1_buf_ctx->dev_buffer;
-        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
-        GGML_ASSERT(d_Qx != nullptr);
-    }
-
-    const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
-    const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
-
-    const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
-    const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
-
-    // compute
-    const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
-    ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
-}
-
-static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
-    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
-    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
-    GGML_ASSERT(!ggml_is_transposed(src0));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-    GGML_ASSERT(!ggml_is_permuted(src0));
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const uint64_t ne00 = src0->ne[0];
-    const uint64_t ne01 = src0->ne[1];
-    const uint64_t ne02 = src0->ne[2];
-    // const uint64_t ne03 = src0->ne[3];
-
-    const uint64_t nb01 = src0->nb[1];
-    const uint64_t nb02 = src0->nb[2];
-
-    // const uint64_t ne10 = src1->ne[0];
-    const uint64_t ne11 = src1->ne[1];
-    const uint64_t ne12 = src1->ne[2];
-    // const uint64_t ne13 = src1->ne[3];
-
-    GGML_ASSERT(ne11 == 1);
-
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-
-    vk_buffer d_Qy = nullptr;
-    size_t qy_buf_offset = 0;
-
-    bool src1_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
-        src1_uma = d_Qy != nullptr;
-    }
-
-    const uint64_t d_ne = ne01 * ne11 * ne12;
-
-    const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
-    const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
-
-    const uint64_t qx_sz = ggml_nbytes(src0);
-    const uint64_t qy_sz = ggml_nbytes(src1);
-    const uint64_t d_sz = sizeof(float) * d_ne;
-
-    if (dryrun) {
-        // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
-        return;
-    }
-
-    vk_buffer d_D = dst_buf_ctx->dev_buffer;
-    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
-    GGML_ASSERT(d_D != nullptr);
-    vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
-    const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
-    GGML_ASSERT(d_Qx != nullptr);
-    if (!src1_uma) {
-        d_Qy = src1_buf_ctx->dev_buffer;
-        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
-        GGML_ASSERT(d_Qx != nullptr);
-    }
-
-    const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
-    const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
-
-    const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
-    const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
-
-    // compute
-    const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
-    ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
-        { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
-}
-
-static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
-    if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
-        // detect 0213 permutation, and batch size of 1
-        src0->nb[0] <= src0->nb[2] &&
-        src0->nb[2] <= src0->nb[1] &&
-        src0->nb[1] <= src0->nb[3] &&
-        src1->nb[0] <= src1->nb[2] &&
-        src1->nb[2] <= src1->nb[1] &&
-        src1->nb[1] <= src1->nb[3] &&
-        src0->ne[3] == 1 &&
-        src1->ne[3] == 1) {
-        ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
-    } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
-               !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
-        ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
-    } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
-        ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
-    } else {
-        ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
-    }
-}
-
-static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
-    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
-    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
-    GGML_ASSERT(ids->type == GGML_TYPE_I32);
-
-    const uint64_t ne00 = src0->ne[0];
-    const uint64_t ne01 = src0->ne[1];
-    const uint64_t ne02 = src0->ne[2];
-    const uint64_t ne03 = src0->ne[3];
-
-    const uint64_t ne10 = src1->ne[0];
-    const uint64_t ne11 = src1->ne[1];
-    const uint64_t ne12 = src1->ne[2];
-    const uint64_t ne13 = src1->ne[3];
-
-    const uint64_t nei0 = ids->ne[0];
-    const uint64_t nei1 = ids->ne[1];
-    GGML_ASSERT(nei0 * nei1 <= 3072);
-
-    const uint32_t nbi1 = ids->nb[1];
-    const uint32_t nbi2 = ids->nb[2];
-
-    const uint64_t ne20 = dst->ne[0];
-    const uint64_t ne21 = dst->ne[1];
-    const uint64_t ne22 = dst->ne[2];
-    const uint64_t ne23 = dst->ne[3];
-
-    const uint64_t n_as = ne02;
-
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-    ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
-
-    vk_buffer d_Qx;
-    size_t qx_buf_offset = 0;
-    vk_buffer d_Qy;
-    size_t qy_buf_offset = 0;
-    vk_buffer d_ids;
-    size_t ids_buf_offset = 0;
-
-    bool src0_uma = false;
-    bool src1_uma = false;
-    bool ids_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
-        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
-        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
-        src0_uma = d_Qx != nullptr;
-        src1_uma = d_Qy != nullptr;
-        ids_uma = d_ids != nullptr;
-    }
-
-    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
-    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
-
-    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
-
-    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
-
-    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
-    const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
-
-    if (qx_needs_dequant) {
-        GGML_ABORT("fatal error");
-    }
-
-    // Not implemented
-    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
-
-    const uint64_t x_ne = ne01 * ne00;
-    const uint64_t y_ne = ne11 * ne10;
-    const uint64_t d_ne = ne21 * ne20;
-
-    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
-    const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
-
-    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
-
-    const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
-    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
-    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
-    const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
-    const uint64_t ids_sz = nbi2;
-    const uint64_t d_sz = sizeof(float) * d_ne;
-
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
-
-    if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
-    } else {
-        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
-    }
-    if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
-    } else {
-        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
-    }
-    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
-    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
-
-    if (dryrun) {
-        const uint64_t x_sz_upd = x_sz * ne02 * ne03;
-        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
-        if (
-                (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
-                (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
-            GGML_ABORT("Requested preallocation size is too large");
-        }
-        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
-            ctx->prealloc_size_x = x_sz_upd;
-        }
-        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
-            ctx->prealloc_size_y = y_sz_upd;
-        }
-
-        // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
-        if (qx_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
-        }
-        if (qy_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
-        }
-        return;
-    }
-
-    vk_buffer d_D = dst_buf_ctx->dev_buffer;
-    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
-    GGML_ASSERT(d_D != nullptr);
-    vk_buffer d_X;
-    uint64_t x_buf_offset = 0;
-    vk_buffer d_Y;
-    uint64_t y_buf_offset = 0;
-    if (!src0_uma) {
-        d_Qx = src0_buf_ctx->dev_buffer;
-        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
-        GGML_ASSERT(d_Qx != nullptr);
-    }
-    if (!src1_uma) {
-        d_Qy = src1_buf_ctx->dev_buffer;
-        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
-        GGML_ASSERT(d_Qy != nullptr);
-    }
-    if (!ids_uma) {
-        d_ids = ids_buf_ctx->dev_buffer;
-        ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
-        GGML_ASSERT(d_ids != nullptr);
-    }
-    if (qx_needs_dequant) {
-        d_X = ctx->prealloc_x;
-        GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03);
-    } else {
-        d_X = d_Qx;
-        x_buf_offset = qx_buf_offset;
-        GGML_ASSERT(qx_sz == x_sz);
-    }
-    if (qy_needs_dequant) {
-        d_Y = ctx->prealloc_y;
-        GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
-    } else {
-        d_Y = d_Qy;
-        y_buf_offset = qy_buf_offset;
-        GGML_ASSERT(qy_sz == y_sz);
-    }
-
-    if (x_non_contig) {
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
-    } else if (qx_needs_dequant) {
-        const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
-        ggml_vk_sync_buffers(subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
-            { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
-    }
-    if (y_non_contig) {
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
-    }
-
-    uint32_t stride_batch_x = ne00*ne01;
-    uint32_t stride_batch_y = ne10*ne11;
-
-    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
-        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
-    }
-
-    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
-        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
-    }
-
-    // compute
-    ggml_vk_matmul_id(
-        ctx, subctx, pipeline,
-        { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
-        { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
-        ne01, ne21, ne10, ne10, ne10, ne01,
-        stride_batch_x, stride_batch_y, ne20*ne21,
-        n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
-    );  // NOLINT
-}
-
-static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
-    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
-    std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
-    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);  // NOLINT
-    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
-    GGML_ASSERT(ids->type == GGML_TYPE_I32);
-
-    const uint64_t ne00 = src0->ne[0];
-    const uint64_t ne01 = src0->ne[1];
-    const uint64_t ne02 = src0->ne[2];
-    const uint64_t ne03 = src0->ne[3];
-
-    const uint64_t ne10 = src1->ne[0];
-    const uint64_t ne11 = src1->ne[1];
-    const uint64_t ne12 = src1->ne[2];
-    const uint64_t ne13 = src1->ne[3];
-
-    const uint64_t nei0 = ids->ne[0];
-    const uint64_t nei1 = ids->ne[1];
-
-    const uint64_t nbi2 = ids->nb[2];
-
-    GGML_ASSERT(nei1 == 1);
-
-    const uint64_t ne20 = dst->ne[0];
-    const uint64_t ne21 = dst->ne[1];
-    const uint64_t ne22 = dst->ne[2];
-    const uint64_t ne23 = dst->ne[3];
-
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-    ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
-
-    vk_buffer d_Qx;
-    size_t qx_buf_offset = 0;
-    vk_buffer d_Qy;
-    size_t qy_buf_offset = 0;
-    vk_buffer d_ids;
-    size_t ids_buf_offset = 0;
-
-    bool src0_uma = false;
-    bool src1_uma = false;
-    bool ids_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
-        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
-        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
-        src0_uma = d_Qx != nullptr;
-        src1_uma = d_Qy != nullptr;
-        ids_uma = d_ids != nullptr;
-    }
-
-    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
-    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
-
-    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
-
-    const bool qx_needs_dequant = x_non_contig;
-    const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
-
-    // Not implemented
-    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
-
-    const uint64_t x_ne = ne01 * ne00;
-    const uint64_t y_ne = ne11 * ne10;
-    const uint64_t d_ne = ne21 * ne20;
-
-    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
-    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
-    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
-    const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
-    const uint64_t ids_sz = nbi2;
-    const uint64_t d_sz = sizeof(float) * d_ne;
-
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
-    if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
-    }
-    if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
-    } else {
-        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
-    }
-    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
-    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
-    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
-    GGML_ASSERT(dmmv != nullptr);
-
-    if (dryrun) {
-        const uint64_t x_sz_upd = x_sz * ne02 * ne03;
-        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
-        if (
-                (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
-                (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
-            GGML_ABORT("Requested preallocation size is too large");
-        }
-        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
-            ctx->prealloc_size_x = x_sz_upd;
-        }
-        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
-            ctx->prealloc_size_y = y_sz_upd;
-        }
-
-        // Request descriptor sets
-        if (qx_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
-        }
-        if (qy_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
-        }
-        ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
-        return;
-    }
-
-    vk_buffer d_D = dst_buf_ctx->dev_buffer;
-    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
-    GGML_ASSERT(d_D != nullptr);
-    vk_buffer d_X;
-    uint64_t x_buf_offset = 0;
-    vk_buffer d_Y;
-    uint64_t y_buf_offset = 0;
-    if(!src0_uma) {
-        d_Qx = src0_buf_ctx->dev_buffer;
-        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
-        GGML_ASSERT(d_Qx != nullptr);
-    }
-    if(!src1_uma) {
-        d_Qy = src1_buf_ctx->dev_buffer;
-        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
-        GGML_ASSERT(d_Qy != nullptr);
-    }
-    if(!ids_uma) {
-        d_ids = ids_buf_ctx->dev_buffer;
-        ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
-        GGML_ASSERT(d_ids != nullptr);
-    }
-    if (qx_needs_dequant) {
-        d_X = ctx->prealloc_x;
-    } else {
-        d_X = d_Qx;
-        x_buf_offset = qx_buf_offset;
-        GGML_ASSERT(qx_sz == x_sz);
-    }
-    if (qy_needs_dequant) {
-        d_Y = ctx->prealloc_y;
-    } else {
-        d_Y = d_Qy;
-        y_buf_offset = qy_buf_offset;
-        GGML_ASSERT(qy_sz == y_sz);
-    }
-
-    if (x_non_contig) {
-        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
-    }
-    if (y_non_contig) {
-        GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
-    }
-
-    uint32_t stride_batch_y = ne10*ne11;
-
-    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
-        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
-    }
-
-    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
-
-    uint32_t groups_x = ne01;
-    uint32_t groups_z = 1;
-
-    if (ne01 > max_groups_x) {
-        groups_z = 64;
-        groups_x /= groups_z;
-    }
-
-    // compute
-    const vk_mat_vec_id_push_constants pc = {
-        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
-        (uint32_t)nei0, (uint32_t)ne11,
-    };
-    ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
-        { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
-        vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
-        sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
-}
-
-static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
-    if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
-        ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
-    } else {
-        ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
-    }
-}
-
-static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
-    switch (op) {
-    case GGML_OP_GET_ROWS:
-        GGML_ASSERT(src1->type == GGML_TYPE_I32);
-        if (dst->type == GGML_TYPE_F16) {
-            return ctx->device->pipeline_get_rows[src0->type];
-        }
-        if (dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_get_rows_f32[src0->type];
-        }
-        return nullptr;
-    case GGML_OP_ACC:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_acc_f32;
-        }
-        return nullptr;
-    case GGML_OP_ADD:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_add_f32;
-        }
-        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
-            return ctx->device->pipeline_add_f16_f32_f16;
-        }
-        return nullptr;
-    case GGML_OP_MUL:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_mul_f32;
-        }
-        return nullptr;
-    case GGML_OP_DIV:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_div_f32;
-        }
-        return nullptr;
-    case GGML_OP_CONCAT:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_concat_f32;
-        }
-        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-            return ctx->device->pipeline_concat_f16;
-        }
-        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
-            return ctx->device->pipeline_concat_i32;
-        }
-        return nullptr;
-    case GGML_OP_UPSCALE:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_upscale_f32;
-        }
-        return nullptr;
-    case GGML_OP_SCALE:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_scale_f32;
-        }
-        return nullptr;
-    case GGML_OP_SQR:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_sqr_f32;
-        }
-        return nullptr;
-    case GGML_OP_SIN:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_sin_f32;
-        }
-        return nullptr;
-    case GGML_OP_COS:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_cos_f32;
-        }
-        return nullptr;
-    case GGML_OP_CLAMP:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_clamp_f32;
-        }
-        return nullptr;
-    case GGML_OP_PAD:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_pad_f32;
-        }
-        return nullptr;
-    case GGML_OP_REPEAT:
-        if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
-            return ctx->device->pipeline_repeat_f32;
-        }
-        return nullptr;
-    case GGML_OP_CPY:
-    case GGML_OP_CONT:
-    case GGML_OP_DUP:
-        return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
-    case GGML_OP_NORM:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_norm_f32;
-        }
-        return nullptr;
-    case GGML_OP_GROUP_NORM:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_group_norm_f32;
-        }
-        return nullptr;
-    case GGML_OP_RMS_NORM:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_rms_norm_f32;
-        }
-        return nullptr;
-    case GGML_OP_UNARY:
-        switch (ggml_get_unary_op(dst)) {
-            case GGML_UNARY_OP_SILU:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_silu_f32;
-                }
-                break;
-            case GGML_UNARY_OP_GELU:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_gelu_f32;
-                }
-                break;
-            case GGML_UNARY_OP_GELU_QUICK:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_gelu_quick_f32;
-                }
-                break;
-            case GGML_UNARY_OP_RELU:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_relu_f32;
-                }
-                break;
-            case GGML_UNARY_OP_TANH:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_tanh_f32;
-                }
-                break;
-            default:
-                break;
-        }
-        return nullptr;
-    case GGML_OP_DIAG_MASK_INF:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_diag_mask_inf_f32;
-        }
-        return nullptr;
-    case GGML_OP_SOFT_MAX:
-        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
-
-        if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_soft_max_f32;
-        }
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_soft_max_f32_f16;
-        }
-        return nullptr;
-    case GGML_OP_ROPE:
-        {
-            const int mode = ((const int32_t *) dst->op_params)[2];
-            const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-
-            if (is_neox) {
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_rope_neox_f32;
-                }
-                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-                    return ctx->device->pipeline_rope_neox_f16;
-                }
-            } else {
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_rope_norm_f32;
-                }
-                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-                    return ctx->device->pipeline_rope_norm_f16;
-                }
-            }
-            return nullptr;
-        }
-    case GGML_OP_ARGSORT:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
-            return ctx->device->pipeline_argsort_f32;
-        }
-        return nullptr;
-    case GGML_OP_SUM_ROWS:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_sum_rows_f32;
-        }
-        return nullptr;
-    case GGML_OP_IM2COL:
-        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_im2col_f32;
-        }
-        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
-            return ctx->device->pipeline_im2col_f32_f16;
-        }
-        return nullptr;
-    case GGML_OP_TIMESTEP_EMBEDDING:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_timestep_embedding_f32;
-        }
-        return nullptr;
-    case GGML_OP_POOL_2D:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_pool2d_f32;
-        }
-        return nullptr;
-    case GGML_OP_LEAKY_RELU:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_leaky_relu_f32;
-        }
-        return nullptr;
-    default:
-        return nullptr;
-    }
-
-    GGML_UNUSED(src2);
-}
-
-static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
-    switch (op) {
-    case GGML_OP_CPY:
-    case GGML_OP_GET_ROWS:
-    case GGML_OP_ADD:
-    case GGML_OP_MUL:
-    case GGML_OP_DIV:
-    case GGML_OP_CONCAT:
-    case GGML_OP_UPSCALE:
-    case GGML_OP_SQR:
-    case GGML_OP_SIN:
-    case GGML_OP_COS:
-    case GGML_OP_CLAMP:
-    case GGML_OP_PAD:
-    case GGML_OP_REPEAT:
-        return true;
-    default:
-        return false;
-    }
-}
-
-template<typename PC>
-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) {
-    VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
-    if (src1 != nullptr) {
-        std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
-    }
-    if (src2 != nullptr) {
-        std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
-    }
-    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
-    std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
-    GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))));  // NOLINT
-    GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0));  // NOLINT
-    GGML_ASSERT(dst->buffer != nullptr);
-    const uint64_t ne00 = src0->ne[0];
-    const uint64_t ne01 = src0->ne[1];
-    const uint64_t ne02 = src0->ne[2];
-    const uint64_t ne03 = src0->ne[3];
-    const uint64_t ne0 = ne00 * ne01;
-
-    const bool use_src1 = src1 != nullptr;
-    const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
-    const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
-    const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
-    const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
-    const uint64_t ne1 = ne10 * ne11;
-    // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0;
-
-    const bool use_src2 = src2 != nullptr;
-    const uint64_t ne20 = use_src2 ? src2->ne[0] : 0;
-    const uint64_t ne21 = use_src2 ? src2->ne[1] : 0;
-    const uint64_t ne22 = use_src2 ? src2->ne[2] : 0;
-    const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
-    const uint64_t ne2 = ne20 * ne21;
-
-    const uint64_t ned0 = dst->ne[0];
-    const uint64_t ned1 = dst->ne[1];
-    const uint64_t ned2 = dst->ne[2];
-    const uint64_t ned3 = dst->ne[3];
-    const uint64_t ned = ned0 * ned1;
-
-    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
-
-    if (pipeline == nullptr) {
-        std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type);
-        if (src1 != nullptr) {
-            std::cerr << " and " << ggml_type_name(src1->type);
-        }
-        std::cerr << " to " << ggml_type_name(dst->type) << std::endl;
-        GGML_ABORT("fatal error");
-    }
-
-    if (dryrun) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
-        return;
-    }
-
-    const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
-
-    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-    ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
-    ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
-
-    vk_buffer d_X = nullptr;
-    size_t x_buf_offset = 0;
-    vk_buffer d_Y = nullptr;
-    size_t y_buf_offset = 0;
-    vk_buffer d_Z = nullptr;
-    size_t z_buf_offset = 0;
-
-    bool src0_uma = false;
-    bool src1_uma = false;
-    bool src2_uma = false;
-
-    if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
-        src0_uma = d_X != nullptr;
-        if (use_src1) {
-            ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset);
-            src1_uma = d_Y != nullptr;
-        }
-        if (use_src2) {
-            ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
-            src2_uma = d_Z != nullptr;
-        }
-    }
-
-    uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0;
-    uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0;
-    uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
-    uint64_t d_sz = ggml_type_size(dst->type) * ned;
-
-    vk_buffer d_D = dst_buf_ctx->dev_buffer;
-
-    // Workaround for tiny tensor inputs on ROPE
-    if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
-        y_sz = VK_WHOLE_SIZE;
-    }
-
-    GGML_ASSERT(d_D != nullptr);
-    uint64_t d_buf_offset = ((vk_tensor_offset(dst) + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
-    GGML_ASSERT(d_buf_offset == vk_tensor_offset(dst) || op == GGML_OP_CPY);  // NOLINT
-    if(!src0_uma) {
-        d_X = src0_buf_ctx->dev_buffer;
-        x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
-        GGML_ASSERT(d_X != nullptr);
-    }
-    if (use_src1 && !src1_uma) {
-        d_Y = src1_buf_ctx->dev_buffer;
-        y_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
-        GGML_ASSERT(d_Y != nullptr);
-    }
-    if (use_src2 && !src2_uma) {
-        d_Z = src2_buf_ctx->dev_buffer;
-        z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
-        GGML_ASSERT(d_Z != nullptr);
-    }
-
-    if (op_supports_incontiguous) {
-        x_sz = ggml_nbytes(src0);
-        y_sz = use_src1 ? ggml_nbytes(src1) : 0;
-        z_sz = use_src2 ? ggml_nbytes(src2) : 0;
-        d_sz = ggml_nbytes(dst);
-
-        if (x_buf_offset + x_sz >= d_X->size) {
-            x_sz = VK_WHOLE_SIZE;
-        }
-        if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
-            y_sz = VK_WHOLE_SIZE;
-        }
-        if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
-            z_sz = VK_WHOLE_SIZE;
-        }
-        if (d_buf_offset + d_sz >= d_D->size) {
-            d_sz = VK_WHOLE_SIZE;
-        }
-    }
-
-    std::array<uint32_t, 3> elements;
-
-    // Single call if dimension 2 is contiguous
-    GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
-
-    switch (op) {
-    case GGML_OP_NORM:
-    case GGML_OP_RMS_NORM:
-    case GGML_OP_SOFT_MAX:
-    case GGML_OP_SUM_ROWS:
-        {
-            const uint32_t nr = ggml_nrows(src0);
-            if (nr > 262144) {
-                elements = { 512, 512, CEIL_DIV(nr, 262144) };
-            } else if (nr > 512) {
-                elements = { 512, CEIL_DIV(nr, 512), 1 };
-            } else {
-                elements = { nr, 1, 1 };
-            }
-        } break;
-    case GGML_OP_GROUP_NORM:
-        {
-            const uint32_t num_groups = dst->op_params[0];
-            elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
-        } break;
-    case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_ROPE:
-        elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
-        break;
-    case GGML_OP_GET_ROWS:
-        elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
-        break;
-    case GGML_OP_ARGSORT:
-        elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
-        break;
-    case GGML_OP_IM2COL:
-        {
-            const bool is_2D = dst->op_params[6] == 1;
-
-            const uint32_t IC = src1->ne[is_2D ? 2 : 1];
-
-            const uint32_t KH = is_2D ? src0->ne[1] : 1;
-            const uint32_t KW =         src0->ne[0];
-
-            const uint32_t OH = is_2D ? dst->ne[2] : 1;
-            const uint32_t OW =         dst->ne[1];
-
-            const uint32_t batch = src1->ne[is_2D ? 3 : 2];
-
-            elements = { OW * KW * KH, OH, batch * IC };
-        } break;
-    case GGML_OP_TIMESTEP_EMBEDDING:
-        {
-            const uint32_t dim = dst->op_params[0];
-            uint32_t half_ceil = (dim + 1) / 2;
-            elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
-        } break;
-    case GGML_OP_POOL_2D:
-        {
-            const uint32_t N = dst->ne[3];
-            const uint32_t OC = dst->ne[2];
-            const uint32_t OH = dst->ne[1];
-            const uint32_t OW = dst->ne[0];
-            elements = { N * OC * OH * OW, 1, 1};
-        } break;
-    case GGML_OP_ADD:
-    case GGML_OP_DIV:
-    case GGML_OP_MUL:
-    case GGML_OP_SCALE:
-    case GGML_OP_SQR:
-    case GGML_OP_SIN:
-    case GGML_OP_COS:
-    case GGML_OP_CLAMP:
-    case GGML_OP_PAD:
-    case GGML_OP_REPEAT:
-    case GGML_OP_CPY:
-    case GGML_OP_CONCAT:
-    case GGML_OP_UPSCALE:
-    case GGML_OP_UNARY:
-        {
-            const uint32_t ne = ggml_nelements(dst);
-            if (ne > 262144) {
-                elements = { 512, 512, CEIL_DIV(ne, 262144) };
-            } else if (ne > 512) {
-                elements = { 512, CEIL_DIV(ne, 512), 1 };
-            } else {
-                elements = { ne, 1, 1 };
-            }
-        } break;
-    default:
-        elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
-        break;
-    }
-
-    if (!op_supports_incontiguous) {
-        if (x_sz != VK_WHOLE_SIZE) {
-            x_sz *= ne02 * ne03;
-        }
-        if (use_src1 && y_sz != VK_WHOLE_SIZE) {
-            y_sz *= ne12 * ne13;
-        }
-        if (use_src2 && z_sz != VK_WHOLE_SIZE) {
-            z_sz *= ne22 * ne23;
-        }
-        if (d_sz != VK_WHOLE_SIZE) {
-            d_sz *= ned2 * ned3;
-        }
-    }
-
-    if (op == GGML_OP_SOFT_MAX) {
-        // Empty src1 is possible in soft_max, but the shader needs a buffer
-        vk_subbuffer subbuf_y;
-        if (use_src1) {
-            subbuf_y = { d_Y, y_buf_offset, y_sz };
-        } else {
-            subbuf_y = { d_X, 0, x_sz };
-        }
-
-        ggml_vk_sync_buffers(subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-    } else if (op == GGML_OP_ROPE) {
-        // Empty src2 is possible in rope, but the shader needs a buffer
-        vk_subbuffer subbuf_z;
-        if (use_src2) {
-            subbuf_z = { d_Z, z_buf_offset, z_sz };
-        } else {
-            subbuf_z = { d_X, 0, x_sz };
-        }
-
-        ggml_vk_sync_buffers(subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-    } else if (op == GGML_OP_IM2COL) {
-        // im2col uses only src1 and dst buffers
-        ggml_vk_sync_buffers(subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-    } else if (use_src2) {
-        ggml_vk_sync_buffers(subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-    } else if (use_src1) {
-        ggml_vk_sync_buffers(subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-    } else {
-        ggml_vk_sync_buffers(subctx);
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-    }
-}
-
-static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t src1_type_size = ggml_type_size(src1->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f, 0,
-    }, dryrun);
-}
-
-static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t src1_type_size = ggml_type_size(src1->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-    const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
-
-    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
-    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
-    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
-    int offset = dst->op_params[3] / 4; // offset in bytes
-
-    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] /  dst_type_size,
-        d_offset,
-        0.0f, 0.0f, offset,
-    }, dryrun);
-}
-
-static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t src1_type_size = ggml_type_size(src1->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f, 0,
-    }, dryrun);
-}
-
-static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t src1_type_size = ggml_type_size(src1->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f, 0,
-    }, dryrun);
-}
-
-static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t src1_type_size = ggml_type_size(src1->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f, 0,
-    }, dryrun);
-}
-
-static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    int * op_params = (int *)dst->op_params;
-
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t src1_type_size = ggml_type_size(src1->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, {
-        (uint32_t)ggml_nelements(dst),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f, op_params[0],
-    }, dryrun);
-}
-
-static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-
-    const float sf0 = (float)dst->ne[0] / src0->ne[0];
-    const float sf1 = (float)dst->ne[1] / src0->ne[1];
-    const float sf2 = (float)dst->ne[2] / src0->ne[2];
-    const float sf3 = (float)dst->ne[3] / src0->ne[3];
-
-    ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
-        (uint32_t)ggml_nelements(dst), 0,
-        (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
-        sf0, sf1, sf2, sf3,
-    }, dryrun);
-}
-
-static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    float * op_params = (float *)dst->op_params;
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        op_params[0], 0.0f
-    }, dryrun);
-}
-
-static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f,
-    }, dryrun);
-}
-
-static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f,
-    }, dryrun);
-}
-
-static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f,
-    }, dryrun);
-}
-
-static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    float * op_params = (float *)dst->op_params;
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        op_params[0], op_params[1],
-    }, dryrun);
-}
-
-static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
-        (uint32_t)ggml_nelements(dst),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f,
-    }, dryrun);
-}
-
-static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
-        (uint32_t)ggml_nelements(dst),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        0,
-        0.0f, 0.0f,
-    }, dryrun);
-}
-
-static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
-    const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
-
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
-        (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
-        d_offset,
-        0.0f, 0.0f,
-    }, dryrun);
-}
-
-static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    float * op_params = (float *)dst->op_params;
-
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
-}
-
-static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const int * int_op_params = (const int *)dst->op_params;
-    const float * float_op_params = (const float *)dst->op_params;
-
-    const uint32_t num_groups = int_op_params[0];
-    const float eps = float_op_params[1];
-    const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
-
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
-}
-
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
-}
-
-static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
-}
-
-static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    int32_t * op_params = (int32_t *)dst->op_params;
-    ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
-}
-
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    float * op_params = (float *)dst->op_params;
-
-    float scale = op_params[0];
-    float max_bias = op_params[1];
-
-    const uint32_t ncols =   (uint32_t)src0->ne[0];
-    const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
-    const uint32_t nrows_y = (uint32_t)src0->ne[1];
-
-    const uint32_t n_head_kv   = nrows_x/nrows_y;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
-        ncols,
-        src1 != nullptr ? nrows_y : (uint32_t)0,
-        scale, max_bias,
-        m0, m1,
-        n_head_log2,
-    }, dryrun);
-}
-
-static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
-    const int n_dims        = ((int32_t *) dst->op_params)[1];
-    // const int mode          = ((int32_t *) dst->op_params)[2];
-    // const int n_ctx         = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig    = ((int32_t *) dst->op_params)[4];
-    const float freq_base   = ((float *)   dst->op_params)[5];
-    const float freq_scale  = ((float *)   dst->op_params)[6];
-    const float ext_factor  = ((float *)   dst->op_params)[7];
-    const float attn_factor = ((float *)   dst->op_params)[8];
-    const float beta_fast   = ((float *)   dst->op_params)[9];
-    const float beta_slow   = ((float *)   dst->op_params)[10];
-
-    float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
-        (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
-        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
-        src2 != nullptr,
-    }, dryrun);
-}
-
-static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    int32_t * op_params = (int32_t *)dst->op_params;
-
-    uint32_t ncols = src0->ne[0];
-
-    uint32_t ncols_pad = 1;
-    while (ncols_pad < ncols) {
-        ncols_pad *= 2;
-    }
-
-    GGML_ASSERT(ncols_pad <= 1024);
-
-    ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
-        ncols,
-        ncols_pad,
-        op_params[0],
-    }, dryrun);
-}
-
-static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
-}
-
-static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    const int32_t s0 = dst->op_params[0];
-    const int32_t s1 = dst->op_params[1];
-    const int32_t p0 = dst->op_params[2];
-    const int32_t p1 = dst->op_params[3];
-    const int32_t d0 = dst->op_params[4];
-    const int32_t d1 = dst->op_params[5];
-
-    const bool is_2D = dst->op_params[6] == 1;
-
-    const uint32_t IC = src1->ne[is_2D ? 2 : 1];
-    const uint32_t IH = is_2D ? src1->ne[1] : 1;
-    const uint32_t IW =         src1->ne[0];
-
-    const uint32_t KH = is_2D ? src0->ne[1] : 1;
-    const uint32_t KW =         src0->ne[0];
-
-    const uint32_t OH = is_2D ? dst->ne[2] : 1;
-    const uint32_t OW =         dst->ne[1];
-
-    const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
-    const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
-
-    const uint32_t pelements = OW * KW * KH;
-
-    ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
-        batch_offset, offset_delta,
-        IC, IW, IH, OW, OH, KW, KH,
-        pelements,
-        IC * KH * KW,
-        s0, s1, p0, p1, d0, d1,
-    }, dryrun);
-}
-
-static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const uint32_t dim = dst->op_params[0];
-    const uint32_t max_period = dst->op_params[1];
-    const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
-
-    ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
-        nb1, dim, max_period,
-    }, dryrun);
-}
-
-static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
-    const int32_t k1 = dst->op_params[1];
-    const int32_t k0 = dst->op_params[2];
-    const int32_t s1 = dst->op_params[3];
-    const int32_t s0 = dst->op_params[4];
-    const int32_t p1 = dst->op_params[5];
-    const int32_t p0 = dst->op_params[6];
-
-    const uint32_t IH = src0->ne[1];
-    const uint32_t IW = src0->ne[0];
-
-    const uint32_t N = dst->ne[3];
-
-    const uint32_t OC = dst->ne[2];
-    const uint32_t OH = dst->ne[1];
-    const uint32_t OW = dst->ne[0];
-
-    const uint32_t parallel_elements = N * OC * OH * OW;
-
-    ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
-        IW, IH, OW, OH, OC,
-        parallel_elements,
-        op,
-        k0, k1, s0, s1, p0, p1,
-    }, dryrun);
-}
-
-static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
-    const float * op_params = (const float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
-}
-
-#ifdef GGML_VULKAN_RUN_TESTS
-static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
-    if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
-        return;
-    }
-    i0 = std::max(i0, 5);
-    i1 = std::max(i1, 5);
-    i2 = std::max(i2, 0);
-    fprintf(stderr, "         ");
-    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
-        fprintf(stderr, "%7d ", idx1);
-    }
-    fprintf(stderr, "\n");
-    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
-        fprintf(stderr, "%7d: ", idx0);
-        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
-            if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) {
-                float val;
-                if (type == GGML_TYPE_F32) {
-                    val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0);
-                } else if (type == GGML_TYPE_F16) {
-                    val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0));
-                } else {
-                    GGML_ABORT("fatal error");
-                }
-                fprintf(stderr, "% 7.2f ", val);
-            } else {
-                fprintf(stderr, "        ");
-            }
-        }
-        fprintf(stderr, "\n");
-    }
-}
-
-template <typename X_TYPE, typename Y_TYPE>
-static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) {
-    VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")");
-    const size_t x_ne = m * k * batch;
-    const size_t y_ne = k * n * batch;
-    const size_t d_ne = m * n * batch;
-
-    vk_pipeline p;
-    std::string shname;
-    if (shader_size == 0) {
-        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f32->a_s;
-            shname = "F32_ALIGNED_S";
-        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f32_f16->a_s;
-            shname = "F32_F16_ALIGNED_S";
-        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f16_f32->a_s;
-            shname = "F16_F32_ALIGNED_S";
-        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f16->a_s;
-            shname = "F16_ALIGNED_S";
-        } else {
-            GGML_ABORT("fatal error");
-        }
-    } else if (shader_size == 1) {
-        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f32->a_m;
-            shname = "F32_ALIGNED_M";
-        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f32_f16->a_m;
-            shname = "F32_F16_ALIGNED_M";
-        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f16_f32->a_m;
-            shname = "F16_F32_ALIGNED_M";
-        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f16->a_m;
-            shname = "F16_ALIGNED_M";
-        } else {
-            GGML_ABORT("fatal error");
-        }
-    } else if (shader_size == 2) {
-        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f32->a_l;
-            shname = "F32_ALIGNED_L";
-        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f32_f16->a_l;
-            shname = "F32_F16_ALIGNED_L";
-        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f16_f32->a_l;
-            shname = "F16_F32_ALIGNED_L";
-        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = ctx->device->pipeline_matmul_f16->a_l;
-            shname = "F16_ALIGNED_L";
-        } else {
-            GGML_ABORT("fatal error");
-        }
-    } else {
-        GGML_ASSERT(0);
-    }
-
-    const size_t kpad = ggml_vk_align_size(k, p->align);
-
-    if (k != kpad) {
-        if (shader_size == 0) {
-            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f32->s;
-                shname = "F32_S";
-            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f32_f16->s;
-                shname = "F32_F16_S";
-            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f16_f32->s;
-                shname = "F16_F32_S";
-            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f16->s;
-                shname = "F16_S";
-            }
-        } else if (shader_size == 1) {
-            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f32->m;
-                shname = "F32_M";
-            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f32_f16->m;
-                shname = "F32_F16_M";
-            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f16_f32->m;
-                shname = "F16_F32_M";
-            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f16->m;
-                shname = "F16_M";
-            }
-        } else if (shader_size == 2) {
-            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f32->l;
-                shname = "F32_L";
-            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f32_f16->l;
-                shname = "F32_F16_L";
-            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f16_f32->l;
-                shname = "F16_F32_L";
-            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = ctx->device->pipeline_matmul_f16->l;
-                shname = "F16_L";
-            }
-        }
-    }
-
-    ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
-    if (split_k > 1) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
-
-        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
-            // Resize buffer
-            if (ctx->prealloc_split_k != nullptr) {
-                ggml_vk_destroy_buffer(ctx->prealloc_split_k);
-            }
-            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
-        }
-    }
-
-    ggml_pipeline_allocate_descriptor_sets(ctx->device);
-
-    vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
-
-    X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
-    Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
-    float* d = (float *) malloc(sizeof(float) * d_ne);
-
-    for (size_t i = 0; i < x_ne; i++) {
-        if (std::is_same<float, X_TYPE>()) {
-            x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
-        } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
-            x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
-        } else {
-            GGML_ABORT("fatal error");
-        }
-    }
-    for (size_t i = 0; i < y_ne; i++) {
-        if (std::is_same<float, Y_TYPE>()) {
-            // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
-            y[i] = (i % k == i / k) ? 1.0f : 0.0f;
-        } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
-            y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
-        } else {
-            GGML_ABORT("fatal error");
-        }
-    }
-
-    ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
-    ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
-
-    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
-    for (size_t i = 0; i < num_it; i++) {
-        ggml_vk_ctx_begin(ctx->device, subctx);
-        ggml_vk_matmul(
-            ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
-            m, n, k,
-            k, k, m, k*m, k*n, m*n,
-            split_k, batch, batch, batch, 1, 1
-        );
-        ggml_vk_ctx_end(subctx);
-    }
-
-    auto begin = std::chrono::high_resolution_clock::now();
-    ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
-    ctx->device->device.resetFences({ ctx->fence });
-
-    auto end = std::chrono::high_resolution_clock::now();
-    double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
-
-    // copy dst to host
-    ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne);
-
-    float * d_chk = (float *) malloc(sizeof(float) * d_ne);
-
-    ggml_init_params iparams = {
-        /*.mem_size   =*/ 1024*1024*1024,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-
-    ggml_context * ggml_ctx = ggml_init(iparams);
-
-    ggml_type src0_type;
-    ggml_type src1_type;
-
-    if (std::is_same<float, X_TYPE>()) {
-        src0_type = GGML_TYPE_F32;
-    } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
-        src0_type = GGML_TYPE_F16;
-    } else {
-        GGML_ABORT("fatal error");
-    }
-    if (std::is_same<float, Y_TYPE>()) {
-        src1_type = GGML_TYPE_F32;
-    } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
-        src1_type = GGML_TYPE_F16;
-    } else {
-        GGML_ABORT("fatal error");
-    }
-
-    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch);
-    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch);
-    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
-
-    src0_ggml->data = x;
-    src1_ggml->data = y;
-    tensor_ggml->data = d_chk;
-
-    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
-    ggml_build_forward_expand(cgraph, tensor_ggml);
-
-    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
-
-    ggml_free(ggml_ctx);
-
-    double avg_err = 0.0;
-    int first_err_n = -1;
-    int first_err_m = -1;
-    int first_err_b = -1;
-
-    for (size_t i = 0; i < m*n*batch; i++) {
-        double err = std::fabs(d[i] - d_chk[i]);
-        avg_err += err;
-
-        if (err > 0.05f && first_err_n == -1) {
-            first_err_b = i / (m * n);
-            first_err_n = (i % (m * n)) / m;
-            first_err_m = (i % (m * n)) % m;
-        }
-    }
-
-    avg_err /= m * n;
-
-    double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
-
-    std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
-
-    if (avg_err > 0.1) {
-        std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
-        std::cerr << "Actual result: " << std::endl << std::endl;
-        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-        std::cerr << std::endl;
-        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
-        std::cerr << "Expected result: " << std::endl << std::endl;
-        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-        if (split_k > 1) {
-            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
-            ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
-
-            std::cerr << "d_buf0: " << std::endl << std::endl;
-            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-            std::cerr << "d_buf1: " << std::endl << std::endl;
-            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-            std::cerr << "d_buf2: " << std::endl << std::endl;
-            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-            std::cerr << "d_buf3: " << std::endl << std::endl;
-            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-            free(split_k_buf);
-        }
-    }
-
-    free(d_chk);
-
-    ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
-    ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
-
-    ggml_vk_destroy_buffer(d_X);
-    ggml_vk_destroy_buffer(d_Y);
-    ggml_vk_destroy_buffer(d_D);
-
-    ggml_pipeline_cleanup(p);
-    ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
-
-    free(x);
-    free(y);
-    free(d);
-}
-
-static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
-    if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
-        return;
-    }
-    i0 = std::max(i0, 5);
-    i1 = std::max(i1, 5);
-    i2 = std::max(i2, 0);
-    i3 = std::max(i3, 0);
-    fprintf(stderr, "         ");
-    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
-        fprintf(stderr, "%7d ", idx1);
-    }
-    fprintf(stderr, "\n");
-    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
-        fprintf(stderr, "%7d: ", idx0);
-        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
-            if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
-                float val;
-                if (tensor->type == GGML_TYPE_F32) {
-                    val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
-                } else if (tensor->type == GGML_TYPE_F16) {
-                    val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
-                } else {
-                    GGML_ABORT("fatal error");
-                }
-                fprintf(stderr, "% 7.2f ", val);
-            } else {
-                fprintf(stderr, "        ");
-            }
-        }
-        fprintf(stderr, "\n");
-    }
-}
-
-static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
-    ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);
-}
-
-static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) {
-    if (quant == GGML_TYPE_F32) {
-        memcpy(to, from, sizeof(float) * ne);
-        return;
-    }
-
-    const auto * tt = ggml_get_type_traits(quant);
-
-    ggml_to_float_t dequant_fn = tt->to_float;
-
-    dequant_fn(from, to, ne);
-}
-
-static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
-    VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")");
-    const size_t x_sz = sizeof(float) * ne;
-    const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
-    const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
-    float * x = (float *) malloc(x_sz);
-    void * qx = malloc(qx_sz);
-    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    float * x_ref = (float *) malloc(x_sz);
-    ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
-
-    for (size_t i = 0; i < ne; i++) {
-        x[i] = rand() / (float)RAND_MAX;
-    }
-
-    vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant);
-
-    ggml_vk_quantize_data(x, qx, ne, quant);
-    ggml_vk_dequantize_data(qx, x_ref, ne, quant);
-
-    ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
-
-    ggml_pipeline_allocate_descriptor_sets(ctx->device);
-
-    ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
-
-    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
-    ggml_vk_ctx_begin(ctx->device, subctx);
-    const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
-    ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
-    ggml_vk_ctx_end(subctx);
-
-    auto begin = std::chrono::high_resolution_clock::now();
-
-    ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
-    ctx->device->device.resetFences({ ctx->fence });
-
-    auto end = std::chrono::high_resolution_clock::now();
-
-    double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
-    ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16);
-
-    int first_err = -1;
-
-    double avg_err = 0.0;
-    for (size_t i = 0; i < ne; i++) {
-        double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i]));
-        avg_err += error;
-
-        if (first_err < 0 && error > 0.05) {
-            first_err = i;
-        }
-    }
-
-    avg_err /= ne;
-
-    std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
-
-    if (avg_err > 0.1) {
-        std::cerr << "first_error = " << first_err << std::endl;
-        std::cerr << "Actual result: " << std::endl << std::endl;
-        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
-            std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
-        }
-        std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
-        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
-            std::cerr << x_ref[i] << ", ";
-        }
-        std::cerr << std::endl;
-    }
-
-    ggml_vk_destroy_buffer(x_buf);
-    ggml_vk_destroy_buffer(qx_buf);
-
-    free(x);
-    free(qx);
-    free(x_ref);
-    free(x_chk);
-}
-
-static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
-    VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
-    const size_t x_ne = m * k * batch;
-    const size_t y_ne = k * n * batch;
-    const size_t d_ne = m * n * batch;
-
-    vk_pipeline p;
-    std::string shname;
-    if (shader_size == 0) {
-        p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
-        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
-    } else if (shader_size == 1) {
-        p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
-        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
-    } else if (shader_size == 2) {
-        p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
-        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
-    } else {
-        GGML_ASSERT(0);
-    }
-
-    const size_t kpad = ggml_vk_align_size(k, p->align);
-
-    if (k != kpad) {
-        if (shader_size == 0) {
-            p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
-            shname = std::string(ggml_type_name(quant)) + "_S";
-        } else if (shader_size == 1) {
-            p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
-            shname = std::string(ggml_type_name(quant)) + "_M";
-        } else if (shader_size == 2) {
-            p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
-            shname = std::string(ggml_type_name(quant)) + "_L";
-        } else {
-            GGML_ASSERT(0);
-        }
-    }
-
-    const size_t x_sz = sizeof(float) * x_ne;
-    const size_t y_sz = sizeof(float) * y_ne;
-    const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
-    const size_t d_sz = sizeof(float) * d_ne;
-    float * x = (float *) malloc(x_sz);
-    float * y = (float *) malloc(y_sz);
-    void * qx = malloc(qx_sz);
-    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    float * d = (float *) malloc(d_sz);
-    float * d_chk = (float *) malloc(d_sz);
-
-    for (size_t i = 0; i < x_ne; i++) {
-        x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
-    }
-
-    ggml_vk_quantize_data(x, qx, x_ne, quant);
-
-    for (size_t i = 0; i < y_ne; i++) {
-        // y[i] = rand() / (float)RAND_MAX;
-        y[i] = (i % k == i / k) ? 1.0f : 0.0f;
-    }
-
-    ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
-    if (split_k > 1) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
-
-        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
-            // Resize buffer
-            if (ctx->prealloc_split_k != nullptr) {
-                ggml_vk_destroy_buffer(ctx->prealloc_split_k);
-            }
-            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
-        }
-    }
-
-    ggml_pipeline_allocate_descriptor_sets(ctx->device);
-
-    ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
-    ggml_vk_buffer_write(y_buf, 0, y, y_sz);
-
-    vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
-    for (size_t i = 0; i < num_it; i++) {
-        ggml_vk_ctx_begin(ctx->device, subctx);
-        ggml_vk_matmul(
-            ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
-            m, n, k,
-            k, k, m, k*m, k*n, m*n,
-            split_k, batch, batch, batch, 1, 1
-        );
-        ggml_vk_ctx_end(subctx);
-    }
-
-    auto begin = std::chrono::high_resolution_clock::now();
-
-    ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
-    ctx->device->device.resetFences({ ctx->fence });
-
-    auto end = std::chrono::high_resolution_clock::now();
-
-    double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
-    ggml_vk_buffer_read(d_buf, 0, d, d_sz);
-
-    ggml_init_params iparams = {
-        /*.mem_size   =*/ 1024*1024*1024,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-
-    ggml_context * ggml_ctx = ggml_init(iparams);
-
-    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
-    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
-    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
-
-    src0_ggml->data = qx;
-    src1_ggml->data = y;
-    tensor_ggml->data = d_chk;
-
-    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
-    ggml_build_forward_expand(cgraph, tensor_ggml);
-
-    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
-
-    ggml_free(ggml_ctx);
-
-    double avg_err = 0.0;
-    int first_err_n = -1;
-    int first_err_m = -1;
-    int first_err_b = -1;
-
-    for (size_t i = 0; i < m*n*batch; i++) {
-        double err = std::fabs(d[i] - d_chk[i]);
-        avg_err += err;
-
-        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
-            first_err_b = i / (m * n);
-            first_err_n = (i % (m * n)) / m;
-            first_err_m = (i % (m * n)) % m;
-        }
-    }
-
-    avg_err /= m * n;
-
-    double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
-
-    std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
-
-    if (avg_err > 0.01 || std::isnan(avg_err)) {
-        std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
-        std::cerr << "Actual result: " << std::endl << std::endl;
-        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-        std::cerr << std::endl;
-        std::cerr << "Expected result: " << std::endl << std::endl;
-        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-        if (split_k > 1) {
-            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
-            ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
-
-            std::cerr << "d_buf0: " << std::endl << std::endl;
-            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-            std::cerr << "d_buf1: " << std::endl << std::endl;
-            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-            std::cerr << "d_buf2: " << std::endl << std::endl;
-            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-            std::cerr << "d_buf3: " << std::endl << std::endl;
-            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-
-            free(split_k_buf);
-        }
-    }
-
-    ggml_vk_destroy_buffer(qx_buf);
-    ggml_vk_destroy_buffer(y_buf);
-    ggml_vk_destroy_buffer(d_buf);
-
-    free(x);
-    free(qx);
-    free(y);
-    free(d);
-    free(d_chk);
-}
-#endif
-
-static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
-#if defined(GGML_VULKAN_RUN_TESTS)
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
-
-    ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
-
-    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
-    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
-    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
-    // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
-    // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
-    // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
-
-    std::cerr << std::endl;
-
-    const std::vector<size_t> vals {
-        8, 8, 8,
-        100, 46, 576,
-        623, 111, 128,
-        100, 46, 558,
-        512, 1, 256,
-        128, 110, 622,
-        511, 511, 127,
-        511, 511, 7,
-        511, 511, 17,
-        49, 49, 128,
-        128, 49, 49,
-        4096, 49, 4096,
-        11008, 49, 4096,
-        4096, 49, 11008,
-        32000, 49, 4096,
-        512, 512, 128,
-        128, 512, 512,
-        4096, 512, 4096,
-        11008, 512, 4096,
-        4096, 512, 11008,
-        32000, 512, 4096,
-    };
-    const size_t num_it = 1;
-    for (size_t i = 0; i < vals.size(); i += 3) {
-        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
-        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
-        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
-        // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
-        // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
-        // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
-        std::cerr << std::endl;
-    }
-
-    GGML_ABORT("fatal error");
-#endif
-
-    if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
-        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")");
-        // Resize buffer
-        if (ctx->prealloc_x != nullptr) {
-            ggml_vk_destroy_buffer(ctx->prealloc_x);
-        }
-        ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x);
-    }
-    if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {
-        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")");
-        // Resize buffer
-        if (ctx->prealloc_y != nullptr) {
-            ggml_vk_destroy_buffer(ctx->prealloc_y);
-        }
-        ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
-    }
-    if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
-        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
-        // Resize buffer
-        if (ctx->prealloc_split_k != nullptr) {
-            ggml_vk_destroy_buffer(ctx->prealloc_split_k);
-        }
-        ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
-    }
-}
-
-static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence);
-
-// Returns true if node has enqueued work into the queue, false otherwise
-// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
-static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){
-    if (ggml_is_empty(node) || !node->buffer) {
-        return false;
-    }
-
-    VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
-    ctx->semaphore_idx = 0;
-
-    const ggml_tensor * src0 = node->src[0];
-    const ggml_tensor * src1 = node->src[1];
-    const ggml_tensor * src2 = node->src[2];
-
-    switch (node->op) {
-    // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
-    case GGML_OP_RESHAPE:
-    case GGML_OP_VIEW:
-    case GGML_OP_PERMUTE:
-    case GGML_OP_TRANSPOSE:
-    case GGML_OP_NONE:
-        return false;
-    case GGML_OP_UNARY:
-        switch (ggml_get_unary_op(node)) {
-        case GGML_UNARY_OP_SILU:
-        case GGML_UNARY_OP_GELU:
-        case GGML_UNARY_OP_GELU_QUICK:
-        case GGML_UNARY_OP_RELU:
-        case GGML_UNARY_OP_TANH:
-            break;
-        default:
-            return false;
-        }
-        break;
-    case GGML_OP_REPEAT:
-    case GGML_OP_GET_ROWS:
-    case GGML_OP_ADD:
-    case GGML_OP_ACC:
-    case GGML_OP_MUL:
-    case GGML_OP_DIV:
-    case GGML_OP_CONCAT:
-    case GGML_OP_UPSCALE:
-    case GGML_OP_SCALE:
-    case GGML_OP_SQR:
-    case GGML_OP_SIN:
-    case GGML_OP_COS:
-    case GGML_OP_CLAMP:
-    case GGML_OP_PAD:
-    case GGML_OP_CPY:
-    case GGML_OP_CONT:
-    case GGML_OP_DUP:
-    case GGML_OP_NORM:
-    case GGML_OP_GROUP_NORM:
-    case GGML_OP_RMS_NORM:
-    case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_SOFT_MAX:
-    case GGML_OP_ROPE:
-    case GGML_OP_MUL_MAT:
-    case GGML_OP_MUL_MAT_ID:
-    case GGML_OP_ARGSORT:
-    case GGML_OP_SUM_ROWS:
-    case GGML_OP_IM2COL:
-    case GGML_OP_TIMESTEP_EMBEDDING:
-    case GGML_OP_POOL_2D:
-    case GGML_OP_LEAKY_RELU:
-        break;
-    default:
-        std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
-        GGML_ABORT("fatal error");
-        return false;
-    }
-
-    vk_context compute_ctx;
-
-    if (!dryrun) {
-        if (ctx->compute_ctx.expired()) {
-            compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
-            ctx->compute_ctx = compute_ctx;
-            ggml_vk_ctx_begin(ctx->device, compute_ctx);
-        } else {
-            compute_ctx = ctx->compute_ctx.lock();
-        }
-    }
-
-    switch (node->op) {
-    case GGML_OP_REPEAT:
-        ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_ACC:
-        ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_GET_ROWS:
-        ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_ADD:
-        ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_MUL:
-        ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_DIV:
-        ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_CONCAT:
-        ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_UPSCALE:
-        ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_SCALE:
-        ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_SQR:
-        ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_SIN:
-        ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_COS:
-        ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_CLAMP:
-        ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_PAD:
-        ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_CPY:
-    case GGML_OP_CONT:
-    case GGML_OP_DUP:
-        ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_NORM:
-        ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_GROUP_NORM:
-        ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_RMS_NORM:
-        ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_UNARY:
-        switch (ggml_get_unary_op(node)) {
-        case GGML_UNARY_OP_SILU:
-        case GGML_UNARY_OP_GELU:
-        case GGML_UNARY_OP_GELU_QUICK:
-        case GGML_UNARY_OP_RELU:
-        case GGML_UNARY_OP_TANH:
-            ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
-            break;
-        default:
-            return false;
-        }
-        break;
-    case GGML_OP_DIAG_MASK_INF:
-        ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_SOFT_MAX:
-        ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_ROPE:
-        ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
-
-        break;
-    case GGML_OP_ARGSORT:
-        ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_SUM_ROWS:
-        ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_IM2COL:
-        ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_TIMESTEP_EMBEDDING:
-        ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_POOL_2D:
-        ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_LEAKY_RELU:
-        ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
-
-        break;
-    case GGML_OP_MUL_MAT:
-        ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun);
-
-        break;
-    case GGML_OP_MUL_MAT_ID:
-        ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
-
-        break;
-    default:
-        return false;
-    }
-
-    if (dryrun) {
-        return false;
-    }
-
-    ctx->tensor_ctxs[node_idx] = compute_ctx;
-
-#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF)
-    // Force context reset on each node so that each tensor ends up in its own context
-    // and can be run and compared to its CPU equivalent separately
-    last_node = true;
-#endif
-
-    if (submit || last_node) {
-        ggml_vk_ctx_end(compute_ctx);
-
-        // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward
-        if (last_node) {
-            compute_ctx->exit_tensor_idx = node_idx_begin;
-        }
-        else {
-            compute_ctx->exit_tensor_idx = -1;
-        }
-
-        ctx->compute_ctx.reset();
-
-        bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false);
-        if (!ok) {
-            if (node->op == GGML_OP_UNARY) {
-                std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
-            }
-            else {
-                std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
-            }
-        }
-
-    }
-    return true;
-}
-
-static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){
-    ggml_backend_buffer * buf = nullptr;
-
-    switch (tensor->op) {
-    case GGML_OP_ADD:
-    case GGML_OP_ACC:
-    case GGML_OP_GET_ROWS:
-    case GGML_OP_MUL:
-    case GGML_OP_DIV:
-    case GGML_OP_CONCAT:
-    case GGML_OP_UPSCALE:
-    case GGML_OP_SCALE:
-    case GGML_OP_SQR:
-    case GGML_OP_SIN:
-    case GGML_OP_COS:
-    case GGML_OP_CLAMP:
-    case GGML_OP_PAD:
-    case GGML_OP_CPY:
-    case GGML_OP_CONT:
-    case GGML_OP_DUP:
-    case GGML_OP_NORM:
-    case GGML_OP_GROUP_NORM:
-    case GGML_OP_RMS_NORM:
-    case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_SOFT_MAX:
-    case GGML_OP_ROPE:
-    case GGML_OP_RESHAPE:
-    case GGML_OP_VIEW:
-    case GGML_OP_PERMUTE:
-    case GGML_OP_TRANSPOSE:
-    case GGML_OP_NONE:
-    case GGML_OP_ARGSORT:
-    case GGML_OP_SUM_ROWS:
-    case GGML_OP_IM2COL:
-    case GGML_OP_TIMESTEP_EMBEDDING:
-    case GGML_OP_POOL_2D:
-    case GGML_OP_LEAKY_RELU:
-    case GGML_OP_REPEAT:
-        buf = tensor->buffer;
-
-        break;
-    case GGML_OP_UNARY:
-        switch (ggml_get_unary_op(tensor)) {
-        case GGML_UNARY_OP_SILU:
-        case GGML_UNARY_OP_GELU:
-        case GGML_UNARY_OP_GELU_QUICK:
-        case GGML_UNARY_OP_RELU:
-        case GGML_UNARY_OP_TANH:
-            buf = tensor->buffer;
-            break;
-        default:
-            return false;
-        }
-        break;
-    case GGML_OP_MUL_MAT:
-    case GGML_OP_MUL_MAT_ID:
-        buf = tensor->buffer;
-
-        break;
-    default:
-        return false;
-    }
-
-    if (buf == nullptr) {
-        return false;
-    }
-
-    VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
-
-    vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
-
-    // always wait for the GPU work to be done for the last submit
-    if (tensor_idx == subctx->exit_tensor_idx) {
-        use_fence = true;
-    }
-
-    // Only run if ctx hasn't been submitted yet
-    if (!subctx->seqs.empty()) {
-#ifdef GGML_VULKAN_CHECK_RESULTS
-        ggml_vk_check_results_0(tensor);
-        use_fence = true;
-#endif
-
-        // Do staging buffer copies
-        for (auto& cpy : subctx->in_memcpys) {
-            memcpy(cpy.dst, cpy.src, cpy.n);
-        }
-
-        ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
-
-        if (use_fence) {
-            VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
-
-            ctx->device->device.resetFences({ ctx->fence });
-        }
-#ifdef GGML_VULKAN_CHECK_RESULTS
-        ggml_vk_check_results_1(tensor);
-#endif
-    }
-
-    if (tensor_idx == subctx->exit_tensor_idx) {
-        // Do staging buffer copies
-        for (auto& cpy : subctx->out_memcpys) {
-            memcpy(cpy.dst, cpy.src, cpy.n);
-        }
-        subctx->in_memcpys.clear();
-        subctx->out_memcpys.clear();
-    }
-
-    return true;
-}
-
-// Clean up after graph processing is done
-static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
-    VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
-    for (auto& buffer : ctx->gc.temp_buffers) {
-        ggml_vk_pool_free(ctx, buffer);
-    }
-    ctx->gc.temp_buffers.clear();
-
-    for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) {
-        vk_pipeline_ref plr = ctx->device->pipelines[dsr.first];
-
-        if (plr.expired()) {
-            continue;
-        }
-
-        vk_pipeline pl = plr.lock();
-        ggml_pipeline_cleanup(pl);
-    }
-
-    ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
-    ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
-
-    for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
-        ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
-    }
-    ctx->gc.semaphores.clear();
-
-    for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
-        ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
-    }
-    ctx->gc.tl_semaphores.clear();
-    ctx->semaphore_idx = 0;
-
-    ctx->event_idx = 0;
-
-    for (auto& event : ctx->gc.events) {
-        ctx->device->device.resetEvent(event);
-    }
-
-    ctx->tensor_ctxs.clear();
-    ctx->gc.contexts.clear();
-    ctx->device->pipeline_descriptor_set_requirements.clear();
-}
-
-// Clean up on backend free
-static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
-    VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
-    ggml_vk_graph_cleanup(ctx);
-
-    ggml_vk_destroy_buffer(ctx->prealloc_x);
-    ggml_vk_destroy_buffer(ctx->prealloc_y);
-    ggml_vk_destroy_buffer(ctx->prealloc_split_k);
-
-    for (auto& buffer : ctx->buffer_pool) {
-        ggml_vk_destroy_buffer(buffer);
-    }
-
-    ctx->prealloc_size_x = 0;
-    ctx->prealloc_size_y = 0;
-    ctx->prealloc_size_split_k = 0;
-
-    for (auto& event : ctx->gc.events) {
-        ctx->device->device.destroyEvent(event);
-    }
-    ctx->gc.events.clear();
-
-    ctx->device->device.destroyFence(ctx->fence);
-}
-
-static int ggml_vk_get_device_count() {
-    ggml_vk_instance_init();
-
-    return vk_instance.device_indices.size();
-}
-
-static void ggml_vk_get_device_description(int device, char * description, size_t description_size) {
-    ggml_vk_instance_init();
-
-    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
-
-    vk::PhysicalDeviceProperties props;
-    devices[device].getProperties(&props);
-
-    snprintf(description, description_size, "%s", props.deviceName.data());
-}
-
-// backend interface
-
-#define UNUSED GGML_UNUSED
-
-// device backend
-
-static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
-    return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;
-}
-
-static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()");
-    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
-    ggml_vk_destroy_buffer(ctx->dev_buffer);
-    delete ctx;
-}
-
-static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
-    return vk_ptr_base;
-
-    UNUSED(buffer);
-}
-
-static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
-    VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
-    if (tensor->view_src != nullptr) {
-        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
-    }
-}
-
-static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
-    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
-    vk_buffer buf = buf_ctx->dev_buffer;
-
-    ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
-}
-
-static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
-    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
-
-    vk_buffer buf = buf_ctx->dev_buffer;
-
-    ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
-}
-
-static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
-    if (ggml_backend_buffer_is_vk(src->buffer)) {
-        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
-        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-
-        vk_buffer src_buf = src_buf_ctx->dev_buffer;
-        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
-
-        ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
-
-        return true;
-    }
-    return false;
-
-    UNUSED(buffer);
-}
-
-static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
-
-    ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size);
-}
-
-static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
-    /* .free_buffer     = */ ggml_backend_vk_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_vk_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_vk_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_vk_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_vk_buffer_get_tensor,
-    /* .cpy_tensor      = */ ggml_backend_vk_buffer_cpy_tensor,
-    /* .clear           = */ ggml_backend_vk_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// vk buffer type
-static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
-
-    return ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")");
-    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
-
-    vk_buffer dev_buffer = nullptr;
-    try {
-        dev_buffer = ggml_vk_create_buffer_device(ctx->device, size);
-    } catch (const vk::SystemError& e) {
-        return nullptr;
-    }
-
-    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);
-
-    return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);
-}
-
-static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
-    return ctx->device->properties.limits.minStorageBufferOffsetAlignment;
-}
-
-static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
-    return ctx->device->max_memory_allocation_size;
-}
-
-static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    return ggml_nbytes(tensor);
-
-    UNUSED(buft);
-}
-
-ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
-    ggml_vk_instance_init();
-
-    VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")");
-
-    vk_device dev = ggml_vk_get_device(dev_num);
-
-    return &dev->buffer_type;
-}
-
-// host buffer type
-
-static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    return GGML_VK_NAME "_Host";
-
-    UNUSED(buft);
-}
-
-static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
-    return GGML_VK_NAME "_Host";
-
-    UNUSED(buffer);
-}
-
-static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
-    ggml_vk_host_free(vk_instance.devices[0], buffer->context);
-}
-
-static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")");
-
-    size += 32;  // Behave like the CPU buffer type
-    void * ptr = nullptr;
-    try {
-        ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
-    } catch (vk::SystemError& e) {
-        std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl;
-        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
-        // fallback to cpu buffer
-        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
-    }
-
-    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
-    buffer->buft = buft;
-    buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
-
-    return buffer;
-
-    UNUSED(buft);
-}
-
-static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;
-
-    UNUSED(buft);
-}
-
-// Should be changed to return device-specific host buffer type
-// but that probably requires changes in llama.cpp
-ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
-    static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {
-        /* .iface    = */ {
-            /* .get_name         = */ ggml_backend_vk_host_buffer_type_name,
-            /* .alloc_buffer     = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_vk_host_buffer_type_get_alignment,
-            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
-            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
-        },
-        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
-        /* .context  = */ nullptr,
-    };
-
-    // Make sure device 0 is initialized
-    ggml_vk_instance_init();
-    ggml_vk_get_device(0);
-
-    return &ggml_backend_vk_buffer_type_host;
-}
-
-
-// backend
-
-static const char * ggml_backend_vk_name(ggml_backend_t backend) {
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-
-    return ctx->name.c_str();
-}
-
-static void ggml_backend_vk_free(ggml_backend_t backend) {
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")");
-
-    ggml_vk_cleanup(ctx);
-
-    delete ctx;
-    delete backend;
-}
-
-static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-
-    return &ctx->device->buffer_type;
-}
-
-static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
-
-    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
-
-    vk_context transfer_ctx;
-
-    if (ctx->transfer_ctx.expired()) {
-        // Initialize new transfer context
-        transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ctx->transfer_ctx = transfer_ctx;
-        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-    } else {
-        transfer_ctx = ctx->transfer_ctx.lock();
-    }
-
-    vk_buffer buf = buf_ctx->dev_buffer;
-
-    ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
-}
-
-static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
-
-    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
-
-    vk_context transfer_ctx;
-
-    if (ctx->transfer_ctx.expired()) {
-        // Initialize new transfer context
-        transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-        ctx->transfer_ctx = transfer_ctx;
-        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-    } else {
-        transfer_ctx = ctx->transfer_ctx.lock();
-    }
-
-    vk_buffer buf = buf_ctx->dev_buffer;
-
-    ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
-}
-
-static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
-    VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
-        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
-        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-
-        vk_context transfer_ctx;
-
-        if (ctx->transfer_ctx.expired()) {
-            // Initialize new transfer context
-            transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
-            ctx->transfer_ctx = transfer_ctx;
-            ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-        } else {
-            transfer_ctx = ctx->transfer_ctx.lock();
-        }
-
-        vk_buffer src_buf = src_buf_ctx->dev_buffer;
-        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
-
-        ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
-        return true;
-    }
-
-    return false;
-}
-
-static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
-    VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    if(ctx->transfer_ctx.expired()) {
-        return;
-    }
-
-    vk_context transfer_ctx = ctx->transfer_ctx.lock();
-
-    ggml_vk_ctx_end(transfer_ctx);
-
-    for (auto& cpy : transfer_ctx->in_memcpys) {
-        memcpy(cpy.dst, cpy.src, cpy.n);
-    }
-
-    ggml_vk_submit(transfer_ctx, ctx->fence);
-    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
-    ctx->device->device.resetFences({ ctx->fence });
-
-    for (auto& cpy : transfer_ctx->out_memcpys) {
-        memcpy(cpy.dst, cpy.src, cpy.n);
-    }
-
-    ctx->transfer_ctx.reset();
-}
-
-static bool ggml_vk_is_empty(ggml_tensor * node) {
-    return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
-}
-
-static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
-    }
-    ggml_vk_preallocate_buffers(ctx);
-    ggml_pipeline_allocate_descriptor_sets(ctx->device);
-
-    int last_node = cgraph->n_nodes - 1;
-
-    // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
-    while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
-        last_node -= 1;
-    }
-
-    // Reserve tensor context space for all nodes
-    ctx->tensor_ctxs.resize(cgraph->n_nodes);
-
-    bool first_node_in_batch = true; // true if next node will be first node in a batch
-    int submit_node_idx = 0; // index to first node in a batch
-
-    // submit work every submit_count node to overlap CPU cmdbuffer generation with GPU execution
-    constexpr int submit_count = 100;
-    int submitted_nodes = 0;
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        if (first_node_in_batch) {
-            submit_node_idx = i;
-        }
-
-        bool submit = (submitted_nodes >= submit_count) || (i == last_node);
-
-
-        bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
-
-        if (enqueued) {
-            ++submitted_nodes;
-
-#ifndef GGML_VULKAN_CHECK_RESULTS
-            if (first_node_in_batch) {
-                first_node_in_batch = false;
-            }
-#endif
-        }
-
-        if (submit) {
-            first_node_in_batch = true;
-            submitted_nodes = 0;
-        }
-    }
-
-#ifdef GGML_VULKAN_PERF
-    ctx->device->perf_logger->print_timings();
-#endif
-
-    ggml_vk_graph_cleanup(ctx);
-
-    return GGML_STATUS_SUCCESS;
-
-    UNUSED(backend);
-}
-
-// TODO: enable async and synchronize
-static ggml_backend_i ggml_backend_vk_interface = {
-    /* .get_name                = */ ggml_backend_vk_name,
-    /* .free                    = */ ggml_backend_vk_free,
-    /* .set_tensor_async        = */ NULL,  // ggml_backend_vk_set_tensor_async,
-    /* .get_tensor_async        = */ NULL,  // ggml_backend_vk_get_tensor_async,
-    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_vk_cpy_tensor_async,
-    /* .synchronize             = */ NULL,  // ggml_backend_vk_synchronize,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_vk_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_vk_guid() {
-    static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
-    return &guid;
-}
-
-ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
-    VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
-
-    ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
-    ggml_vk_init(ctx, dev_num);
-
-    ggml_backend_t vk_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_vk_guid(),
-        /* .interface = */ ggml_backend_vk_interface,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
-        /* .context   = */ ctx,
-    };
-
-    return vk_backend;
-}
-
-bool ggml_backend_is_vk(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
-}
-
-int ggml_backend_vk_get_device_count() {
-    return ggml_vk_get_device_count();
-}
-
-void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
-    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
-    int dev_idx = vk_instance.device_indices[device];
-    ggml_vk_get_device_description(dev_idx, description, description_size);
-}
-
-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
-    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
-
-    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
-
-    vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
-
-    for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
-        if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
-            *total = heap.size;
-            *free = heap.size;
-            break;
-        }
-    }
-}
-
-//////////////////////////
-
-struct ggml_backend_vk_device_context {
-    size_t device;
-    std::string name;
-    std::string description;
-};
-
-static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
-    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-    return ctx->name.c_str();
-}
-
-static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
-    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-    return ctx->description.c_str();
-}
-
-static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
-    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
-    ggml_backend_vk_get_device_memory(ctx->device, free, total);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
-    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-    return ggml_backend_vk_buffer_type(ctx->device);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
-    UNUSED(dev);
-    return ggml_backend_vk_host_buffer_type();
-}
-
-static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
-    UNUSED(dev);
-    return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_vk_device_get_name(dev);
-    props->description = ggml_backend_vk_device_get_description(dev);
-    props->type        = ggml_backend_vk_device_get_type(dev);
-    ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
-    props->caps = {
-        /* .async                 = */ false,
-        /* .host_buffer           = */ true,
-        /* .buffer_from_host_ptr  = */ false,
-        /* .events                = */ false,
-    };
-}
-
-static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
-    UNUSED(params);
-    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-    return ggml_backend_vk_init(ctx->device);
-}
-
-static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    switch (op->op) {
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_GELU:
-                case GGML_UNARY_OP_GELU_QUICK:
-                case GGML_UNARY_OP_SILU:
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_TANH:
-                    return ggml_is_contiguous(op->src[0]);
-                default:
-                    return false;
-            }
-            break;
-        case GGML_OP_MUL_MAT:
-        case GGML_OP_MUL_MAT_ID:
-            {
-                switch (op->src[0]->type) {
-                    case GGML_TYPE_F32:
-                    case GGML_TYPE_F16:
-                    case GGML_TYPE_Q4_0:
-                    case GGML_TYPE_Q4_1:
-                    case GGML_TYPE_Q5_0:
-                    case GGML_TYPE_Q5_1:
-                    case GGML_TYPE_Q8_0:
-                    case GGML_TYPE_Q2_K:
-                    case GGML_TYPE_Q3_K:
-                    case GGML_TYPE_Q4_K:
-                    case GGML_TYPE_Q5_K:
-                    case GGML_TYPE_Q6_K:
-                    case GGML_TYPE_IQ4_NL:
-                        break;
-                    default:
-                        return false;
-                }
-                struct ggml_tensor * a;
-                struct ggml_tensor * b;
-                if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
-                if (a->ne[3] != b->ne[3]) {
-                    return false;
-                }
-                if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
-                    !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
-                    return false;
-                }
-
-                return true;
-            } break;
-        case GGML_OP_GET_ROWS:
-            {
-                switch (op->src[0]->type) {
-                    case GGML_TYPE_F32:
-                    case GGML_TYPE_F16:
-                    case GGML_TYPE_Q4_0:
-                    case GGML_TYPE_Q4_1:
-                    case GGML_TYPE_Q5_0:
-                    case GGML_TYPE_Q5_1:
-                    case GGML_TYPE_Q8_0:
-                    case GGML_TYPE_IQ4_NL:
-                        return true;
-                    default:
-                        return false;
-                }
-            } break;
-        case GGML_OP_CONT:
-        case GGML_OP_CPY:
-        case GGML_OP_DUP:
-            {
-                ggml_type src0_type = op->src[0]->type;
-                ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
-                    return true;
-                }
-                return false;
-            } break;
-        case GGML_OP_REPEAT:
-            return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
-        case GGML_OP_ROPE:
-            return ggml_is_contiguous(op->src[0]);
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_NORM:
-        case GGML_OP_GROUP_NORM:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_ADD:
-        case GGML_OP_ACC:
-        case GGML_OP_MUL:
-        case GGML_OP_DIV:
-        case GGML_OP_CONCAT:
-        case GGML_OP_UPSCALE:
-        case GGML_OP_SCALE:
-        case GGML_OP_SQR:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-        case GGML_OP_CLAMP:
-        case GGML_OP_PAD:
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX:
-        case GGML_OP_ARGSORT:
-        case GGML_OP_SUM_ROWS:
-        case GGML_OP_IM2COL:
-        case GGML_OP_TIMESTEP_EMBEDDING:
-        case GGML_OP_POOL_2D:
-        case GGML_OP_LEAKY_RELU:
-            return true;
-        default:
-            return false;
-    }
-
-    UNUSED(dev);
-}
-
-static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
-        return false;
-    }
-
-    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-    ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
-
-    return buft_ctx->device->idx == ctx->device;
-}
-
-static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    const int min_batch_size = 32;
-
-    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
-           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
-
-    UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
-    /* .get_name             = */ ggml_backend_vk_device_get_name,
-    /* .get_description      = */ ggml_backend_vk_device_get_description,
-    /* .get_memory           = */ ggml_backend_vk_device_get_memory,
-    /* .get_type             = */ ggml_backend_vk_device_get_type,
-    /* .get_props            = */ ggml_backend_vk_device_get_props,
-    /* .init_backend         = */ ggml_backend_vk_device_init,
-    /* .get_buffer_type      = */ ggml_backend_vk_device_get_buffer_type,
-    /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
-    /* .buffer_from_host_ptr = */ NULL,
-    /* .supports_op          = */ ggml_backend_vk_device_supports_op,
-    /* .supports_buft        = */ ggml_backend_vk_device_supports_buft,
-    /* .offload_op           = */ ggml_backend_vk_device_offload_op,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
-};
-
-static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
-    UNUSED(reg);
-    return GGML_VK_NAME;
-}
-
-static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
-    UNUSED(reg);
-    return ggml_backend_vk_get_device_count();
-}
-
-static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
-    static std::vector<ggml_backend_dev_t> devices;
-
-    static bool initialized = false;
-
-    {
-        static std::mutex mutex;
-        std::lock_guard<std::mutex> lock(mutex);
-        if (!initialized) {
-            for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
-                ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
-                char desc[256];
-                ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
-                ctx->device = i;
-                ctx->name = GGML_VK_NAME + std::to_string(i);
-                ctx->description = desc;
-                devices.push_back(new ggml_backend_device {
-                    /* .iface   = */ ggml_backend_vk_device_i,
-                    /* .reg     = */ reg,
-                    /* .context = */ ctx,
-                });
-            }
-            initialized = true;
-        }
-    }
-
-    GGML_ASSERT(device < devices.size());
-    return devices[device];
-}
-
-static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
-    /* .get_name         = */ ggml_backend_vk_reg_get_name,
-    /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
-    /* .get_device       = */ ggml_backend_vk_reg_get_device,
-    /* .get_proc_address = */ NULL,
-};
-
-ggml_backend_reg_t ggml_backend_vk_reg() {
-    static ggml_backend_reg reg = {
-        /* .iface   = */ ggml_backend_vk_reg_i,
-        /* .context = */ nullptr,
-    };
-
-    return &reg;
-}
-
-// Extension availability
-static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
-#ifdef GGML_VULKAN_VALIDATE
-    bool portability_enumeration_ext = false;
-    // Check for portability enumeration extension for MoltenVK support
-    for (const auto& properties : instance_extensions) {
-        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
-            return true;
-        }
-    }
-    if (!portability_enumeration_ext) {
-        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
-    }
-#endif
-    return false;
-
-    UNUSED(instance_extensions);
-}
-static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
-#ifdef __APPLE__
-    bool portability_enumeration_ext = false;
-    // Check for portability enumeration extension for MoltenVK support
-    for (const auto& properties : instance_extensions) {
-        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
-            return true;
-        }
-    }
-    if (!portability_enumeration_ext) {
-        std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
-    }
-#endif
-    return false;
-
-    UNUSED(instance_extensions);
-}
-
-// checks
-
-#ifdef GGML_VULKAN_CHECK_RESULTS
-static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<const ggml_tensor *>& done, int level = 0) {
-    if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) {
-        return;
-    }
-    for (int j = 0; j < level; j++) {
-        std::cerr << " ";
-    }
-    std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl;
-
-    done.push_back(tensor);
-
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (tensor->src[i] != nullptr) {
-            ggml_vk_print_graph_origin(tensor->src[i], done, level + 1);
-        }
-    }
-}
-
-static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {
-    if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {
-        return;
-    }
-    i0 = std::max(i0, 5);
-    i1 = std::max(i1, 5);
-    i2 = std::max(i2, 0);
-    i3 = std::max(i3, 0);
-    fprintf(stderr, "         ");
-    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
-        fprintf(stderr, "%7d ", idx1);
-    }
-    fprintf(stderr, "\n");
-    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
-        fprintf(stderr, "%7d: ", idx0);
-        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
-            if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
-                float val;
-                if (tensor->type == GGML_TYPE_F32) {
-                    val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
-                } else if (tensor->type == GGML_TYPE_F16) {
-                    val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
-                } else if (tensor->type == GGML_TYPE_I32) {
-                    val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
-                } else {
-                    GGML_ABORT("fatal error");
-                }
-                fprintf(stderr, "% 7.2f ", val);
-            } else {
-                fprintf(stderr, "        ");
-            }
-        }
-        fprintf(stderr, "\n");
-    }
-}
-
-static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {
-    void * tensor_data = tensor->data;
-
-    const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);
-
-    if (is_gpu) {
-        const size_t tensor_size = ggml_nbytes(tensor);
-        tensor_data = malloc(tensor_size);
-
-        ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
-
-        vk_buffer buffer_gpu = buf_ctx->dev_buffer;
-        ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size);
-    }
-
-    std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
-    std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl;
-    if (tensor->src[0] != nullptr) {
-        std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl;
-    }
-    if (tensor->src[1] != nullptr) {
-        std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl;
-    }
-    std::cerr << std::endl << "Result:" << std::endl;
-    ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
-    std::cerr << std::endl;
-    std::vector<const ggml_tensor *> done;
-    ggml_vk_print_graph_origin(tensor, done);
-
-    if (is_gpu) {
-        free(tensor_data);
-    }
-}
-
-void * comp_result;
-size_t comp_size;
-size_t comp_nb[GGML_MAX_DIMS];
-size_t check_counter = 0;
-static void ggml_vk_check_results_0(ggml_tensor * tensor) {
-    if (tensor->op == GGML_OP_TRANSPOSE) {
-        return;
-    }
-
-    check_counter++;
-    if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
-        return;
-    }
-
-    VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")");
-
-    ggml_tensor * src0 = tensor->src[0];
-    ggml_tensor * src1 = tensor->src[1];
-    ggml_tensor * src2 = tensor->src[2];
-
-    struct ggml_init_params iparams = {
-        /*.mem_size   =*/ 2ul*1024ul*1024ul*1024ul,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ false,
-    };
-
-    struct ggml_context * ggml_ctx = ggml_init(iparams);
-
-    struct ggml_tensor * src0_clone = nullptr;
-    struct ggml_tensor * src1_clone = nullptr;
-    struct ggml_tensor * src2_clone = nullptr;
-    struct ggml_tensor * tensor_clone = nullptr;
-
-    size_t src0_size;
-    size_t src1_size;
-    size_t src2_size;
-
-    void * src0_buffer = nullptr;
-    void * src1_buffer = nullptr;
-    void * src2_buffer = nullptr;
-
-    if (src0 != nullptr) {
-        src0_clone = ggml_dup_tensor(ggml_ctx, src0);
-
-        src0_size = ggml_nbytes(src0);
-
-        src0_buffer = malloc(src0_size);
-        src0_clone->data = src0_buffer;
-        if (ggml_backend_buffer_is_host(src0->buffer)) {
-            memcpy(src0_clone->data, src0->data, src0_size);
-            memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
-        } else if (ggml_backend_buffer_is_vk(src0->buffer)) {
-            ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
-            vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
-            uint64_t offset = vk_tensor_offset(src0) + src0->view_offs;
-            if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
-                for (int i3 = 0; i3 < src0->ne[3]; i3++) {
-                    for (int i2 = 0; i2 < src0->ne[2]; i2++) {
-                        const int idx = i3*src0->ne[2] + i2;
-                        ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
-                    }
-                }
-
-                src0_clone->nb[0] = src0->nb[0];
-                src0_clone->nb[1] = src0->nb[1];
-                for (int i = 2; i < GGML_MAX_DIMS; i++) {
-                    src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
-                }
-            } else {
-                if (offset + src0_size >= buffer_gpu->size) {
-                    src0_size = buffer_gpu->size - offset;
-                }
-                ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
-                memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
-            }
-        } else {
-            GGML_ABORT("fatal error");
-        }
-
-        if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(src0, "src0");
-        }
-    }
-    if (src1 != nullptr) {
-        src1_clone = ggml_dup_tensor(ggml_ctx, src1);
-
-        src1_size = ggml_nbytes(src1);
-
-        src1_buffer = malloc(src1_size);
-        src1_clone->data = src1_buffer;
-        if (ggml_backend_buffer_is_host(src1->buffer)) {
-            memcpy(src1_clone->data, src1->data, src1_size);
-            memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
-        } else if (ggml_backend_buffer_is_vk(src1->buffer)) {
-            ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
-            vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
-            uint64_t offset = vk_tensor_offset(src1) + src1->view_offs;
-            if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
-                for (int i3 = 0; i3 < src1->ne[3]; i3++) {
-                    for (int i2 = 0; i2 < src1->ne[2]; i2++) {
-                        const int idx = i3*src1->ne[2] + i2;
-                        ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
-                    }
-                }
-
-                src1_clone->nb[0] = src1->nb[0];
-                src1_clone->nb[1] = src1->nb[1];
-                for (int i = 2; i < GGML_MAX_DIMS; i++) {
-                    src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
-                }
-            } else {
-                if (offset + src1_size >= buffer_gpu->size) {
-                    src1_size = buffer_gpu->size - offset;
-                }
-                ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
-                memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
-            }
-        } else {
-            GGML_ABORT("fatal error");
-        }
-
-        if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(src1, "src1");
-        }
-    }
-    if (src2 != nullptr) {
-        src2_clone = ggml_dup_tensor(ggml_ctx, src2);
-
-        src2_size = ggml_nbytes(src2);
-
-        src2_buffer = malloc(src2_size);
-        src2_clone->data = src2_buffer;
-        if (ggml_backend_buffer_is_host(src2->buffer)) {
-            memcpy(src2_clone->data, src2->data, src2_size);
-            memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
-        } else if (ggml_backend_buffer_is_vk(src2->buffer)) {
-            ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context;
-            vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
-            uint64_t offset = vk_tensor_offset(src2) + src2->view_offs;
-            if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
-                for (int i3 = 0; i3 < src2->ne[3]; i3++) {
-                    for (int i2 = 0; i2 < src2->ne[2]; i2++) {
-                        const int idx = i3*src2->ne[2] + i2;
-                        ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
-                    }
-                }
-
-                src2_clone->nb[0] = src2->nb[0];
-                src2_clone->nb[1] = src2->nb[1];
-                for (int i = 2; i < GGML_MAX_DIMS; i++) {
-                    src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
-                }
-            } else {
-                if (offset + src2_size >= buffer_gpu->size) {
-                    src2_size = buffer_gpu->size - offset;
-                }
-                ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
-                memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
-            }
-        } else {
-            GGML_ABORT("fatal error");
-        }
-
-        if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-            ggml_vk_print_tensor(src2, "src2");
-        }
-    }
-
-    if (tensor->op == GGML_OP_MUL_MAT) {
-        tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
-    } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
-        tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
-    } else if (tensor->op == GGML_OP_MUL) {
-        tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
-    } else if (tensor->op == GGML_OP_DIV) {
-        tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
-    } else if (tensor->op == GGML_OP_CONCAT) {
-        tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params);
-    } else if (tensor->op == GGML_OP_UPSCALE) {
-        tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
-    } else if (tensor->op == GGML_OP_SCALE) {
-        tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
-    } else if (tensor->op == GGML_OP_SQR) {
-        tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
-    } else if (tensor->op == GGML_OP_SIN) {
-        tensor_clone = ggml_sin(ggml_ctx, src0_clone);
-    } else if (tensor->op == GGML_OP_COS) {
-        tensor_clone = ggml_cos(ggml_ctx, src0_clone);
-    } else if (tensor->op == GGML_OP_CLAMP) {
-        tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
-    } else if (tensor->op == GGML_OP_PAD) {
-        tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
-    } else if (tensor->op == GGML_OP_REPEAT) {
-        tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor);
-    } else if (tensor->op == GGML_OP_ADD) {
-        tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
-    } else if (tensor->op == GGML_OP_ACC) {
-        tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
-    } else if (tensor->op == GGML_OP_NORM) {
-        tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
-    } else if (tensor->op == GGML_OP_GROUP_NORM) {
-        tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
-    } else if (tensor->op == GGML_OP_RMS_NORM) {
-        tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
-    } else if (tensor->op == GGML_OP_SOFT_MAX) {
-        if (src1 != nullptr) {
-            tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
-        } else {
-            tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
-        }
-    } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
-        tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
-    } else if (tensor->op == GGML_OP_ROPE) {
-        const int n_dims      = ((int32_t *) tensor->op_params)[1];
-        const int mode        = ((int32_t *) tensor->op_params)[2];
-        //const int n_ctx_ggml       = ((int32_t *) tensor->op_params)[3];
-        const int n_ctx_orig_ggml  = ((int32_t *) tensor->op_params)[4];
-        const float freq_base       = ((float *) tensor->op_params)[5];
-        const float freq_scale      = ((float *) tensor->op_params)[6];
-        const float ext_factor      = ((float *) tensor->op_params)[7];
-        const float attn_factor     = ((float *) tensor->op_params)[8];
-        const float beta_fast       = ((float *) tensor->op_params)[9];
-        const float beta_slow       = ((float *) tensor->op_params)[10];
-        tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
-    } else if (tensor->op == GGML_OP_UNARY) {
-        switch (ggml_get_unary_op(tensor)) {
-        case GGML_UNARY_OP_SILU:
-            tensor_clone = ggml_silu(ggml_ctx, src0_clone);
-            break;
-        case GGML_UNARY_OP_GELU:
-            tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
-            break;
-        case GGML_UNARY_OP_GELU_QUICK:
-            tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone);
-            break;
-        case GGML_UNARY_OP_RELU:
-            tensor_clone = ggml_relu(ggml_ctx, src0_clone);
-            break;
-        case GGML_UNARY_OP_TANH:
-            tensor_clone = ggml_tanh(ggml_ctx, src0_clone);
-            break;
-        default:
-            std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
-            GGML_ABORT("fatal error");
-        }
-    } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
-        if (src1 == nullptr) {
-            tensor_clone = ggml_dup(ggml_ctx, src0_clone);
-            tensor_clone->type = tensor->type;
-        } else {
-            tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone);
-        }
-    } else if (tensor->op == GGML_OP_CONT) {
-        tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
-    } else if (tensor->op == GGML_OP_RESHAPE) {
-        tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
-    } else if (tensor->op == GGML_OP_VIEW) {
-        tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
-    } else if (tensor->op == GGML_OP_PERMUTE) {
-        int32_t * params = (int32_t *)tensor->op_params;
-        tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]);
-    } else if (tensor->op == GGML_OP_TRANSPOSE) {
-        tensor_clone = ggml_transpose(ggml_ctx, src0_clone);
-    } else if (tensor->op == GGML_OP_GET_ROWS) {
-        tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
-    } else if (tensor->op == GGML_OP_ARGSORT) {
-        tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
-    } else if (tensor->op == GGML_OP_SUM_ROWS) {
-        tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
-    } else if (tensor->op == GGML_OP_IM2COL) {
-        const int32_t s0 = tensor->op_params[0];
-        const int32_t s1 = tensor->op_params[1];
-        const int32_t p0 = tensor->op_params[2];
-        const int32_t p1 = tensor->op_params[3];
-        const int32_t d0 = tensor->op_params[4];
-        const int32_t d1 = tensor->op_params[5];
-
-        const bool is_2D = tensor->op_params[6] == 1;
-        tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
-    } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
-        const int32_t dim = tensor->op_params[0];
-        const int32_t max_period = tensor->op_params[1];
-        tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
-    } else if (tensor->op == GGML_OP_POOL_2D) {
-        enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
-        const int32_t k0 = tensor->op_params[1];
-        const int32_t k1 = tensor->op_params[2];
-        const int32_t s0 = tensor->op_params[3];
-        const int32_t s1 = tensor->op_params[4];
-        const int32_t p0 = tensor->op_params[5];
-        const int32_t p1 = tensor->op_params[6];
-
-        tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
-    } else if (tensor->op == GGML_OP_LEAKY_RELU) {
-        const float * op_params = (const float *)tensor->op_params;
-        tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
-    } else {
-        std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
-        GGML_ABORT("fatal error");
-    }
-
-    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
-    ggml_build_forward_expand(cgraph, tensor_clone);
-
-    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
-
-    if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-        ggml_vk_print_tensor(tensor_clone, "tensor_clone");
-    }
-
-    comp_size = ggml_nbytes(tensor_clone);
-
-    comp_result = malloc(comp_size);
-    memcpy(comp_result, tensor_clone->data, comp_size);
-    memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
-
-    if (src0 != nullptr) {
-        free(src0_buffer);
-    }
-    if (src1 != nullptr) {
-        free(src1_buffer);
-    }
-
-    ggml_free(ggml_ctx);
-
-    VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
-}
-
-static void ggml_vk_check_results_1(ggml_tensor * tensor) {
-    if (tensor->op == GGML_OP_TRANSPOSE) {
-        return;
-    }
-    if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
-        return;
-    }
-
-    VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")");
-
-    ggml_tensor * src0 = tensor->src[0];
-    ggml_tensor * src1 = tensor->src[1];
-    ggml_tensor * src2 = tensor->src[2];
-
-    void * tensor_data = tensor->data;
-
-    if (ggml_backend_buffer_is_vk(tensor->buffer)) {
-        size_t tensor_size = ggml_nbytes(tensor);
-        tensor_data = malloc(tensor_size);
-
-        ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
-
-        vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
-        uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs;
-        if (offset + tensor_size >= buffer_gpu->size) {
-            tensor_size = buffer_gpu->size - offset;
-        }
-
-        ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size);
-    }
-
-    float first_error_result = -1.0f;
-    float first_error_correct = -1.0f;
-    std::array<int, 4> first_error = { -1, -1, -1, -1 };
-    double avg_err = 0.0;
-    size_t counter = 0;
-
-    for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
-        for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
-            for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
-                for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
-                    const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size;
-                    float correct = 0.0f;
-                    float result = 0.0f;
-
-                    if (buffer_size_fit) {
-                        if (tensor->type == GGML_TYPE_F32) {
-                            correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
-                            result  = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
-                        } else if (tensor->type == GGML_TYPE_F16) {
-                            correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
-                            result  = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
-                        } else if (tensor->type == GGML_TYPE_I32) {
-                            correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
-                            result  = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
-                        } else {
-                            std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
-                        }
-                    } else {
-                        std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl;
-                        GGML_ABORT("fatal error");
-                    }
-
-                    if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) {
-                        std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl;
-                        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
-                        if (src0 != nullptr) {
-                            std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
-                        }
-                        if (src1 != nullptr) {
-                            std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
-                        }
-                        if (src2 != nullptr) {
-                            std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
-                        }
-                        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
-                        std::cerr << std::endl << "Result:" << std::endl;
-                        ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
-                        std::cerr << std::endl << "Correct:" << std::endl;
-                        ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3);
-                        std::cerr << std::endl;
-                        std::vector<const ggml_tensor *> done;
-                        ggml_vk_print_graph_origin(tensor, done);
-                        GGML_ABORT("fatal error");
-                    }
-                    if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) {
-                        first_error[0] = i0;
-                        first_error[1] = i1;
-                        first_error[2] = i2;
-                        first_error[3] = i3;
-                        first_error_result = result;
-                        first_error_correct = correct;
-                    }
-
-                    // Special case, value is infinite, avoid NaN result in avg_err
-                    // NaN also appears in results, if both are nan error is 0
-                    if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
-                        avg_err += std::fabs(correct - result);
-                    }
-                    counter++;
-                }
-            }
-        }
-    }
-
-    avg_err /= counter;
-
-    if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
-        std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
-        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
-        if (src0 != nullptr) {
-            std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
-        }
-        if (src1 != nullptr) {
-            std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
-        }
-        if (src2 != nullptr) {
-            std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
-        }
-        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
-        std::cerr << std::endl << "Result:" << std::endl;
-        ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
-        std::cerr << std::endl << "Correct:" << std::endl;
-        ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
-        std::cerr << std::endl;
-        std::vector<const ggml_tensor *> done;
-        ggml_vk_print_graph_origin(tensor, done);
-    }
-
-    if (avg_err > 0.05 || std::isnan(avg_err)) {
-        std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
-        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
-        if (src0 != nullptr) {
-            std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
-        }
-        if (src1 != nullptr) {
-            std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
-        }
-        if (src2 != nullptr) {
-            std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
-        }
-        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
-        std::cerr << std::endl << "Result:" << std::endl;
-        ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
-        std::cerr << std::endl << "Correct:" << std::endl;
-        ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]);
-        std::cerr << std::endl;
-        std::vector<const ggml_tensor *> done;
-        ggml_vk_print_graph_origin(tensor, done);
-        GGML_ABORT("fatal error");
-    } else {
-        std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl;
-    }
-
-    free(comp_result);
-    comp_result = nullptr;
-    comp_size = 0;
-
-    if (ggml_backend_buffer_is_vk(tensor->buffer)) {
-        free(tensor_data);
-    }
-
-    VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
-}
-#endif
diff --git a/ggml/src/kompute-shaders/common.comp b/ggml/src/kompute-shaders/common.comp
deleted file mode 100644 (file)
index 62d62b0..0000000
+++ /dev/null
@@ -1,102 +0,0 @@
-#extension GL_EXT_shader_16bit_storage: require
-#extension GL_EXT_shader_8bit_storage: require
-#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
-#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
-#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
-#extension GL_EXT_control_flow_attributes: enable
-#extension GL_KHR_shader_subgroup_arithmetic : require
-#extension GL_EXT_debug_printf : enable
-
-#define QK4_0 32
-#define QK4_1 32
-
-#define GELU_COEF_A 0.044715
-#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
-#define TWOPI_F 6.283185307179586f
-
-#define QK_K 256
-
-#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
-#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
-#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
-#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
-
-#define sizeof_block_q4_0 0x12
-struct block_q4_0 {
-    float16_t d;
-    uint8_t qs[QK4_0 / 2];
-};
-mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
-    const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
-    const float d2 = d1 / 256.f;
-    const float md = -8.f * xb.d;
-    const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
-    const uint16_t mask1 = mask0 << 8;
-
-    mat4 reg;
-    for (int i=0;i<8;i++) {
-        uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
-        reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
-        reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
-    }
-    return reg;
-}
-
-#define sizeof_block_q4_1 0x14
-struct block_q4_1 {
-    float16_t d;
-    float16_t m;
-    uint8_t qs[QK4_1 / 2];
-};
-mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
-    const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
-    const float d2 = d1 / 256.f;
-    const float  m = xb.m;
-    const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
-    const uint16_t mask1 = mask0 << 8;
-
-    mat4 reg;
-    for (int i=0;i<8;i++) {
-        uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
-        reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
-        reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
-    }
-    return reg;
-}
-
-#define sizeof_block_q6_k 210
-struct block_q6_k {
-    uint8_t ql[QK_K/2];      // quants, lower 4 bits
-    uint8_t qh[QK_K/4];      // quants, upper 2 bits
-    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
-    float16_t d;             // super-block scale
-};
-mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
-    const float16_t d_all = xb.d;
-
-    const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    const uint qhIndex = 32*(il/8) + 16*(il&1);
-    float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
-    il = (il/2) & 3;
-
-    const uint16_t  kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? uint8_t(0xF0)             : uint8_t(0x0F);
-    const float16_t coef   = il>1 ? float16_t(1.f/16.f)       : float16_t(1.f);
-    const float16_t ml = float16_t(d_all * sc * 32.f);
-    const float16_t dl = float16_t(d_all * sc * coef);
-    mat4 reg;
-    for (int i = 0; i < 16; ++i) {
-        const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
-                                        : ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
-    }
-    return reg;
-}
-
-
-#define QK8_0 32
-// struct block_q8_0 {
-//     float16_t d;         // delta
-//     int8_t    qs[QK8_0]; // quants
-// };
-#define sizeof_block_q8_0 34
diff --git a/ggml/src/kompute-shaders/op_add.comp b/ggml/src/kompute-shaders/op_add.comp
deleted file mode 100644 (file)
index b7b76a7..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1024) in;
-
-layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb00;
-    int nb01;
-    int nb02;
-    int nb03;
-    int ne10;
-    int ne11;
-    int ne12;
-    int ne13;
-    int nb10;
-    int nb11;
-    int nb12;
-    int nb13;
-    int ne0;
-    int nb0;
-    int nb1;
-    int nb2;
-    int nb3;
-  //int offs; // TODO: needed for GGML_OP_ACC, see metal code
-} pcs;
-
-// general-purpose kernel for addition of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
-// cons: not very efficient
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const uint i13 = i03 % pcs.ne13;
-    const uint i12 = i02 % pcs.ne12;
-    const uint i11 = i01 % pcs.ne11;
-
-    int offs = 0; // TMP (see above)
-
-    uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
-    uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11       ) / 4);
-    uint dst_off  = uint((i03*pcs.nb3  + i02*pcs.nb2  + i01*pcs.nb1  + offs) / 4);
-
-    for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
-        const uint i10 = i0 % pcs.ne10;
-        out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_addrow.comp b/ggml/src/kompute-shaders/op_addrow.comp
deleted file mode 100644 (file)
index 2376a6b..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    uint row;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 4;
-
-    for (uint x = 0; x < 4; x++) {
-        const uint i = baseIndex + x;
-        out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_cpy_f16_f16.comp b/ggml/src/kompute-shaders/op_cpy_f16_f16.comp
deleted file mode 100644 (file)
index d57247d..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define IN_TYPE float16_t
-#define IN_TYPE_SIZE 2
-#define OUT_TYPE float16_t
-#define OUT_TYPE_SIZE 2
-
-layout(local_size_x = 1024) in;
-
-layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
-layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    int ne1;
-    int ne2;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
-
-    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
-    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
-    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
-    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
-
-    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
-
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
-        out_[dst_data+i00] = OUT_TYPE(in_[src]);
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_cpy_f16_f32.comp b/ggml/src/kompute-shaders/op_cpy_f16_f32.comp
deleted file mode 100644 (file)
index b568bcd..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define IN_TYPE float16_t
-#define IN_TYPE_SIZE 2
-#define OUT_TYPE float
-#define OUT_TYPE_SIZE 4
-
-layout(local_size_x = 1024) in;
-
-layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
-layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    int ne1;
-    int ne2;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
-
-    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
-    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
-    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
-    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
-
-    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
-
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
-        out_[dst_data+i00] = OUT_TYPE(in_[src]);
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_cpy_f32_f16.comp b/ggml/src/kompute-shaders/op_cpy_f32_f16.comp
deleted file mode 100644 (file)
index 99b2283..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define IN_TYPE float
-#define IN_TYPE_SIZE 4
-#define OUT_TYPE float16_t
-#define OUT_TYPE_SIZE 2
-
-layout(local_size_x = 1024) in;
-
-layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
-layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    int ne1;
-    int ne2;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
-
-    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
-    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
-    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
-    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
-
-    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
-
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
-        out_[dst_data+i00] = OUT_TYPE(in_[src]);
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_cpy_f32_f32.comp b/ggml/src/kompute-shaders/op_cpy_f32_f32.comp
deleted file mode 100644 (file)
index 2fc9984..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define IN_TYPE float
-#define IN_TYPE_SIZE 4
-#define OUT_TYPE float
-#define OUT_TYPE_SIZE 4
-
-layout(local_size_x = 1024) in;
-
-layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
-layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    int ne1;
-    int ne2;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
-
-    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
-    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
-    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
-    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
-
-    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
-
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
-        out_[dst_data+i00] = OUT_TYPE(in_[src]);
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_diagmask.comp b/ggml/src/kompute-shaders/op_diagmask.comp
deleted file mode 100644 (file)
index 291c3fc..0000000
+++ /dev/null
@@ -1,30 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    uint n_past;
-    int ne00;
-    int ne01;
-} pcs;
-
-void main() {
-    const uint i02 = gl_WorkGroupID.z;
-    const uint i01 = gl_WorkGroupID.y;
-    const uint i00 = gl_WorkGroupID.x;
-
-    const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00;
-
-    if (i00 > pcs.n_past + i01) {
-        out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000);
-    } else {
-        out_[index + pcs.outOff] = in_[index + pcs.inOff];
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_gelu.comp b/ggml/src/kompute-shaders/op_gelu.comp
deleted file mode 100644 (file)
index 9d8c537..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 8;
-
-    for (uint x = 0; x < 8; x++) {
-        const uint i = baseIndex + x;
-        const float y = in_[i + pcs.inOff];
-        out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0)));
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_getrows.comp b/ggml/src/kompute-shaders/op_getrows.comp
deleted file mode 100644 (file)
index 1a5581b..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-void main() {
-    const uint i = gl_WorkGroupID.x;
-    const int r = inB[i + pcs.inBOff];
-
-    int z = 0;
-    for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) {
-        const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK;
-        const mat4 result = dequantize_block(inIndex, ind%NL);
-        for (uint j = 0; j < 4; ++j) {
-            for (uint k = 0; k < 4; ++k) {
-                const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z;
-                out_[outIndex] = result[j][k];
-                ++z;
-            }
-        }
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_getrows_f16.comp b/ggml/src/kompute-shaders/op_getrows_f16.comp
deleted file mode 100644 (file)
index 48c9361..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
-    for (int j = 0; j < k; j++) {
-        out_[y + j] = inA[x + j];
-    }
-}
-
-void main() {
-    const uint i = gl_WorkGroupID.x;
-    const int r = inB[i + pcs.inBOff];
-
-    dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
-}
diff --git a/ggml/src/kompute-shaders/op_getrows_f32.comp b/ggml/src/kompute-shaders/op_getrows_f32.comp
deleted file mode 100644 (file)
index 9d7acda..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { float inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
-    for (int j = 0; j < k; j++) {
-        out_[y + j] = inA[x + j];
-    }
-}
-
-void main() {
-    const uint i = gl_WorkGroupID.x;
-    const int r = inB[i + pcs.inBOff];
-
-    dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
-}
diff --git a/ggml/src/kompute-shaders/op_getrows_q4_0.comp b/ggml/src/kompute-shaders/op_getrows_q4_0.comp
deleted file mode 100644 (file)
index 32b2e89..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define NL 2
-#define BYTES_FOR_TYPE 4 /*bytes for float*/
-#define SIZE_OF_BLOCK sizeof_block_q4_0
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-block_q4_0 get_unaligned_block_q4_0(uint index) {
-    block_q4_0 fres;
-    fres.d = u8BufToFloat16(inA, index);
-    [[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) {
-        fres.qs[it] = inA[index+2+it];
-    }
-    return fres;
-}
-
-mat4 dequantize_block(uint index, uint il) {
-    const block_q4_0 block = get_unaligned_block_q4_0(index);
-    return dequantize_q4_0(block, il);
-}
-
-#include "op_getrows.comp"
diff --git a/ggml/src/kompute-shaders/op_getrows_q4_1.comp b/ggml/src/kompute-shaders/op_getrows_q4_1.comp
deleted file mode 100644 (file)
index 87f2fbe..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define NL 2
-#define BYTES_FOR_TYPE 4 /*bytes for float*/
-#define SIZE_OF_BLOCK sizeof_block_q4_1
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-block_q4_1 get_unaligned_block_q4_1(uint index) {
-    block_q4_1 fres;
-    fres.d = u8BufToFloat16(inA, index);
-    fres.m = u8BufToFloat16(inA, index+2);
-    [[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
-        fres.qs[it] = inA[index+4+it];
-    }
-    return fres;
-}
-
-mat4 dequantize_block(uint index, uint il) {
-    const block_q4_1 block = get_unaligned_block_q4_1(index);
-    return dequantize_q4_1(block, il);
-}
-
-#include "op_getrows.comp"
diff --git a/ggml/src/kompute-shaders/op_getrows_q6_k.comp b/ggml/src/kompute-shaders/op_getrows_q6_k.comp
deleted file mode 100644 (file)
index 9ce3545..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define NL 16
-#define BYTES_FOR_TYPE 4 /*bytes for float*/
-#define SIZE_OF_BLOCK sizeof_block_q6_k
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-block_q6_k get_unaligned_block_q6_k(uint index) {
-    block_q6_k fres;
-    [[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
-        fres.ql[it] = inA[index + it];
-    }
-    [[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
-        fres.qh[it] = inA[index + QK_K/2 + it];
-    }
-    [[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
-        fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
-    }
-    fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
-    return fres;
-}
-
-mat4 dequantize_block(uint index, uint il) {
-    const block_q6_k block = get_unaligned_block_q6_k(index);
-    return dequantize_q6_k(block, il);
-}
-
-#include "op_getrows.comp"
diff --git a/ggml/src/kompute-shaders/op_mul.comp b/ggml/src/kompute-shaders/op_mul.comp
deleted file mode 100644 (file)
index c92647c..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1024) in;
-
-layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb00;
-    int nb01;
-    int nb02;
-    int nb03;
-    int ne10;
-    int ne11;
-    int ne12;
-    int ne13;
-    int nb10;
-    int nb11;
-    int nb12;
-    int nb13;
-    int ne0;
-    int nb0;
-    int nb1;
-    int nb2;
-    int nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const uint i13 = i03 % pcs.ne13;
-    const uint i12 = i02 % pcs.ne12;
-    const uint i11 = i01 % pcs.ne11;
-
-    uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4);
-    uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4);
-    uint dst_off  = uint((i03*pcs.nb3  + i02*pcs.nb2  + i01*pcs.nb1)  / 4);
-
-    for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
-        const uint i10 = i0 % pcs.ne10;
-        out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10];
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_mul_mat_f16.comp b/ggml/src/kompute-shaders/op_mul_mat_f16.comp
deleted file mode 100644 (file)
index 8f0a903..0000000
+++ /dev/null
@@ -1,67 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#extension GL_KHR_shader_subgroup_arithmetic : require
-
-layout(local_size_x_id = 0) in;
-
-layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { float inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    int ne10;
-    int ne11;
-    int ne12;
-    uint nb10;
-    uint nb11;
-    uint nb12;
-    int ne0;
-    int ne1;
-    uint r2;
-    uint r3;
-} pcs;
-
-#define N_F16_F32 4
-
-void main() {
-    const uint r0 = gl_WorkGroupID.x;
-    const uint rb = gl_WorkGroupID.y*N_F16_F32;
-    const uint im = gl_WorkGroupID.z;
-
-    const uint i12 = im%pcs.ne12;
-    const uint i13 = im/pcs.ne12;
-
-    const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
-
-    const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
-
-    for (uint row = 0; row < N_F16_F32; ++row) {
-        uint r1 = rb + row;
-        if (r1 >= pcs.ne11) {
-            break;
-        }
-
-        const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
-
-        float sumf = 0;
-        for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
-            sumf += float(inA[x+i]) * float(inB[y+i]);
-        }
-
-        const float all_sum = subgroupAdd(sumf);
-        if (subgroupElect()) {
-            out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
-        }
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp b/ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp
deleted file mode 100644 (file)
index d1ca4ad..0000000
+++ /dev/null
@@ -1,51 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#extension GL_KHR_shader_subgroup_arithmetic : require
-#extension GL_EXT_debug_printf : enable
-
-// device subgroup size
-layout (local_size_x_id = 0) in;
-
-layout(binding = 0) readonly buffer tensorInA { float inA[]; };
-layout(binding = 1) readonly buffer tensorInB { float inB[]; };
-layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout(push_constant) uniform parameter {
-  uint inAOff;
-  uint inBOff;
-  uint outOff;
-  int ne00;
-  int ne01;
-  int ne02;
-  int ne11;
-  int ne12;
-  uint nb01;
-  uint nb02;
-  uint nb11;
-  uint nb12;
-  uint nb1;
-  uint nb2;
-}
-pcs;
-
-
-void main() {
-  uvec3 gid = gl_WorkGroupID;
-
-  uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
-  uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
-
-  const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
-  const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
-  float sum = 0.0f;
-  for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
-      sum += float(inA[x+i]) * float(inB[y+i]);
-  }
-
-  const float all_sum = subgroupAdd(sum);
-  if (subgroupElect()) {
-    out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
-  }
-}
diff --git a/ggml/src/kompute-shaders/op_mul_mat_q4_0.comp b/ggml/src/kompute-shaders/op_mul_mat_q4_0.comp
deleted file mode 100644 (file)
index b0cea8b..0000000
+++ /dev/null
@@ -1,33 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define BLOCKS_IN_QUANT QK4_0
-#define SIZE_OF_BLOCK sizeof_block_q4_0
-#define N_ROWS 4
-
-#include "op_mul_mv_q_n_pre.comp"
-
-// The q4_0 version of this function
-float block_q_n_dot_y(uint block_index, uint yb, uint il) {
-    vec2 acc = vec2(0.0, 0.0);
-    const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
-    float d = float(u8BufToFloat16(inA, index));
-    float sumy = 0.0f;
-    for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
-        const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
-
-        const float yl0 = inB[yb + i];
-        const float yl1 = inB[yb + i + 1];
-        const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
-        const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
-
-        sumy += yl0 + yl1 + yl8 + yl9;
-
-        acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
-        acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
-    }
-    return d * (sumy * -8.f + acc[0] + acc[1]);
-}
-
-#include "op_mul_mv_q_n.comp"
diff --git a/ggml/src/kompute-shaders/op_mul_mat_q4_1.comp b/ggml/src/kompute-shaders/op_mul_mat_q4_1.comp
deleted file mode 100644 (file)
index 8582c61..0000000
+++ /dev/null
@@ -1,35 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define BLOCKS_IN_QUANT QK4_1
-#define SIZE_OF_BLOCK sizeof_block_q4_1
-#define N_ROWS 4
-
-#include "op_mul_mv_q_n_pre.comp"
-
-// The q4_1 version of this function
-float block_q_n_dot_y(uint block_index, uint yb, uint il) {
-    vec2 acc = vec2(0.0, 0.0);
-    const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
-    float d = float(u8BufToFloat16(inA, index));
-    float m = float(u8BufToFloat16(inA, index+2));
-
-    float sumy = 0.0f;
-    for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
-        const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
-
-        const float yl0 = inB[yb + i];
-        const float yl1 = inB[yb + i + 1];
-        const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
-        const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
-
-        sumy += yl0 + yl1 + yl8 + yl9;
-
-        acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
-        acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
-    }
-    return d * (acc[0] + acc[1]) + sumy * m;
-}
-
-#include "op_mul_mv_q_n.comp"
diff --git a/ggml/src/kompute-shaders/op_mul_mat_q6_k.comp b/ggml/src/kompute-shaders/op_mul_mat_q6_k.comp
deleted file mode 100644 (file)
index c9baebd..0000000
+++ /dev/null
@@ -1,94 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define SIZE_OF_BLOCK sizeof_block_q6_k
-
-layout(local_size_x_id = 0) in;
-layout(local_size_y_id = 1) in;
-layout(local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { float inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int ne10;
-    int ne0;
-    int ne1;
-    int ne01;
-    int gqa;
-} pcs;
-
-void main() {
-    const uint8_t kmask1 = uint8_t(0x03);
-    const uint8_t kmask2 = uint8_t(0x0C);
-    const uint8_t kmask3 = uint8_t(0x30);
-    const uint8_t kmask4 = uint8_t(0xC0);
-
-    const uint nb = pcs.ne00/QK_K;
-
-    const uint r0 = gl_WorkGroupID.x;
-    const uint r1 = gl_WorkGroupID.y;
-    const uint r2 = gl_WorkGroupID.z;
-
-    const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
-    const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
-    const uint x = row * nb + offset0; // Based from inA without base offset
-    const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
-
-    float sumf = 0;
-
-    // bits of invocation ID for gl_SubgroupSize=32:
-    //  x   x   x   x   x
-    //  4   3   2   1   0
-    // (     tid     ) ix
-    //  ip (   il    )
-
-    const uint block_stride = gl_SubgroupSize / 16;         // number of blocks each subgroup processes
-    const uint tid  = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
-    const uint ix   = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
-    const uint ip   = tid/8;        // first or second half of block (0 or 1)
-    const uint il   = tid%8;        // each half has 8 parts, one per scale
-    const uint n    = 4;            // 4 scales at a time (and 4 sums)
-    const uint l0   = n*il;         // offset into half-block, 0..28
-    const uint is   = 8*ip + l0/16; // 0, 1, 8, 9
-
-    const uint y_offset = 128*ip + l0;
-    const uint q_offset_l = 64*ip + l0;
-    const uint q_offset_h = 32*ip + l0;
-
-    for (uint i = ix; i < nb; i += block_stride) {
-
-        const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
-
-        const uint qlIndex = q_offset_l;
-        const uint q2Index = qlIndex + QK_K/8;
-        const uint qhIndex = q_offset_h;
-        const uint y = yy + i * QK_K + y_offset;
-
-        float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
-        for (uint l = 0; l < n; ++l) {
-            const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
-            const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
-            const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
-
-            sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
-            sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
-            sums[2] += inB[y+l+64] * (int8_t((currentQ1  >> 4) | ((currentQh & kmask3) << 0)) - 32);
-            sums[3] += inB[y+l+96] * (int8_t((currentQ2  >> 4) | ((currentQh & kmask4) >> 2)) - 32);
-        }
-
-        float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
-        sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
-    }
-
-    const float tot = subgroupAdd(sumf);
-    if (subgroupElect()) {
-        out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_mul_mat_q8_0.comp b/ggml/src/kompute-shaders/op_mul_mat_q8_0.comp
deleted file mode 100644 (file)
index 34d015e..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#include "op_mul_mv_q_n_pre.comp"
-
-#define SIZE_OF_D 2
-
-#define N_DST 4 // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-
-#define NB_Q8_0 8
-
-void main() {
-    // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
-    if (gl_SubgroupInvocationID > 31)
-        return;
-
-    const int nr  = N_DST;
-    const int nsg = N_SIMDGROUP;
-    const int nw  = N_SIMDWIDTH;
-
-    const int nb = pcs.ne00/QK8_0;
-    const uint r0 = gl_WorkGroupID.x;
-    const uint r1 = gl_WorkGroupID.y;
-    const uint im = gl_WorkGroupID.z;
-
-    const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
-
-    const uint i12 = im%pcs.ne12;
-    const uint i13 = im/pcs.ne12;
-
-    const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
-
-    const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
-    const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
-
-    float yl[NB_Q8_0];
-    float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
-
-    const uint ix = gl_SubgroupInvocationID.x/4;
-    const uint il = gl_SubgroupInvocationID.x%4;
-
-    uint yb = y + ix * QK8_0 + NB_Q8_0*il;
-
-    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
-    for (uint ib = ix; ib < nb; ib += nw/4) {
-        for (int i = 0; i < NB_Q8_0; ++i) {
-            yl[i] = inB[yb + i];
-        }
-
-        for (int row = 0; row < nr; row++) {
-            const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
-            float sumq = 0.f;
-            for (int iq = 0; iq < NB_Q8_0; ++iq) {
-                const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
-                sumq += qs_iq * yl[iq];
-            }
-            const float16_t d = u8BufToFloat16(inA, x + block_offset);
-            sumf[row] += sumq*d;
-        }
-
-        yb += NB_Q8_0 * nw;
-    }
-
-    for (int row = 0; row < nr; ++row) {
-        const float tot = subgroupAdd(sumf[row]);
-        if (subgroupElect() && first_row + row < pcs.ne01) {
-            out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
-        }
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_mul_mv_q_n.comp b/ggml/src/kompute-shaders/op_mul_mv_q_n.comp
deleted file mode 100644 (file)
index 440b5ab..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-void main() {
-    // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
-    if (gl_SubgroupInvocationID > 31)
-        return;
-
-    const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
-
-    const uint r0 = gl_WorkGroupID.x;
-    const uint r1 = gl_WorkGroupID.y;
-    const uint im = gl_WorkGroupID.z;
-
-    const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
-
-    const uint i12 = im%pcs.ne12;
-    const uint i13 = im/pcs.ne12;
-
-    const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
-
-    const uint x = offset0; // Based from inA without base offset
-    const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
-
-    float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
-
-    const uint ix = gl_SubgroupInvocationID/2;
-    const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
-
-    uint yb = y + ix * BLOCKS_IN_QUANT + il;
-
-    //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
-    //    gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
-    //    gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
-
-    for (uint ib = ix; ib < nb; ib += 16) {
-        for (int row = 0; row < N_ROWS; row++) {
-            const uint block_index = x + ib + row * nb;
-            sumf[row] += block_q_n_dot_y(block_index, yb, il);
-        }
-
-        yb += BLOCKS_IN_QUANT * 16;
-    }
-
-    for (int row = 0; row < N_ROWS; ++row) {
-        const float tot = subgroupAdd(sumf[row]);
-        if (first_row + row < pcs.ne01 && subgroupElect()) {
-            out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
-        }
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp b/ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp
deleted file mode 100644 (file)
index 7912b09..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-layout(local_size_x_id = 0) in;
-layout(local_size_y = 1) in;
-layout(local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { float inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int  ne00;
-    int  ne01;
-    int  ne02;
-    int  ne10;
-    int  ne12;
-    int  ne0;
-    int  ne1;
-    uint r2;
-    uint r3;
-} pcs;
diff --git a/ggml/src/kompute-shaders/op_norm.comp b/ggml/src/kompute-shaders/op_norm.comp
deleted file mode 100644 (file)
index ad0c3c0..0000000
+++ /dev/null
@@ -1,84 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 256) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    uint ne00;
-    uint nb01;
-    float eps;
-} pcs;
-
-shared float sum[gl_WorkGroupSize.x];
-
-void main() {
-    const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
-    // MEAN
-    // parallel sum
-    sum[gl_LocalInvocationID.x] = 0.0;
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        sum[gl_LocalInvocationID.x] += in_[x+i00];
-    }
-
-    // reduce
-    barrier();
-    memoryBarrierShared();
-    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
-        if (gl_LocalInvocationID.x < i) {
-            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
-        }
-        barrier();
-        memoryBarrierShared();
-    }
-
-    // broadcast
-    if (gl_LocalInvocationID.x == 0) {
-        sum[0] /= float(pcs.ne00);
-    }
-    barrier();
-    memoryBarrierShared();
-    const float mean = sum[0];
-
-    // recenter
-    const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        out_[y+i00] = in_[x+i00] - mean;
-    }
-
-    // VARIANCE
-    // parallel sum
-    sum[gl_LocalInvocationID.x] = 0.0;
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
-    }
-
-    // reduce
-    barrier();
-    memoryBarrierShared();
-    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
-        if (gl_LocalInvocationID.x < i) {
-            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
-        }
-        barrier();
-        memoryBarrierShared();
-    }
-
-    // broadcast
-    if (gl_LocalInvocationID.x == 0) {
-        sum[0] /= float(pcs.ne00);
-    }
-    barrier();
-    memoryBarrierShared();
-    const float variance = sum[0];
-
-    const float scale = 1.0f/sqrt(variance + pcs.eps);
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        out_[y+i00] *= scale;
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_relu.comp b/ggml/src/kompute-shaders/op_relu.comp
deleted file mode 100644 (file)
index 52a601f..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 4;
-
-    for (uint x = 0; x < 4; x++) {
-        const uint i = baseIndex + x;
-        out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_rmsnorm.comp b/ggml/src/kompute-shaders/op_rmsnorm.comp
deleted file mode 100644 (file)
index da658c1..0000000
+++ /dev/null
@@ -1,53 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 512) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    uint ne00;
-    uint nb01;
-    float eps;
-} pcs;
-
-shared float sum[gl_WorkGroupSize.x];
-
-void main() {
-    const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
-
-    // parallel sum
-    sum[gl_LocalInvocationID.x] = 0.0;
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
-    }
-
-    // reduce
-    barrier();
-    memoryBarrierShared();
-    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
-        if (gl_LocalInvocationID.x < i) {
-            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
-        }
-        barrier();
-        memoryBarrierShared();
-    }
-
-    // broadcast
-    if (gl_LocalInvocationID.x == 0) {
-        sum[0] /= float(pcs.ne00);
-    }
-    barrier();
-    memoryBarrierShared();
-
-    const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
-
-    const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        out_[y+i00] = in_[x+i00] * scale;
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_rope_f16.comp b/ggml/src/kompute-shaders/op_rope_f16.comp
deleted file mode 100644 (file)
index 1a4058b..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-#version 450
-
-#include "rope_common.comp"
-
-layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
-layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
-
-void main() {
-    const uint i3 = gl_WorkGroupID.z;
-    const uint i2 = gl_WorkGroupID.y;
-    const uint i1 = gl_WorkGroupID.x;
-
-    const bool is_neox = (pcs.mode & 2) != 0;
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
-
-    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
-
-    const int p = inB[pcs.inBOff + i2];
-
-    float theta = float(p);
-
-    if (!is_neox) {
-        for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
-            float cos_theta, sin_theta;
-            rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            theta *= theta_scale;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            const float x0 = float(inA[src]);
-            const float x1 = float(inA[src+1]);
-
-            out_[dst_data]   = float16_t(x0*cos_theta - x1*sin_theta);
-            out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
-        }
-    } else {
-        const float inv_ndims = -1.f/pcs.n_dims;
-        for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
-            const uint cur_rot = ic;
-
-            float cos_theta, sin_theta;
-            rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            theta *= theta_scale;
-
-            const uint i0 = ic/2;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            const float x0 = float(inA[src]);
-            const float x1 = float(inA[src+pcs.n_dims/2]);
-
-            out_[dst_data]              = float16_t(x0*cos_theta - x1*sin_theta);
-            out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
-        }
-
-        for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
-            const uint i0 = ic;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            out_[dst_data + 0] = inA[src + 0];
-            out_[dst_data + 1] = inA[src + 1];
-        }
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_rope_f32.comp b/ggml/src/kompute-shaders/op_rope_f32.comp
deleted file mode 100644 (file)
index 65e0382..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-#version 450
-
-#include "rope_common.comp"
-
-layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly  tensorInB { int   inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-void main() {
-    const uint i3 = gl_WorkGroupID.z;
-    const uint i2 = gl_WorkGroupID.y;
-    const uint i1 = gl_WorkGroupID.x;
-
-    const bool is_neox = (pcs.mode & 2) != 0;
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
-
-    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
-
-    const int p = inB[pcs.inBOff + i2];
-
-    float theta = float(p);
-
-    if (!is_neox) {
-        for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
-            float cos_theta, sin_theta;
-            rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            theta *= theta_scale;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
-
-            const float x0 = inA[src];
-            const float x1 = inA[src+1];
-
-            out_[dst_data]   = x0*cos_theta - x1*sin_theta;
-            out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
-        }
-    } else {
-        const float inv_ndims = -1.f/pcs.n_dims;
-        for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
-            const uint cur_rot = ic;
-
-            float cos_theta, sin_theta;
-            rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            theta *= theta_scale;
-
-            const uint i0 = ic/2;
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
-
-            const float x0 = inA[src];
-            const float x1 = inA[src+pcs.n_dims/2];
-
-            out_[dst_data] = x0*cos_theta - x1*sin_theta;
-            out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
-        }
-
-        for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
-            const uint i0 = ic;
-
-            const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
-
-            out_[dst_data + 0] = inA[src + 0];
-            out_[dst_data + 1] = inA[src + 1];
-        }
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_scale.comp b/ggml/src/kompute-shaders/op_scale.comp
deleted file mode 100644 (file)
index bdae267..0000000
+++ /dev/null
@@ -1,19 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    float scale;
-} pcs;
-
-void main() {
-    const uint i = gl_WorkGroupID.x;
-    out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
-}
diff --git a/ggml/src/kompute-shaders/op_scale_8.comp b/ggml/src/kompute-shaders/op_scale_8.comp
deleted file mode 100644 (file)
index ada6975..0000000
+++ /dev/null
@@ -1,23 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    float scale;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 8;
-
-    for (uint x = 0; x < 8; x++) {
-        const uint i = baseIndex + x;
-        out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_silu.comp b/ggml/src/kompute-shaders/op_silu.comp
deleted file mode 100644 (file)
index 0fb8e4b..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 4;
-
-    for (uint x = 0; x < 4; x++) {
-        const uint i = baseIndex + x;
-        const float y = in_[i + pcs.inOff];
-        out_[i + pcs.outOff] = y / (1.0 + exp(-y));
-    }
-}
diff --git a/ggml/src/kompute-shaders/op_softmax.comp b/ggml/src/kompute-shaders/op_softmax.comp
deleted file mode 100644 (file)
index 7bc9176..0000000
+++ /dev/null
@@ -1,56 +0,0 @@
-// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
-
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x_id = 0) in;
-
-layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    float scale;
-    int mask;
-} pcs;
-
-void main() {
-    if (gl_SubgroupInvocationID > 31)
-        return;
-
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
-    const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
-    const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
-    const uint pdst = extra_off + pcs.outOff; // Based from out_
-
-    // parallel max
-    float localMax = uintBitsToFloat(0xFF800000);
-    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
-        localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
-    }
-    float max_ = subgroupMax(localMax);
-
-    // parallel sum
-    float localSum = 0.0f;
-    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
-        const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
-        localSum += exp_psrc0;
-        out_[pdst + i00] = exp_psrc0;
-    }
-
-    const float sum = subgroupAdd(localSum);
-    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
-        out_[pdst + i00] /= sum;
-    }
-}
diff --git a/ggml/src/kompute-shaders/rope_common.comp b/ggml/src/kompute-shaders/rope_common.comp
deleted file mode 100644 (file)
index 7b9394c..0000000
+++ /dev/null
@@ -1,67 +0,0 @@
-#include "common.comp"
-
-// TODO: use a local size of 32 or more (Metal uses 1024)
-layout(local_size_x = 1) in;
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int n_dims;
-    int mode;
-    int n_ctx_orig;
-    float freq_base;
-    float freq_scale;
-    float ext_factor;
-    float attn_factor;
-    float beta_fast;
-    float beta_slow;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-float rope_yarn_ramp(const float low, const float high, const float i0) {
-    const float y = (i0 / 2 - low) / max(0.001f, high - low);
-    return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-void rope_yarn(
-    float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
-    out float cos_theta, out float sin_theta
-) {
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
-    }
-    cos_theta = cos(theta) * mscale;
-    sin_theta = sin(theta) * mscale;
-}
-
-// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
-// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
-    return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
-}
-
-void rope_yarn_corr_dims(
-    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
-) {
-    // start and end correction dims
-    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
-    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
-}
diff --git a/ggml/src/sgemm.cpp b/ggml/src/sgemm.cpp
deleted file mode 100644 (file)
index 6626ceb..0000000
+++ /dev/null
@@ -1,1027 +0,0 @@
-// Copyright 2024 Mozilla Foundation
-//
-// Permission is hereby granted, free of charge, to any person obtaining
-// a copy of this software and associated documentation files (the
-// "Software"), to deal in the Software without restriction, including
-// without limitation the rights to use, copy, modify, merge, publish,
-// distribute, sublicense, and/or sell copies of the Software, and to
-// permit persons to whom the Software is furnished to do so, subject to
-// the following conditions:
-//
-// The above copyright notice and this permission notice shall be
-// included in all copies or substantial portions of the Software.
-//
-// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
-// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
-// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-// SOFTWARE.
-
-//
-//                   _   _          ___ _      _   ___
-//                  | |_(_)_ _ _  _| _ ) |    /_\ / __|
-//                  |  _| | ' \ || | _ \ |__ / _ \\__ \.
-//                   \__|_|_||_\_, |___/____/_/ \_\___/
-//                             |__/
-//
-//                    BASIC LINEAR ALGEBRA SUBPROGRAMS
-//
-//
-// This file implements multithreaded CPU matrix multiplication for the
-// common contiguous use case C = Aᵀ * B. These kernels are designed to
-// have excellent performance[1] for matrices that fit in the CPU cache
-// without imposing any overhead such as cache filling or malloc calls.
-//
-// This implementation does not guarantee any upper bound with rounding
-// errors, which grow along with k. Our goal's to maximally exploit the
-// hardware for performance, and then use whatever resources remain for
-// improving numerical accuracy.
-//
-// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
-//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wpedantic"
-#pragma GCC diagnostic ignored "-Wignored-attributes"
-#endif
-
-#include "sgemm.h"
-#include "ggml-impl.h"
-#include "ggml-quants.h"
-
-#ifdef _MSC_VER
-#define NOINLINE __declspec(noinline)
-#else
-#define NOINLINE __attribute__((__noinline__))
-#endif
-
-#if defined(__ARM_NEON) || defined(__AVX512F__)
-#define VECTOR_REGISTERS 32
-#else
-#define VECTOR_REGISTERS 16
-#endif
-
-#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
-namespace {
-
-inline float unhalf(ggml_fp16_t d) {
-    return GGML_FP16_TO_FP32(d);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED ARITHMETIC OPERATIONS
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
-inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
-inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
-#endif  // __SSE__
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
-inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
-inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
-#endif // __AVX__
-
-#if defined(__AVX512F__)
-inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
-inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
-inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
-#endif // __AVX512F__
-
-#if defined(__ARM_NEON)
-inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
-inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
-inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
-#endif // __ARM_NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
-inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
-inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
-inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED FUSED MULTIPLY ADD
-
-/**
- * Computes a * b + c.
- */
-template <typename T, typename U>
-inline U madd(T a, T b, U c) {
-    return add(mul(a, b), c);
-}
-
-#if defined(__FMA__)
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <>
-inline __m256 madd(__m256 a, __m256 b, __m256 c) {
-    return _mm256_fmadd_ps(a, b, c);
-}
-#endif
-#if defined(__AVX512F__)
-template <>
-inline __m512 madd(__m512 a, __m512 b, __m512 c) {
-    return _mm512_fmadd_ps(a, b, c);
-}
-#endif
-#endif
-
-#if defined(__ARM_FEATURE_FMA)
-template <>
-inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
-    return vfmaq_f32(c, b, a);
-}
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
-template <>
-inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
-    return vfmaq_f16(c, b, a);
-}
-#endif
-#endif
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED HORIZONTAL SUM
-
-#if defined(__ARM_NEON)
-inline float hsum(float32x4_t x) {
-    return vaddvq_f32(x);
-}
-#endif // __ARM_NEON
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
-inline float hsum(float16x8_t x) {
-    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
-                                vcvt_f32_f16(vget_high_f16(x))));
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline float hsum(__m128 x) {
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-    x = _mm_add_ps(x, _mm_movehl_ps(x, x));
-    x = _mm_add_ss(x, _mm_movehdup_ps(x));
-#else
-    __m128 t;
-    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
-    x = _mm_add_ps(x, t);
-    t = _mm_movehl_ps(t, x);
-    x = _mm_add_ss(x, t);
-#endif
-    return _mm_cvtss_f32(x);
-}
-#endif
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-inline float hsum(__m256 x) {
-    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
-                           _mm256_castps256_ps128(x)));
-}
-#endif // __AVX__
-
-#if defined(__AVX512F__)
-inline float hsum(__m512 x) {
-    return _mm512_reduce_add_ps(x);
-}
-#endif // __AVX512F__
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// VECTORIZED MEMORY LOADING
-
-template <typename T, typename U> T load(const U *);
-
-#if defined(__ARM_NEON)
-template <> inline float32x4_t load(const float *p) {
-    return vld1q_f32(p);
-}
-#if !defined(_MSC_VER)
-template <> inline float16x8_t load(const ggml_fp16_t *p) {
-    return vld1q_f16((const float16_t *)p);
-}
-template <> inline float32x4_t load(const ggml_fp16_t *p) {
-    return vcvt_f32_f16(vld1_f16((const float16_t *)p));
-}
-#endif // _MSC_VER
-#endif // __ARM_NEON
-
-#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <> inline __m128 load(const float *p) {
-    return _mm_loadu_ps(p);
-}
-#endif  // __SSE__
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-template <> inline __m256 load(const float *p) {
-    return _mm256_loadu_ps(p);
-}
-#endif // __AVX__
-
-#if defined(__F16C__)
-template <> inline __m256 load(const ggml_fp16_t *p) {
-    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
-}
-#endif // __F16C__
-
-#if defined(__AVX512F__)
-template <> inline __m512 load(const float *p) {
-    return _mm512_loadu_ps(p);
-}
-template <> inline __m512 load(const ggml_fp16_t *p) {
-    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
-}
-#endif // __AVX512F__
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-// FLOATING POINT MATRIX MULTIPLICATION
-
-template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
-class tinyBLAS {
-  public:
-    tinyBLAS(int64_t k,
-             const TA *A, int64_t lda,
-             const TB *B, int64_t ldb,
-             TC *C, int64_t ldc,
-             int ith, int nth)
-        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-    }
-
-    void matmul(int64_t m, int64_t n) {
-        mnpack(0, m, 0, n);
-    }
-
-  private:
-    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t mc, nc, mp, np;
-        switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
-#if VECTOR_REGISTERS == 32
-        case 0x55:
-            mc = 5;
-            nc = 5;
-            gemm<5, 5>(m0, m, n0, n);
-            break;
-        case 0x45:
-            mc = 4;
-            nc = 5;
-            gemm<4, 5>(m0, m, n0, n);
-            break;
-        case 0x54:
-            mc = 5;
-            nc = 4;
-            gemm<5, 4>(m0, m, n0, n);
-            break;
-        case 0x44:
-            mc = 4;
-            nc = 4;
-            gemm<4, 4>(m0, m, n0, n);
-            break;
-        case 0x53:
-            mc = 5;
-            nc = 3;
-            gemm<5, 3>(m0, m, n0, n);
-            break;
-        case 0x35:
-            mc = 3;
-            nc = 5;
-            gemm<3, 5>(m0, m, n0, n);
-            break;
-        case 0x43:
-            mc = 4;
-            nc = 3;
-            gemm<4, 3>(m0, m, n0, n);
-            break;
-#else
-        case 0x55:
-        case 0x54:
-        case 0x53:
-        case 0x45:
-        case 0x44:
-        case 0x43:
-            mc = 4;
-            nc = 3;
-            gemm<4, 3>(m0, m, n0, n);
-            break;
-        case 0x35:
-#endif
-        case 0x34:
-            mc = 3;
-            nc = 4;
-            gemm<3, 4>(m0, m, n0, n);
-            break;
-        case 0x52:
-            mc = 5;
-            nc = 2;
-            gemm<5, 2>(m0, m, n0, n);
-            break;
-        case 0x33:
-            mc = 3;
-            nc = 3;
-            gemm<3, 3>(m0, m, n0, n);
-            break;
-        case 0x25:
-            mc = 2;
-            nc = 5;
-            gemm<2, 5>(m0, m, n0, n);
-            break;
-        case 0x42:
-            mc = 4;
-            nc = 2;
-            gemm<4, 2>(m0, m, n0, n);
-            break;
-        case 0x24:
-            mc = 2;
-            nc = 4;
-            gemm<2, 4>(m0, m, n0, n);
-            break;
-        case 0x32:
-            mc = 3;
-            nc = 2;
-            gemm<3, 2>(m0, m, n0, n);
-            break;
-        case 0x23:
-            mc = 2;
-            nc = 3;
-            gemm<2, 3>(m0, m, n0, n);
-            break;
-        case 0x51:
-            mc = 5;
-            nc = 1;
-            gemm<5, 1>(m0, m, n0, n);
-            break;
-        case 0x41:
-            mc = 4;
-            nc = 1;
-            gemm<4, 1>(m0, m, n0, n);
-            break;
-        case 0x22:
-            mc = 2;
-            nc = 2;
-            gemm<2, 2>(m0, m, n0, n);
-            break;
-        case 0x15:
-            mc = 1;
-            nc = 5;
-            gemm<1, 5>(m0, m, n0, n);
-            break;
-        case 0x14:
-            mc = 1;
-            nc = 4;
-            gemm<1, 4>(m0, m, n0, n);
-            break;
-        case 0x31:
-            mc = 3;
-            nc = 1;
-            gemm<3, 1>(m0, m, n0, n);
-            break;
-        case 0x13:
-            mc = 1;
-            nc = 3;
-            gemm<1, 3>(m0, m, n0, n);
-            break;
-        case 0x21:
-            mc = 2;
-            nc = 1;
-            gemm<2, 1>(m0, m, n0, n);
-            break;
-        case 0x12:
-            mc = 1;
-            nc = 2;
-            gemm<1, 2>(m0, m, n0, n);
-            break;
-        case 0x11:
-            mc = 1;
-            nc = 1;
-            gemm<1, 1>(m0, m, n0, n);
-            break;
-        default:
-            return;
-        }
-        mp = m0 + (m - m0) / mc * mc;
-        np = n0 + (n - n0) / nc * nc;
-        mnpack(mp, m, n0, np);
-        mnpack(m0, m, np, n);
-    }
-
-    template <int RM, int RN>
-    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t ytiles = (m - m0) / RM;
-        int64_t xtiles = (n - n0) / RN;
-        int64_t tiles = xtiles * ytiles;
-        int64_t duty = (tiles + nth - 1) / nth;
-        int64_t start = duty * ith;
-        int64_t end = start + duty;
-        if (end > tiles)
-            end = tiles;
-        for (int64_t job = start; job < end; ++job) {
-            int64_t ii = m0 + job / xtiles * RM;
-            int64_t jj = n0 + job % xtiles * RN;
-            D Cv[RN][RM] = {};
-            for (int64_t l = 0; l < k; l += KN)
-                for (int64_t j = 0; j < RN; ++j)
-                    for (int64_t i = 0; i < RM; ++i)
-                        Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
-                                        load<V>(B + ldb * (jj + j) + l),
-                                        Cv[j][i]);
-            for (int64_t j = 0; j < RN; ++j)
-                for (int64_t i = 0; i < RM; ++i)
-                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
-        }
-    }
-
-    const TA *const A;
-    const TB *const B;
-    TC *const C;
-    const int64_t k;
-    const int64_t lda;
-    const int64_t ldb;
-    const int64_t ldc;
-    const int ith;
-    const int nth;
-};
-
-//////////////////////////////////////////////////////////////////////////////////////////
-// QUANT ZERO MATRIX MULTIPLICATION
-
-#if defined(__ARM_FEATURE_DOTPROD)
-template <typename TA>
-class tinyBLAS_Q0_ARM {
-  public:
-    tinyBLAS_Q0_ARM(int64_t k,
-                    const TA *A, int64_t lda,
-                    const block_q8_0 *B, int64_t ldb,
-                    float *C, int64_t ldc,
-                    int ith, int nth)
-        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-    }
-
-    void matmul(int64_t m, int64_t n) {
-        mnpack(0, m, 0, n);
-    }
-
-  private:
-    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t mc, nc, mp, np;
-        switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
-        case 0x33:
-            mc = 3;
-            nc = 3;
-            gemm<3, 3>(m0, m, n0, n);
-            break;
-        case 0x32:
-            mc = 3;
-            nc = 2;
-            gemm<3, 2>(m0, m, n0, n);
-            break;
-        case 0x23:
-            mc = 2;
-            nc = 3;
-            gemm<2, 3>(m0, m, n0, n);
-            break;
-        case 0x22:
-            mc = 2;
-            nc = 2;
-            gemm<2, 2>(m0, m, n0, n);
-            break;
-        case 0x31:
-            mc = 3;
-            nc = 1;
-            gemm<3, 1>(m0, m, n0, n);
-            break;
-        case 0x13:
-            mc = 1;
-            nc = 3;
-            gemm<1, 3>(m0, m, n0, n);
-            break;
-        case 0x21:
-            mc = 2;
-            nc = 1;
-            gemm<2, 1>(m0, m, n0, n);
-            break;
-        case 0x12:
-            mc = 1;
-            nc = 2;
-            gemm<1, 2>(m0, m, n0, n);
-            break;
-        case 0x11:
-            mc = 1;
-            nc = 1;
-            gemm<1, 1>(m0, m, n0, n);
-            break;
-        default:
-            return;
-        }
-        mp = m0 + (m - m0) / mc * mc;
-        np = n0 + (n - n0) / nc * nc;
-        mnpack(mp, m, n0, np);
-        mnpack(m0, m, np, n);
-    }
-
-    template <int RM, int RN>
-    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t ytiles = (m - m0) / RM;
-        int64_t xtiles = (n - n0) / RN;
-        int64_t tiles = xtiles * ytiles;
-        int64_t duty = (tiles + nth - 1) / nth;
-        int64_t start = duty * ith;
-        int64_t end = start + duty;
-        if (end > tiles)
-            end = tiles;
-        for (int64_t job = start; job < end; ++job) {
-            int64_t ii = m0 + job / xtiles * RM;
-            int64_t jj = n0 + job % xtiles * RN;
-            float32x4_t Cv[RN][RM] = {};
-            for (int64_t l = 0; l < k; ++l)
-                for (int64_t j = 0; j < RN; ++j)
-                    for (int64_t i = 0; i < RM; ++i)
-                        Cv[j][i] = vmlaq_n_f32(Cv[j][i],
-                                               vcvtq_f32_s32(vdotq_s32(
-                                                   vdotq_s32(vdupq_n_s32(0),
-                                                             load_lo(A + lda * (ii + i) + l),
-                                                             load_lo(B + ldb * (jj + j) + l)),
-                                                   load_hi(A + lda * (ii + i) + l),
-                                                   load_hi(B + ldb * (jj + j) + l))),
-                                               unhalf(A[lda * (ii + i) + l].d) *
-                                               unhalf(B[ldb * (jj + j) + l].d));
-            for (int64_t j = 0; j < RN; ++j)
-                for (int64_t i = 0; i < RM; ++i)
-                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
-        }
-    }
-
-    inline int8x16_t load_lo(const block_q8_0 *b) {
-        return vld1q_s8(b->qs);
-    }
-
-    inline int8x16_t load_hi(const block_q8_0 *b) {
-        return vld1q_s8(b->qs + 16);
-    }
-
-    inline int8x16_t load_lo(const block_q4_0 *b) {
-        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
-                                                     vdupq_n_u8(0x0f))),
-                        vdupq_n_s8(0x8));
-    }
-
-    inline int8x16_t load_hi(const block_q4_0 *b) {
-        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
-                        vdupq_n_s8(0x8));
-    }
-
-    const TA *const A;
-    const block_q8_0 *const B;
-    float *const C;
-    const int64_t k;
-    const int64_t lda;
-    const int64_t ldb;
-    const int64_t ldc;
-    const int ith;
-    const int nth;
-};
-#endif // __ARM_FEATURE_DOTPROD
-
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
-template <typename TA, typename TB, typename TC>
-class tinyBLAS_Q0_AVX {
-  public:
-    tinyBLAS_Q0_AVX(int64_t k,
-                    const TA *A, int64_t lda,
-                    const TB *B, int64_t ldb,
-                    TC *C, int64_t ldc,
-                    int ith, int nth)
-        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-    }
-
-    void matmul(int64_t m, int64_t n) {
-        mnpack(0, m, 0, n);
-    }
-
-  private:
-    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t mc, nc, mp, np;
-        switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
-#if VECTOR_REGISTERS == 32
-        case 0x44:
-            mc = 4;
-            nc = 4;
-            gemm<4, 4>(m0, m, n0, n);
-            break;
-        case 0x43:
-            mc = 4;
-            nc = 3;
-            gemm<4, 3>(m0, m, n0, n);
-            break;
-        case 0x34:
-            mc = 3;
-            nc = 4;
-            gemm<3, 4>(m0, m, n0, n);
-            break;
-        case 0x33:
-            mc = 3;
-            nc = 3;
-            gemm<3, 3>(m0, m, n0, n);
-            break;
-        case 0x42:
-            mc = 4;
-            nc = 2;
-            gemm<4, 2>(m0, m, n0, n);
-            break;
-        case 0x24:
-            mc = 2;
-            nc = 4;
-            gemm<2, 4>(m0, m, n0, n);
-            break;
-#else
-        case 0x44:
-        case 0x43:
-        case 0x42:
-            mc = 4;
-            nc = 2;
-            gemm<4, 2>(m0, m, n0, n);
-            break;
-        case 0x34:
-        case 0x24:
-            mc = 2;
-            nc = 4;
-            gemm<2, 4>(m0, m, n0, n);
-            break;
-        case 0x33:
-#endif
-        case 0x32:
-            mc = 3;
-            nc = 2;
-            gemm<3, 2>(m0, m, n0, n);
-            break;
-        case 0x23:
-            mc = 2;
-            nc = 3;
-            gemm<2, 3>(m0, m, n0, n);
-            break;
-        case 0x41:
-            mc = 4;
-            nc = 1;
-            gemm<4, 1>(m0, m, n0, n);
-            break;
-        case 0x22:
-            mc = 2;
-            nc = 2;
-            gemm<2, 2>(m0, m, n0, n);
-            break;
-        case 0x14:
-            mc = 1;
-            nc = 4;
-            gemm<1, 4>(m0, m, n0, n);
-            break;
-        case 0x31:
-            mc = 3;
-            nc = 1;
-            gemm<3, 1>(m0, m, n0, n);
-            break;
-        case 0x13:
-            mc = 1;
-            nc = 3;
-            gemm<1, 3>(m0, m, n0, n);
-            break;
-        case 0x21:
-            mc = 2;
-            nc = 1;
-            gemm<2, 1>(m0, m, n0, n);
-            break;
-        case 0x12:
-            mc = 1;
-            nc = 2;
-            gemm<1, 2>(m0, m, n0, n);
-            break;
-        case 0x11:
-            mc = 1;
-            nc = 1;
-            gemm<1, 1>(m0, m, n0, n);
-            break;
-        default:
-            return;
-        }
-        mp = m0 + (m - m0) / mc * mc;
-        np = n0 + (n - n0) / nc * nc;
-        mnpack(mp, m, n0, np);
-        mnpack(m0, m, np, n);
-    }
-
-    template <int RM, int RN>
-    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t ytiles = (m - m0) / RM;
-        int64_t xtiles = (n - n0) / RN;
-        int64_t tiles = xtiles * ytiles;
-        int64_t duty = (tiles + nth - 1) / nth;
-        int64_t start = duty * ith;
-        int64_t end = start + duty;
-        if (end > tiles)
-            end = tiles;
-        for (int64_t job = start; job < end; ++job) {
-            int64_t ii = m0 + job / xtiles * RM;
-            int64_t jj = n0 + job % xtiles * RN;
-            __m256 Cv[RN][RM] = {};
-            for (int64_t l = 0; l < k; ++l)
-                for (int64_t j = 0; j < RN; ++j)
-                    for (int64_t i = 0; i < RM; ++i) {
-#if defined(__AVX2__)
-                        __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
-                                                              load(A + lda * (ii + i) + l)),
-                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
-                                                              load(A + lda * (ii + i) + l)));
-#else
-                        __m128i ali0 = load0(A + lda * (ii + i) + l);
-                        __m128i ali1 = load1(A + lda * (ii + i) + l);
-                        __m128i blj0 = load0(B + ldb * (jj + j) + l);
-                        __m128i blj1 = load1(B + ldb * (jj + j) + l);
-
-                        __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
-                        __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
-                        __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
-                        __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
-
-                        // updot
-                        const __m128i oneFill = _mm_set1_epi16(1);
-                        __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
-                        __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
-                        __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
-#endif
-                        Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
-                                                       unhalf(B[ldb * (jj + j) + l].d)),
-                                                       udTmp,
-                                                       Cv[j][i]);
-                    }
-            for (int64_t j = 0; j < RN; ++j)
-                for (int64_t i = 0; i < RM; ++i)
-                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
-        }
-    }
-
-    inline __m256i load(const block_q8_0 *b) {
-        return _mm256_loadu_si256((const __m256i *)b->qs);
-    }
-
-    inline __m128i load0(const block_q8_0 *b) {
-        return _mm_loadu_si128((const __m128i *)b->qs);
-    }
-
-    inline __m128i load1(const block_q8_0 *b) {
-        return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
-    }
-
-    inline __m256i load(const block_q4_0 *b) {
-        return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
-    }
-
-    inline __m128i load0(const block_q4_0 *b) {
-        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
-        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
-    }
-
-    inline __m128i load1(const block_q4_0 *b) {
-        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
-        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
-    }
-
-    inline __m256 updot(__m256i u, __m256i s) {
-        __m256i res;
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
-        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
-#else
-        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
-#endif
-        return _mm256_cvtepi32_ps(res);
-    }
-
-    static inline __m256i denibble(const uint8_t *p) {
-        __m128i x = _mm_loadu_si128((const __m128i *)p);
-        return _mm256_and_si256(_mm256_set1_epi8(15),
-                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),
-                                                        _mm_srli_epi16(x, 4), 1));
-    }
-
-    const TA *const A;
-    const TB *const B;
-    TC *const C;
-    const int64_t k;
-    const int64_t lda;
-    const int64_t ldb;
-    const int64_t ldc;
-    const int ith;
-    const int nth;
-};
-#endif // __AVX__
-
-} // namespace
-
-/**
- * Performs optimized matrix multiplication on CPU.
- *
- * This subroutine may compute C = Aᵀ * B with column major ordering.
- * Despite its name, this isn't a generalized implementation. Work is
- * only performed when a handwritten kernel is written and available.
- * Otherwise the caller should fall back to a general matmul routine.
- *
- * For example, for single-threaded single-precision GEMM you can say
- *
- *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
- *                     0, 1,
- *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
- *
- * @param m is rows in `A` and `C`
- * @param n is cols in `B` and `C`
- * @param k is cols in `A` and rows in `B`
- * @param A is first input matrix (always transposed)
- * @param lda is row stride of `A`
- * @param B is second input matrix (never transposed)
- * @param ldb is row stride of `B`
- * @param C is input/output array of output matrices
- * @param ldc is row stride of `C`
- * @param ith is thread id (must be less than `nth`)
- * @param nth is number of threads (must be greater than zero)
- * @param Atype is GGML data type of `A`
- * @param Btype is GGML data type of `B`
- * @param Ctype is GGML data type of `C`
- * @return true if this function was able to service the matmul request
- */
-bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
-                     int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
-
-    assert(m >= 0);
-    assert(n >= 0);
-    assert(k >= 0);
-    assert(lda >= k);
-    assert(ldb >= k);
-    assert(ldc >= m);
-    assert(nth > 0);
-    assert(ith < nth);
-
-    if (Ctype != GGML_TYPE_F32)
-        return false;
-
-    switch (Atype) {
-
-    case GGML_TYPE_F32: {
-        if (Btype != GGML_TYPE_F32)
-            return false;
-#if defined(__AVX512F__)
-        if (k % 16)
-            return false;
-        tinyBLAS<16, __m512, __m512, float, float, float> tb{
-            k, (const float *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__AVX__) || defined(__AVX2__)
-        if (k % 8)
-            return false;
-        tinyBLAS<8, __m256, __m256, float, float, float> tb{
-            k, (const float *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_NEON)
-        if (n < 4)
-            return false;
-        if (k % 4)
-            return false;
-        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
-            k, (const float *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#else
-        return false;
-#endif
-    }
-
-    case GGML_TYPE_F16: {
-#if defined(__AVX512F__)
-        if (k % 16)
-            return false;
-        if (Btype != GGML_TYPE_F32)
-            return false;
-        tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
-            k, (const ggml_fp16_t *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
-        if (k % 8)
-            return false;
-        if (Btype != GGML_TYPE_F32)
-            return false;
-        tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
-            k, (const ggml_fp16_t *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
-        if (n < 8)
-            return false;
-        if (k % 8)
-            return false;
-        if (Btype != GGML_TYPE_F16)
-            return false;
-        tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
-            k, (const ggml_fp16_t *)A, lda,
-            (const ggml_fp16_t *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_NEON) && !defined(_MSC_VER)
-        if (k % 4)
-            return false;
-        if (Btype != GGML_TYPE_F32)
-            return false;
-        tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
-            k, (const ggml_fp16_t *)A, lda,
-            (const float *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#else
-        return false;
-#endif
-    }
-
-    case GGML_TYPE_Q8_0: {
-        if (Btype != GGML_TYPE_Q8_0)
-           return false;
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
-        tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
-            k, (const block_q8_0 *)A, lda,
-            (const block_q8_0 *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_FEATURE_DOTPROD)
-        tinyBLAS_Q0_ARM<block_q8_0> tb{
-            k, (const block_q8_0 *)A, lda,
-            (const block_q8_0 *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#else
-        return false;
-#endif
-    }
-
-    case GGML_TYPE_Q4_0: {
-        if (Btype != GGML_TYPE_Q8_0)
-            return false;
-#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
-        tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
-            k, (const block_q4_0 *)A, lda,
-            (const block_q8_0 *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#elif defined(__ARM_FEATURE_DOTPROD)
-        tinyBLAS_Q0_ARM<block_q4_0> tb{
-            k, (const block_q4_0 *)A, lda,
-            (const block_q8_0 *)B, ldb,
-            (float *)C, ldc,
-            ith, nth};
-        tb.matmul(m, n);
-        return true;
-#else
-        return false;
-#endif
-    }
-
-    default:
-        return false;
-    }
-
-    (void)m;
-    (void)n;
-    (void)k;
-    (void)A;
-    (void)lda;
-    (void)B;
-    (void)ldb;
-    (void)C;
-    (void)ldc;
-    (void)ith;
-    (void)nth;
-    (void)Atype;
-    (void)Btype;
-    (void)Ctype;
-}
diff --git a/ggml/src/sgemm.h b/ggml/src/sgemm.h
deleted file mode 100644 (file)
index caf6dd5..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-#pragma once
-#include <stdint.h>
-#include <stdbool.h>
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
-                     const void *, int64_t, void *, int64_t, int, int,
-                     int, int, int);
-
-#ifdef __cplusplus
-}
-#endif
diff --git a/ggml/src/vulkan-shaders/CMakeLists.txt b/ggml/src/vulkan-shaders/CMakeLists.txt
deleted file mode 100644 (file)
index 10075db..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-find_package (Threads REQUIRED)
-
-set(TARGET vulkan-shaders-gen)
-add_executable(${TARGET} vulkan-shaders-gen.cpp)
-install(TARGETS ${TARGET} RUNTIME)
-target_compile_features(${TARGET} PRIVATE cxx_std_11)
-target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
diff --git a/ggml/src/vulkan-shaders/acc.comp b/ggml/src/vulkan-shaders/acc.comp
deleted file mode 100644 (file)
index 4c8739e..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
-    const uint idx = gl_GlobalInvocationID.x;
-    if (idx >= p.ne) {
-        return;
-    }
-
-    const uint offset = p.param3;
-    const uint src1_i = idx - offset;
-    const uint oz = src1_i / p.nb02;
-    const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
-    const uint ox = src1_i % p.nb01;
-
-    if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
-        data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
-    } else {
-        data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
-    }
-}
-
diff --git a/ggml/src/vulkan-shaders/add.comp b/ggml/src/vulkan-shaders/add.comp
deleted file mode 100644 (file)
index 3974845..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
-}
diff --git a/ggml/src/vulkan-shaders/argsort.comp b/ggml/src/vulkan-shaders/argsort.comp
deleted file mode 100644 (file)
index d4fa45b..0000000
+++ /dev/null
@@ -1,69 +0,0 @@
-#version 450
-
-#include "types.comp"
-
-#define BLOCK_SIZE 1024
-#define ASC 0
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1)          buffer D {int data_d[];};
-
-layout (push_constant) uniform parameter {
-    uint ncols;
-    uint ncols_pad;
-    uint order;
-} p;
-
-shared int dst_row[BLOCK_SIZE];
-
-void swap(uint idx0, uint idx1) {
-    int tmp = dst_row[idx0];
-    dst_row[idx0] = dst_row[idx1];
-    dst_row[idx1] = tmp;
-}
-
-void main() {
-    // bitonic sort
-    const int col = int(gl_LocalInvocationID.x);
-    const uint row = gl_WorkGroupID.y;
-
-    const uint row_offset = row * p.ncols;
-
-    // initialize indices
-    if (col < p.ncols_pad) {
-        dst_row[col] = col;
-    }
-    barrier();
-
-    for (uint k = 2; k <= p.ncols_pad; k *= 2) {
-        for (uint j = k / 2; j > 0; j /= 2) {
-            const uint ixj = col ^ j;
-            if (col < p.ncols_pad && ixj > col) {
-                if ((col & k) == 0) {
-                    if (dst_row[col] >= p.ncols ||
-                        (dst_row[ixj] < p.ncols && (p.order == ASC ?
-                            data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
-                            data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
-                    ) {
-                        swap(col, ixj);
-                    }
-                } else {
-                    if (dst_row[ixj] >= p.ncols ||
-                        (dst_row[col] < p.ncols && (p.order == ASC ?
-                            data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
-                            data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
-                    ) {
-                        swap(col, ixj);
-                    }
-                }
-            }
-            barrier();
-        }
-    }
-
-    if (col < p.ncols) {
-        data_d[row_offset + col] = dst_row[col];
-    }
-}
diff --git a/ggml/src/vulkan-shaders/clamp.comp b/ggml/src/vulkan-shaders/clamp.comp
deleted file mode 100644 (file)
index ae8fa87..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
-}
diff --git a/ggml/src/vulkan-shaders/concat.comp b/ggml/src/vulkan-shaders/concat.comp
deleted file mode 100644 (file)
index c23b6eb..0000000
+++ /dev/null
@@ -1,39 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
-    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-    const int dim = p.param3;
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
-    const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
-    const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
-    const uint i2_offset = i2*p.ne21*p.ne20;
-    const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
-    const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
-
-    uint o[4] = {0, 0, 0, 0};
-    o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
-
-    const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
-    const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
-    const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
-
-    const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
-    data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
-#else
-    if (is_src0) {
-        data_d[p.d_offset + dst_idx] = data_a[src0_idx];
-    } else {
-        data_d[p.d_offset + dst_idx] = data_b[src1_idx];
-    }
-#endif
-}
diff --git a/ggml/src/vulkan-shaders/contig_copy.comp b/ggml/src/vulkan-shaders/contig_copy.comp
deleted file mode 100644 (file)
index 9acbdd3..0000000
+++ /dev/null
@@ -1,42 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-#extension GL_EXT_control_flow_attributes : require
-
-const uint num_threads = 128;
-
-layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    uint idx = get_idx();
-
-    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
-    const uint num_iter = 4;
-
-    // fast path for when all four iterations are in-bounds
-    if (idx + (num_iter-1)*num_threads < p.ne) {
-        [[unroll]] for (uint i = 0; i < num_iter; ++i) {
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
-            data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
-#else
-            data_d[p.d_offset + idx] = data_a[idx];
-#endif
-            idx += num_threads;
-        }
-    } else {
-        [[unroll]] for (uint i = 0; i < num_iter; ++i) {
-            if (idx >= p.ne) {
-                continue;
-            }
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
-            data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
-#else
-            data_d[p.d_offset + idx] = data_a[idx];
-#endif
-            idx += num_threads;
-        }
-    }
-}
diff --git a/ggml/src/vulkan-shaders/copy.comp b/ggml/src/vulkan-shaders/copy.comp
deleted file mode 100644 (file)
index 2775068..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
-#else
-    data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
-#endif
-}
diff --git a/ggml/src/vulkan-shaders/cos.comp b/ggml/src/vulkan-shaders/cos.comp
deleted file mode 100644 (file)
index fbd9d27..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
-}
diff --git a/ggml/src/vulkan-shaders/dequant_f32.comp b/ggml/src/vulkan-shaders/dequant_f32.comp
deleted file mode 100644 (file)
index a4d3fca..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.x * 16;
-
-    if (i >= p.nel) {
-        return;
-    }
-
-    [[unroll]] for (uint l = 0; l < 16; l++) {
-        data_b[i + l] = D_TYPE(data_a[i + l]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_funcs.comp b/ggml/src/vulkan-shaders/dequant_funcs.comp
deleted file mode 100644 (file)
index d5b9897..0000000
+++ /dev/null
@@ -1,68 +0,0 @@
-#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
-#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-#endif
-
-#if defined(DATA_A_F32)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
-}
-#endif
-
-#if defined(DATA_A_F16)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
-}
-#endif
-
-#if defined(DATA_A_Q4_0)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-}
-#endif
-
-#if defined(DATA_A_Q4_1)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const float m = float(data_a[a_offset + ib].m);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return vec2(vui & 0xF, vui >> 4) * d + m;
-}
-#endif
-
-#if defined(DATA_A_Q5_0)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
-    const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
-}
-#endif
-
-#if defined(DATA_A_Q5_1)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const float m = float(data_a[a_offset + ib].m);
-    const uint uint_qh = data_a[a_offset + ib].qh;
-    const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
-}
-#endif
-
-#if defined(DATA_A_Q8_0)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
-}
-#endif
-
-#if defined(DATA_A_IQ4_NL)
-vec2 dequantize(uint ib, uint iqs, uint a_offset) {
-    const float d = float(data_a[a_offset + ib].d);
-    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
-}
-#endif
diff --git a/ggml/src/vulkan-shaders/dequant_head.comp b/ggml/src/vulkan-shaders/dequant_head.comp
deleted file mode 100644 (file)
index 8d80643..0000000
+++ /dev/null
@@ -1,13 +0,0 @@
-#extension GL_EXT_control_flow_attributes : require
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint M;
-    uint K;
-    uint stride_a;
-    uint stride_b;
-    uint nel;
-} p;
-
-#include "types.comp"
diff --git a/ggml/src/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/vulkan-shaders/dequant_iq4_nl.comp
deleted file mode 100644 (file)
index 34ef3da..0000000
+++ /dev/null
@@ -1,30 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint q_idx = 8*il;
-    const uint b_idx = 1024*i + 32*ir + q_idx;
-
-    const float d = float(data_a[ib].d);
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        data_b[b_idx + l +  0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
-        data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >>  4]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q2_k.comp b/ggml/src/vulkan-shaders/dequant_q2_k.comp
deleted file mode 100644 (file)
index 157154a..0000000
+++ /dev/null
@@ -1,34 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = gl_WorkGroupID.x * 256 + wgy;
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-
-        const uint tid = gl_LocalInvocationID.x;
-        const uint ip = tid / 32;
-        const uint il = tid - 32 * ip;
-        const uint is = 8 * ip + il / 16;
-
-        const uint y_idx = i * QUANT_K + 128 * ip + il;
-
-        const uint ql_idx = 32 * ip + il;
-        const uint8_t qs = data_a[i].qs[32 * ip + il];
-
-        FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
-        FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-        data_b[y_idx +  0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
-        data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
-        data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
-        data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q3_k.comp b/ggml/src/vulkan-shaders/dequant_q3_k.comp
deleted file mode 100644 (file)
index c17dd0d..0000000
+++ /dev/null
@@ -1,42 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-
-        const uint r = gl_LocalInvocationID.x / 4;
-        const uint tid = r / 2;
-        const uint is0 = r % 2;
-        const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
-        const uint n = tid / 4;
-        const uint j = tid - 4*n;
-
-        const uint8_t m = uint8_t(1 << (4*n + j));
-        const uint is = 8*n + 2*j + is0;
-        const uint shift = 2*j;
-
-        const int8_t us = int8_t(is <  4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
-                                 is <  8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
-                                 is < 12 ? (data_a[i].scales[is-8] >>  4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
-                                           (data_a[i].scales[is-8] >>  4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
-        const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
-        const FLOAT_TYPE dl    = d_all * FLOAT_TYPE(us - 32);
-
-        const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
-        const uint qs_idx = 32*n;
-
-        for (uint l = l0; l < l0 + 4; ++l) {
-            data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
-        }
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q4_0.comp b/ggml/src/vulkan-shaders/dequant_q4_0.comp
deleted file mode 100644 (file)
index 4081853..0000000
+++ /dev/null
@@ -1,30 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint q_idx = 8*il;
-    const uint b_idx = 1024*i + 32*ir + q_idx;
-
-    const float d = float(data_a[ib].d);
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        data_b[b_idx + l +  0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f));
-        data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >>  4) - 8.0f));
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q4_1.comp b/ggml/src/vulkan-shaders/dequant_q4_1.comp
deleted file mode 100644 (file)
index 2f27eee..0000000
+++ /dev/null
@@ -1,32 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 8*il;
-
-    const float d = float(data_a[ib].d);
-    const float m = float(data_a[ib].m);
-
-    const uint q_idx = 8*il;
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        data_b[b_idx + l +  0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
-        data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >>  4) + m);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q4_k.comp b/ggml/src/vulkan-shaders/dequant_q4_k.comp
deleted file mode 100644 (file)
index 92acb75..0000000
+++ /dev/null
@@ -1,56 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = gl_WorkGroupID.x * 256 + wgy;
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-
-        const uint tid = gl_LocalInvocationID.x;
-        const uint il = tid / 8;
-        const uint ir = tid % 8;
-        const uint is = 2 * il;
-        const uint n = 4;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-
-        const uint y_idx = i * QUANT_K + 64 * il + n * ir;
-        const uint qs_idx = 32*il + n * ir;
-
-        uint8_t sc;
-        uint8_t m;
-        if (is < 4) {
-            sc = uint8_t(data_a[i].scales[is] & 63);
-            m  = uint8_t(data_a[i].scales[is + 4] & 63);
-        } else {
-            sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
-            m  = uint8_t((data_a[i].scales[is + 4] >>  4) | ((data_a[i].scales[is    ] >> 6) << 4));
-        }
-        const FLOAT_TYPE d1 = dall * sc;
-        const FLOAT_TYPE m1 = dmin * m;
-
-        if (is < 4) {
-            sc = uint8_t(data_a[i].scales[is + 1] & 63);
-            m  = uint8_t(data_a[i].scales[is + 5] & 63);
-        } else {
-            sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
-            m  = uint8_t((data_a[i].scales[is + 5] >>  4) | ((data_a[i].scales[is + 1] >> 6) << 4));
-        }
-        const FLOAT_TYPE d2 = dall * sc;
-        const FLOAT_TYPE m2 = dmin * m;
-
-        [[unroll]] for (uint l = 0; l < n; ++l) {
-            data_b[y_idx + l     ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
-            data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >>  4) - m2);
-        }
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q5_0.comp b/ggml/src/vulkan-shaders/dequant_q5_0.comp
deleted file mode 100644 (file)
index b20b805..0000000
+++ /dev/null
@@ -1,34 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 8*il;
-
-    const float d = float(data_a[ib].d);
-    const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
-
-    const uint q_idx = 8*il;
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        const uint iqs = q_idx + l;
-        const uint vui = uint(data_a[ib].qs[iqs]);
-        data_b[b_idx + l +  0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
-        data_b[b_idx + l + 16] = D_TYPE(d * (((vui >>  4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q5_1.comp b/ggml/src/vulkan-shaders/dequant_q5_1.comp
deleted file mode 100644 (file)
index dc59fe3..0000000
+++ /dev/null
@@ -1,35 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 8*il;
-
-    const float d = float(data_a[ib].d);
-    const float m = float(data_a[ib].m);
-    const uint qh = data_a[ib].qh;
-
-    const uint q_idx = 8*il;
-
-    [[unroll]] for (uint l = 0; l < 8; ++l) {
-        const uint iqs = q_idx + l;
-        const uint vui = uint(data_a[ib].qs[iqs]);
-        data_b[b_idx + l +  0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
-        data_b[b_idx + l + 16] = D_TYPE(d * (((vui >>  4) | ((qh >> (iqs + 12)) & 0x10))) + m);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q5_k.comp b/ggml/src/vulkan-shaders/dequant_q5_k.comp
deleted file mode 100644 (file)
index f314a76..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = gl_WorkGroupID.x * 256 + wgy;
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-
-        const uint tid = gl_LocalInvocationID.x;
-        const uint il = tid / 16;
-        const uint ir = tid % 16;
-        const uint is = 2 * il;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
-
-        const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
-        const uint qs_idx = 32*il + 2 * ir;
-        const uint qh_idx = 2 * ir;
-
-        uint8_t sc;
-        uint8_t m;
-        if (is < 4) {
-            sc = uint8_t(data_a[i].scales[is] & 63);
-            m  = uint8_t(data_a[i].scales[is + 4] & 63);
-        } else {
-            sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
-            m  = uint8_t((data_a[i].scales[is + 4] >>  4) | ((data_a[i].scales[is    ] >> 6) << 4));
-        }
-        const FLOAT_TYPE d1 = dall * sc;
-        const FLOAT_TYPE m1 = dmin * m;
-
-        if (is < 4) {
-            sc = uint8_t(data_a[i].scales[is + 1] & 63);
-            m  = uint8_t(data_a[i].scales[is + 5] & 63);
-        } else {
-            sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
-            m  = uint8_t((data_a[i].scales[is + 5] >>  4) | ((data_a[i].scales[is + 1] >> 6) << 4));
-        }
-        const FLOAT_TYPE d2 = dall * sc;
-        const FLOAT_TYPE m2 = dmin * m;
-
-        const uint8_t hm1 = uint8_t(1 << (2 * il    ));
-        const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
-        data_b[y_idx     ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx    ] & 0xF) + (((data_a[i].qh[qh_idx    ] & hm1) != 0) ? 16 : 0)) - m1);
-        data_b[y_idx +  1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
-        data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx    ]  >> 4) + (((data_a[i].qh[qh_idx    ] & hm2) != 0) ? 16 : 0)) - m2);
-        data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1]  >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q6_k.comp b/ggml/src/vulkan-shaders/dequant_q6_k.comp
deleted file mode 100644 (file)
index 0b91317..0000000
+++ /dev/null
@@ -1,33 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
-        const uint i = gl_WorkGroupID.x * 256 + wgy;
-        if (i >= p.M * p.K / QUANT_K) {
-            return;
-        }
-        const uint tid = gl_LocalInvocationID.x;
-        const uint ip = tid / 32;
-        const uint il = tid - 32 * ip;
-        const uint is = 8 * ip + il / 16;
-
-        const uint y_idx = i * QUANT_K + 128 * ip + il;
-
-        const uint ql_idx = 64 * ip + il;
-        const uint8_t qh = data_a[i].qh[32 * ip + il];
-
-        const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
-
-        data_b[y_idx +  0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx +  0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
-        data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
-        data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx +  0] >>  4) | (((qh >> 4) & 3) << 4)) - 32)));
-        data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >>  4) | (((qh >> 6) & 3) << 4)) - 32)));
-    }
-}
diff --git a/ggml/src/vulkan-shaders/dequant_q8_0.comp b/ggml/src/vulkan-shaders/dequant_q8_0.comp
deleted file mode 100644 (file)
index bd1344a..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-#version 450
-
-#include "dequant_head.comp"
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
-
-    const uint tid = gl_LocalInvocationID.x % 64;
-    const uint il  = tid/32;
-    const uint ir  = tid%32;
-    const uint ib = 32*i + ir;
-    if (ib >= p.nel / 32) {
-        return;
-    }
-
-    const uint b_idx = 1024*i + 32*ir + 16*il;
-
-    const float d = float(data_a[ib].d);
-
-    const uint q_idx = 16*il;
-
-    [[unroll]] for (uint l = 0; l < 16; l += 2) {
-        data_b[b_idx + l    ] = D_TYPE(d * data_a[ib].qs[q_idx + l    ]);
-        data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/diag_mask_inf.comp b/ggml/src/vulkan-shaders/diag_mask_inf.comp
deleted file mode 100644 (file)
index 4e68742..0000000
+++ /dev/null
@@ -1,34 +0,0 @@
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_control_flow_attributes : enable
-
-layout (push_constant) uniform parameter
-{
-    uint ncols;
-    uint rows_per_channel;
-    uint n_past;
-} p;
-
-#include "types.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint col = gl_GlobalInvocationID.y;
-    const uint row = gl_GlobalInvocationID.x;
-
-    if (col >= p.ncols) {
-        return;
-    }
-
-    const uint i = row*p.ncols + col;
-    if (col > p.n_past + row % p.rows_per_channel) {
-        data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));
-    } else {
-        data_d[i] = D_TYPE(data_a[i]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/div.comp b/ggml/src/vulkan-shaders/div.comp
deleted file mode 100644 (file)
index 8cfce58..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
-}
diff --git a/ggml/src/vulkan-shaders/gelu.comp b/ggml/src/vulkan-shaders/gelu.comp
deleted file mode 100644 (file)
index 4cc7a68..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const float GELU_COEF_A    = 0.044715f;
-    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    const float xi = float(data_a[i]);
-    const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
-    data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
-}
diff --git a/ggml/src/vulkan-shaders/gelu_quick.comp b/ggml/src/vulkan-shaders/gelu_quick.comp
deleted file mode 100644 (file)
index e6e6fcf..0000000
+++ /dev/null
@@ -1,23 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const float GELU_QUICK_COEF = -1.702f;
-    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    const float x = float(data_a[i]);
-    data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
-}
diff --git a/ggml/src/vulkan-shaders/generic_binary_head.comp b/ggml/src/vulkan-shaders/generic_binary_head.comp
deleted file mode 100644 (file)
index b6beaff..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint ne;
-    uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
-    uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
-    uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
-    uint d_offset;
-    float param1; float param2; int param3;
-} p;
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-uint get_idx() {
-    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-}
-
-uint src0_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
-    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
-    const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
-}
-
-uint src1_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
-    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
-    const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-
-    return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
-}
-
-uint dst_idx(uint idx) {
-    const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
-    const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
-    const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
-    const uint i22_offset = i22*p.ne21*p.ne20;
-    const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
-    const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
-    return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
-}
diff --git a/ggml/src/vulkan-shaders/generic_head.comp b/ggml/src/vulkan-shaders/generic_head.comp
deleted file mode 100644 (file)
index 66e46ae..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint KX;
-    uint KY;
-    float param1;
-    float param2;
-} p;
diff --git a/ggml/src/vulkan-shaders/generic_unary_head.comp b/ggml/src/vulkan-shaders/generic_unary_head.comp
deleted file mode 100644 (file)
index 4e1fa3a..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_control_flow_attributes : require
-
-layout (push_constant) uniform parameter
-{
-    uint ne;
-    uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
-    uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
-    uint d_offset;
-    float param1; float param2;
-} p;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-uint get_idx() {
-    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-}
-
-uint src0_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
-    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
-    const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
-}
-
-uint dst_idx(uint idx) {
-    const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
-    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
-    const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
-    const uint i12_offset = i12*p.ne11*p.ne10;
-    const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
-    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
-    return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
-}
diff --git a/ggml/src/vulkan-shaders/get_rows.comp b/ggml/src/vulkan-shaders/get_rows.comp
deleted file mode 100644 (file)
index e9ff22e..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
-    const uint i00 = gl_GlobalInvocationID.x;
-    const uint i10 = gl_GlobalInvocationID.y;
-    const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
-    const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
-
-    if (i00 >= p.ne00) {
-        return;
-    }
-
-    const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
-
-    const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
-    const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
-
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
-    data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
-#else
-    data_d[d_offset + i00] = data_a[a_offset + i00];
-#endif
-}
diff --git a/ggml/src/vulkan-shaders/get_rows_quant.comp b/ggml/src/vulkan-shaders/get_rows_quant.comp
deleted file mode 100644 (file)
index 53a9a96..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-#include "dequant_funcs.comp"
-
-void main() {
-    const uint i00 = (gl_GlobalInvocationID.x)*2;
-    const uint i10 = gl_GlobalInvocationID.y;
-    const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
-    const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
-
-    if (i00 >= p.ne00) {
-        return;
-    }
-
-    const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
-
-    const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
-    const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
-
-    const uint ib = a_offset + i00/QUANT_K; // block index
-    const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
-    const uint iybs = i00 - i00%QUANT_K; // dst block start index
-    const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
-
-    vec2 v = dequantize(ib, iqs, 0);
-
-    data_d[d_offset + iybs + iqs           ] = D_TYPE(v.x);
-    data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
-}
diff --git a/ggml/src/vulkan-shaders/group_norm.comp b/ggml/src/vulkan-shaders/group_norm.comp
deleted file mode 100644 (file)
index 5ad9b28..0000000
+++ /dev/null
@@ -1,66 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared float tmp[BLOCK_SIZE];
-
-void main() {
-    const uint group_size = p.KX;
-    const float eps = p.param1;
-
-    const uint tid = gl_LocalInvocationID.x;
-    const uint start = gl_WorkGroupID.x * group_size + tid;
-    const uint end = start + group_size;
-
-    tmp[tid] = 0.0f;
-
-    // Calculate mean
-    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
-        tmp[tid] += float(data_a[col]);
-    }
-
-    // tmp up partial tmps and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-
-    const float mean = tmp[0] / group_size;
-    barrier();
-    tmp[tid] = 0.0f;
-
-    // Calculate variance
-    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
-        const float xi = float(data_a[col]) - mean;
-        data_d[col] = D_TYPE(xi);
-        tmp[tid] += xi * xi;
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-
-    const float variance = tmp[0] / group_size;
-    const float scale = inversesqrt(variance + eps);
-
-    [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
-        data_d[col] *= D_TYPE(scale);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/im2col.comp b/ggml/src/vulkan-shaders/im2col.comp
deleted file mode 100644 (file)
index 4d48610..0000000
+++ /dev/null
@@ -1,57 +0,0 @@
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint batch_offset; uint offset_delta;
-    uint IC;
-    uint IW; uint IH;
-    uint OW; uint OH;
-    uint KW; uint KH;
-    uint pelements;
-    uint CHW;
-    int s0; int s1;
-    int p0; int p1;
-    int d0; int d1;
-} p;
-
-#include "types.comp"
-
-#define BLOCK_SIZE 256
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.x;
-    if (i >= p.pelements) {
-        return;
-    }
-
-    const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
-    const uint kx = i / ksize;
-    const uint kd = kx * ksize;
-    const uint ky = (i - kd) / p.OW;
-    const uint ix = i % p.OW;
-
-    const uint oh = gl_GlobalInvocationID.y;
-    const uint batch = gl_GlobalInvocationID.z / p.IC;
-    const uint ic = gl_GlobalInvocationID.z % p.IC;
-
-    const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
-    const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
-
-    const uint offset_dst =
-        ((batch * p.OH + oh) * p.OW + ix) * p.CHW +
-        (ic * (p.KW * p.KH) + ky * p.KW + kx);
-
-    if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
-        data_d[offset_dst] = D_TYPE(0.0f);
-    } else {
-        const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
-        data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/leaky_relu.comp b/ggml/src/vulkan-shaders/leaky_relu.comp
deleted file mode 100644 (file)
index d90a99a..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    const float val = float(data_a[i]);
-    data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
-}
diff --git a/ggml/src/vulkan-shaders/mul.comp b/ggml/src/vulkan-shaders/mul.comp
deleted file mode 100644 (file)
index bfb61c9..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_binary_head.comp"
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp b/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp
deleted file mode 100644 (file)
index 825b910..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {float data_a[];};
-layout (binding = 1) writeonly buffer D {float data_d[];};
-
-layout (push_constant) uniform parameter {
-    uint ne;
-    uint k_num;
-} p;
-
-void main() {
-    const uint idx = gl_GlobalInvocationID.x;
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    float result = 0.0f;
-
-    [[unroll]] for (uint i = 0; i < p.k_num; i++) {
-        result += data_a[i * p.ne + idx];
-    }
-
-    data_d[idx] = result;
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec.comp b/ggml/src/vulkan-shaders/mul_mat_vec.comp
deleted file mode 100644 (file)
index d3ccba7..0000000
+++ /dev/null
@@ -1,56 +0,0 @@
-#version 450
-
-#ifdef FLOAT16
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
-#endif
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-    const uint tid = gl_LocalInvocationID.x;
-
-    // There are not enough cols to use all threads
-    if (tid >= p.ncols) {
-        return;
-    }
-
-    const uint block_size = min(p.ncols, BLOCK_SIZE);
-
-    uint a_offset, b_offset, d_offset;
-    get_offsets(a_offset, b_offset, d_offset);
-
-    const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
-
-    tmp[tid] = FLOAT_TYPE(0.0f);
-
-    [[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
-        const uint col = i*block_size + 2*tid;
-        const uint ib = (row*p.ncols + col)/QUANT_K; // block index
-        const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
-        const uint iybs = col - col%QUANT_K; // y block start index
-
-        vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
-
-        // matrix multiplication
-        tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid]));
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp
deleted file mode 100644 (file)
index 5920bc9..0000000
+++ /dev/null
@@ -1,81 +0,0 @@
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_8bit_storage : require
-
-#define K_QUANTS_PER_ITERATION 2
-
-#ifdef MUL_MAT_ID
-#define EXPERT_COUNT 8
-#endif
-
-#include "types.comp"
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-#ifdef MUL_MAT_ID
-layout (binding = 3) readonly buffer IDS {int data_ids[];};
-#endif
-
-#include "dequant_funcs.comp"
-
-layout (push_constant) uniform parameter
-{
-    uint ncols;
-    uint stride_a;
-    uint stride_b;
-    uint stride_d;
-
-    uint batch_stride_a;
-    uint batch_stride_b;
-    uint batch_stride_d;
-
-#ifdef MUL_MAT_ID
-    uint nei0;
-    uint ne11;
-#else
-    uint ne02;
-    uint ne12;
-    uint broadcast2;
-    uint broadcast3;
-#endif
-} p;
-
-void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
-#ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.y;
-#else
-    const uint batch_idx = gl_GlobalInvocationID.y;
-#endif
-
-#ifndef MUL_MAT_ID
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-#else
-    const uint expert_id = data_ids[expert_idx];
-#endif
-
-    a_offset =
-#ifdef MUL_MAT_ID
-            expert_id * p.batch_stride_a;
-#else
-            batch_idx_a * p.batch_stride_a;
-#endif
-    b_offset =
-#ifdef MUL_MAT_ID
-            (expert_idx % p.ne11) * p.stride_b;
-#else
-            batch_idx * p.batch_stride_b;
-#endif
-    d_offset =
-#ifdef MUL_MAT_ID
-            expert_idx * p.stride_d;
-#else
-            batch_idx * p.batch_stride_d;
-#endif
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp
deleted file mode 100644 (file)
index 1cc4996..0000000
+++ /dev/null
@@ -1,71 +0,0 @@
-#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#define BLOCK_SIZE 32
-#define FLOAT_TYPE float
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (push_constant) uniform parameter
-{
-    uint ncols_x;
-    uint nrows_x;
-    uint row_stride_x;
-    uint channel_stride_x;
-    uint channel_x_divisor;
-    uint b_offset;
-    uint d_offset;
-} p;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint tid       = gl_LocalInvocationID.x;
-    const uint row_x     = gl_GlobalInvocationID.y;
-    const uint channel   = gl_GlobalInvocationID.z;
-    const uint channel_x = channel / p.channel_x_divisor;
-
-    const uint nrows_y   = p.ncols_x;
-    const uint nrows_dst = p.nrows_x;
-    const uint row_dst   = row_x;
-
-    const uint idst = channel*nrows_dst + row_dst;
-
-    tmp[tid] = 0.0f;
-
-    for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
-        const uint col_x = col_x0 + tid;
-
-        if (col_x >= p.ncols_x) {
-            break;
-        }
-
-        const uint row_y = col_x;
-
-        const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
-        const uint iy = channel*nrows_y + row_y;
-
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
-
-        tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-
-    if (tid == 0) {
-        dst[idst] = tmp[0];
-    }
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp
deleted file mode 100644 (file)
index 9b44380..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#define BLOCK_SIZE 32
-#define FLOAT_TYPE float
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (push_constant) uniform parameter
-{
-    uint ncols_x;
-    uint nrows_x;
-    uint nchannels_x;
-    uint nchannels_y;
-    uint b_offset;
-    uint d_offset;
-} p;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint tid = gl_LocalInvocationID.x;
-    const uint row_x = gl_GlobalInvocationID.y;
-    const uint channel = gl_GlobalInvocationID.z;
-    const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
-
-    const uint nrows_y = p.ncols_x;
-    const uint nrows_dst = p.nrows_x;
-    const uint row_dst = row_x;
-
-    tmp[tid] = FLOAT_TYPE(0.0f);
-
-    for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
-        const uint col_x = col_x0 + tid;
-
-        if (col_x >= p.ncols_x) {
-            break;
-        }
-
-        // x is transposed and permuted
-        const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
-
-        const uint row_y = col_x;
-
-        // y is not transposed but permuted
-        const uint iy = channel*nrows_y + row_y;
-
-        tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
-    }
-
-    // dst is not transposed and not permuted
-    const uint idst = channel*nrows_dst + row_dst;
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-
-    if (tid == 0) {
-        dst[idst] = tmp[0];
-    }
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp
deleted file mode 100644 (file)
index ec8eadc..0000000
+++ /dev/null
@@ -1,74 +0,0 @@
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    uint a_offset, b_offset, d_offset;
-    get_offsets(a_offset, b_offset, d_offset);
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
-
-    const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
-
-    const uint l0 = K_QUANTS_PER_ITERATION*v_in;            // 0...15
-    const uint q_offset = 32*v_im + l0;
-    const uint s_offset = 8*v_im;
-    const uint y_offset = 128*v_im + l0;
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-        const uint y_idx = i * QUANT_K + y_offset;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
-        FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
-        FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3), sum1))))))));
-            sum2 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF),
-                   fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF), sum2))))))));
-        }
-        const uint tmp_idx = 16 * ix + tid;
-        tmp[tmp_idx] = fma(dall, sum1, fma(-dmin, sum2, tmp[tmp_idx]));
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp
deleted file mode 100644 (file)
index 3ca4ad8..0000000
+++ /dev/null
@@ -1,67 +0,0 @@
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    uint a_offset, b_offset, d_offset;
-    get_offsets(a_offset, b_offset, d_offset);
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
-
-    const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
-
-    const uint8_t m = uint8_t(1 << (4 * v_im));
-
-    const uint l0 = K_QUANTS_PER_ITERATION*v_in;            // 0...15
-    const uint q_offset = 32*v_im + l0;
-    const uint y_offset = 128*v_im + l0;
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    const uint s_shift = 4 * v_im;
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-        const uint y_idx = i * QUANT_K + y_offset;
-
-        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-        FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +  0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
-        }
-        const uint tmp_idx = 16 * ix + tid;
-        tmp[tmp_idx] = fma(d, sum, tmp[tmp_idx]);
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp
deleted file mode 100644 (file)
index d91e00e..0000000
+++ /dev/null
@@ -1,118 +0,0 @@
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    uint a_offset, b_offset, d_offset;
-    get_offsets(a_offset, b_offset, d_offset);
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const uint step = 8/K_QUANTS_PER_ITERATION;             // 8 or 4
-
-    const uint il = tid/step;                               // 0...3
-    const uint ir = tid - step*il;                          // 0...7 or 0...3
-    const uint n =  2 * K_QUANTS_PER_ITERATION;             // 2 or 4
-
-    const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const uint v_in = il % 2;
-
-    const uint l0 = n * (2 * ir + v_in);            // 0...15
-    const uint q_offset = 32*v_im + l0;
-    const uint y_offset = 64*v_im + l0;
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-        const uint y1_idx = i * QUANT_K + y_offset;
-        const uint y2_idx = y1_idx + 128;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
-        const uint8_t sc0 = uint8_t(  data_a[ib0 + i].scales[v_im * 2    ]       & 0x3f);
-        const uint8_t sc1 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 1]       & 0x3f);
-        const uint8_t sc2 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 4]       & 0x3f);
-        const uint8_t sc3 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 5]       & 0x3f);
-        const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2    ] & 0xc0) >> 2));
-        const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
-        const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
-        const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
-#if K_QUANTS_PER_ITERATION == 2
-        const uint8_t q4_0  = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
-        const uint8_t q4_1  = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
-        const uint8_t q4_2  = uint8_t(data_a[ib0 + i].qs[q_offset +  2] & 0xf);
-        const uint8_t q4_3  = uint8_t(data_a[ib0 + i].qs[q_offset +  3] & 0xf);
-        const uint8_t q4_4  = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
-        const uint8_t q4_5  = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
-        const uint8_t q4_6  = uint8_t(data_a[ib0 + i].qs[q_offset +  2]  >> 4);
-        const uint8_t q4_7  = uint8_t(data_a[ib0 + i].qs[q_offset +  3]  >> 4);
-        const uint8_t q4_8  = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
-        const uint8_t q4_9  = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
-        const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
-        const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
-        const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
-        const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
-        const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66]  >> 4);
-        const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67]  >> 4);
-
-        const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]),      q4_0,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]),  q4_1,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]),  q4_2,  FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) *  q4_3)));
-        const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5,  fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6,  FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7)));
-        const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]),      q4_8,  fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]),  q4_9,  fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]),  q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) *  q4_11)));
-        const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15)));
-        const FLOAT_TYPE smin =
-            fma(FLOAT_TYPE(data_b[b_offset + y1_idx    ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx    ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
-            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7,
-            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7,
-            fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6,     FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7)))))))))))))));
-        const uint tmp_idx = 16 * ix + tid;
-        tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
-#else
-        const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
-        const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
-        const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
-        const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
-        const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
-        const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
-        const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
-        const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
-
-        const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * q4_1);
-        const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
-        const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * q4_5);
-        const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
-        const FLOAT_TYPE smin =
-            fma(FLOAT_TYPE(data_b[b_offset + y1_idx    ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx    ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
-          + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7)))))));
-
-        tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) +
-                        sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
-        const uint tmp_idx = 16 * ix + tid;
-        tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f),
-                       fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx]));
-#endif
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp
deleted file mode 100644 (file)
index 2306785..0000000
+++ /dev/null
@@ -1,109 +0,0 @@
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    uint a_offset, b_offset, d_offset;
-    get_offsets(a_offset, b_offset, d_offset);
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/2;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%2;  // 0 or 0, 1
-
-    const uint il = tid/4;                           // 0...3
-    const uint ir = tid - 4*il;                      // 0...7 or 0...3
-
-    const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const uint v_in = il % 2;
-
-    const uint l0 = 4*ir + 2*v_in;                   // 0...15
-    const uint q_offset = 32*v_im + l0;
-    const uint y_offset = 64*v_im + l0;
-
-    const uint8_t hm1 = uint8_t(1 << (2*v_im));
-    const uint8_t hm2 = uint8_t(hm1 << 4);
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
-        const uint y1_idx = i * QUANT_K + y_offset;
-        const uint y2_idx = y1_idx + 128;
-
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
-
-        const uint8_t sc0 = uint8_t(  data_a[ib0 + i].scales[v_im * 2    ]       & 0x3f);
-        const uint8_t sc1 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 1]       & 0x3f);
-        const uint8_t sc2 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 4]       & 0x3f);
-        const uint8_t sc3 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 5]       & 0x3f);
-        const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2    ] & 0xc0) >> 2));
-        const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
-        const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
-        const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
-
-        const uint8_t q4_0  = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
-        const uint8_t q4_1  = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
-        const uint8_t q4_2  = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
-        const uint8_t q4_3  = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
-        const uint8_t q4_4  = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
-        const uint8_t q4_5  = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
-        const uint8_t q4_6  = uint8_t(data_a[ib0 + i].qs[q_offset + 16]  >> 4);
-        const uint8_t q4_7  = uint8_t(data_a[ib0 + i].qs[q_offset + 17]  >> 4);
-        const uint8_t q4_8  = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
-        const uint8_t q4_9  = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
-        const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
-        const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
-        const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
-        const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
-        const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80]  >> 4);
-        const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81]  >> 4);
-
-        const FLOAT_TYPE sx =
-          fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]), (q4_0 + (((data_a[ib0 + i].qh[l0     ] & hm1) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(data_b[b_offset + y1_idx +  1]), (q4_1 + (((data_a[ib0 + i].qh[l0 +  1] & hm1) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
-             FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
-        const FLOAT_TYPE sy =
-          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0     ] & (hm1 << 1)) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 +  1] & (hm1 << 1)) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
-             FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
-        const FLOAT_TYPE sz =
-          fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]), (q4_8  + (((data_a[ib0 + i].qh[l0     ] & hm2) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(data_b[b_offset + y2_idx +  1]), (q4_9  + (((data_a[ib0 + i].qh[l0 +  1] & hm2) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
-             FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
-        const FLOAT_TYPE sw =
-          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0     ] & (hm2 << 1)) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 +  1] & (hm2 << 1)) != 0) ? 16 : 0)),
-          fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
-             FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
-        const FLOAT_TYPE smin =
-          fma(FLOAT_TYPE(data_b[b_offset + y1_idx     ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
-          fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
-          fma(FLOAT_TYPE(data_b[b_offset + y2_idx     ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
-              (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
-        const uint tmp_idx = 16 * ix + tid;
-        tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp
deleted file mode 100644 (file)
index 95c286e..0000000
+++ /dev/null
@@ -1,79 +0,0 @@
-#version 450
-
-#include "mul_mat_vec_base.comp"
-
-layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
-
-shared FLOAT_TYPE tmp[32];
-
-void main() {
-    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
-
-    uint a_offset, b_offset, d_offset;
-    get_offsets(a_offset, b_offset, d_offset);
-
-    const uint num_blocks_per_row = p.ncols / QUANT_K;
-    const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
-
-    const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
-    const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
-
-    const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
-
-    const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
-    const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
-
-#if K_QUANTS_PER_ITERATION == 1
-    const uint l0 = v_in;                                   // 0...15
-    const uint is = 0;
-#else
-    const uint l0 = 4 * v_in;                               // 0, 4, 8, ..., 28
-    const uint is = v_in / 4;
-#endif
-
-    const uint ql_offset = 64*v_im + l0;
-    const uint qh_offset = 32*v_im + l0;
-    const uint s_offset  =  8*v_im + is;
-    const uint y_offset = 128*v_im + l0;
-
-    tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
-
-    [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
-        const uint y_idx   = i * QUANT_K + y_offset;
-
-        const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
-
-#if K_QUANTS_PER_ITERATION == 1
-        const uint tmp_idx = 16 * ix + tid;
-        tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x03) << 4)) - 32),
-                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
-                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x0c) << 2)) - 32),
-                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
-                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x30) >> 0)) - 32),
-                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
-                       fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0xc0) >> 2)) - 32),
-                       fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
-#else
-        FLOAT_TYPE sum = FLOAT_TYPE(0.0);
-        [[unroll]] for (int l = 0; l < 4; ++l) {
-            sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
-                  fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
-        }
-        tmp[16 * ix + tid] += sum;
-#endif
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-       }
-        barrier();
-    }
-    if (tid == 0) {
-        data_d[d_offset + row] = D_TYPE(tmp[0]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp
deleted file mode 100644 (file)
index fffdd18..0000000
+++ /dev/null
@@ -1,508 +0,0 @@
-#version 450
-
-#extension GL_EXT_control_flow_attributes : enable
-#extension GL_EXT_shader_16bit_storage : require
-
-#ifdef FLOAT16
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
-#endif
-
-#ifdef MUL_MAT_ID
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#endif
-
-#include "types.comp"
-
-#ifndef LOAD_VEC_A
-#define LOAD_VEC_A 1
-#endif
-#ifndef LOAD_VEC_B
-#define LOAD_VEC_B 1
-#endif
-
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-#ifdef MUL_MAT_ID
-layout (binding = 3) readonly buffer IDS {int data_ids[];};
-#endif
-
-layout (push_constant) uniform parameter
-{
-    uint M;
-    uint N;
-    uint K;
-    uint stride_a;
-    uint stride_b;
-    uint stride_d;
-
-    uint batch_stride_a;
-    uint batch_stride_b;
-    uint batch_stride_d;
-
-#ifdef MUL_MAT_ID
-    uint nei0;
-    uint nei1;
-    uint nbi1;
-    uint ne11;
-#else
-    uint k_split;
-    uint ne02;
-    uint ne12;
-    uint broadcast2;
-    uint broadcast3;
-#endif
-} p;
-
-layout (constant_id = 1) const uint BM = 64;
-layout (constant_id = 2) const uint BN = 64;
-layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant
-layout (constant_id = 4) const uint WM = 32;
-layout (constant_id = 5) const uint WN = 32;
-layout (constant_id = 6) const uint WMITER = 2;
-layout (constant_id = 7) const uint TM = 4;
-layout (constant_id = 8) const uint TN = 2;
-layout (constant_id = 9) const uint WARP = 32;
-
-shared FLOAT_TYPE buf_a[BM * (BK+1)];
-shared FLOAT_TYPE buf_b[BN * (BK+1)];
-
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[3072];
-#endif
-
-void main() {
-#ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
-#else
-    const uint batch_idx = gl_GlobalInvocationID.z;
-
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
-
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
-
-    const uint batch_idx_a = i03 * p.ne02 + i02;
-#endif
-
-    const uint blocks_m = (p.M + BM - 1) / BM;
-    const uint ir = gl_WorkGroupID.x % blocks_m;
-    const uint ik = gl_WorkGroupID.x / blocks_m;
-    const uint ic = gl_WorkGroupID.y;
-
-    const uint warp_i = gl_LocalInvocationID.x / WARP;
-    const uint warp_r = warp_i % (BM / WM);
-    const uint warp_c = warp_i / (BM / WM);
-
-    const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
-    const uint WSUBM = WM / WMITER;
-    const uint WSUBN = WN / WNITER;
-
-    const uint tiw = gl_LocalInvocationID.x % WARP;
-    const uint tiwr = tiw % (WSUBM / TM);
-    const uint tiwc = tiw / (WSUBM / TM);
-
-    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
-    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
-    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
-    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
-
-    const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
-    const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
-
-#ifdef MUL_MAT_ID
-    uint _ne1 = 0;
-    for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
-        for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
-            if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
-                row_ids[_ne1] = u16vec2(ii0, ii1);
-                _ne1++;
-            }
-        }
-    }
-
-    barrier();
-
-    // Workgroup has no work
-    if (ic * BN >= _ne1) return;
-#endif
-
-#ifdef MUL_MAT_ID
-    const uint start_k = 0;
-    const uint end_k = p.K;
-#else
-    const uint start_k = ik * p.k_split;
-    const uint end_k = min(p.K, (ik + 1) * p.k_split);
-#endif
-
-    uint pos_a = (
-#ifdef MUL_MAT_ID
-        expert_idx * p.batch_stride_a +
-#else
-        batch_idx_a * p.batch_stride_a +
-#endif
-        ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
-#ifdef MUL_MAT_ID
-    uint pos_b = 0;
-#else
-    uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
-#endif
-
-    float sums[WMITER * TM * WNITER * TN];
-    FLOAT_TYPE cache_a[WMITER * TM];
-    FLOAT_TYPE cache_b[WNITER * TN];
-
-    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
-        sums[i] = 0.0f;
-    }
-
-    [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
-        [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
-
-#if defined(DATA_A_F32) || defined(DATA_A_F16)
-#if LOAD_VEC_A == 8
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-            buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx][0].x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
-            buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
-            buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
-            buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
-            buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
-            buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
-            buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
-#elif LOAD_VEC_A == 4
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-            buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx].x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
-            buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
-            buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
-#else
-            if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
-                buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
-            } else {
-                buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
-            }
-#endif
-#elif defined(DATA_A_Q4_0)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q4_1)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const float m = float(data_a[ib].m);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q5_0)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
-            const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q5_1)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const float m = float(data_a[ib].m);
-            const uint uint_qh = data_a[ib].qh;
-            const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q8_0)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 16;
-            const uint iqs = (idx & 0xF) * 2;
-
-            const float d = float(data_a[ib].d);
-            const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q2_K)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                         // 2 values per idx
-            const uint iqs = idx % 128;                        // 0..127
-
-            const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
-            const uint scalesi = iqs / 8;                      // 0..15
-            const uint qsshift = ((iqs % 64) / 16) * 2;        // 0,2,4,6
-
-            const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
-            const uint scales = data_a[ib].scales[scalesi];
-            const vec2 d = vec2(data_a[ib].d);
-
-            const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
-#elif defined(DATA_A_Q3_K)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                   // 2 values per idx
-            const uint iqs = idx % 128;                  // 0..127
-
-            const uint n = iqs / 64;                     // 0,1
-            const uint qsi = n * 32 + (iqs % 16) * 2;    // 0,2,4..62
-            const uint hmi =          (iqs % 16) * 2;    // 0,2,4..30
-            const uint j = (iqs % 64) / 4;               // 0..3
-            const uint is = iqs / 8;                     // 0..15
-            const uint halfsplit = ((iqs % 64) / 16);    // 0,1,2,3
-            const uint qsshift = halfsplit * 2;          // 0,2,4,6
-            const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128
-
-            const int8_t us = int8_t(is <  4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
-                                    is <  8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
-                                    is < 12 ? (data_a[ib].scales[is-8] >>  4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
-                                            (data_a[ib].scales[is-8] >>  4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
-            const float dl = float(data_a[ib].d) * float(us - 32);
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi    ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi    ] & m) != 0) ? 0 : 4)));
-            buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
-#elif defined(DATA_A_Q4_K)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                 // 2 values per idx
-            const uint iqs = idx % 128;                // 0..127
-
-            const uint n = iqs / 32;                   // 0,1,2,3
-            const uint b = (iqs % 32) / 16;            // 0,1
-            const uint is = 2 * n + b;                 // 0..7
-            const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126
-
-            const vec2 loadd = vec2(data_a[ib].d);
-
-            uint8_t sc;
-            uint8_t mbyte;
-            if (is < 4) {
-                sc    = uint8_t(data_a[ib].scales[is    ] & 63);
-                mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
-            } else {
-                sc    = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
-                mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
-            }
-            const float d = loadd.x * sc;
-            const float m = -loadd.y * mbyte;
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF), m));
-            buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
-#elif defined(DATA_A_Q5_K)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                 // 2 values per idx
-            const uint iqs = idx % 128;                // 0..127
-
-            const uint n = iqs / 32;                   // 0,1,2,3
-            const uint b = (iqs % 32) / 16;            // 0,1
-            const uint is = 2 * n + b;                 // 0..7
-            const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126
-            const uint qhi = (iqs % 16) * 2;           // 0,2,4..30
-
-            const uint8_t hm = uint8_t(1 << (iqs / 16));
-
-            const vec2 loadd = vec2(data_a[ib].d);
-
-            uint8_t sc;
-            uint8_t mbyte;
-            if (is < 4) {
-                sc    = uint8_t(data_a[ib].scales[is    ] & 63);
-                mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
-            } else {
-                sc    = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
-                mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
-            }
-            const float d = loadd.x * sc;
-            const float m = -loadd.y * mbyte;
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi    ] & hm) != 0 ? 16 : 0), m));
-            buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
-#elif defined(DATA_A_Q6_K)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
-
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint iqs = idx % 128;                 // 0..127
-
-            const uint n = iqs / 64;                    // 0,1
-            const uint b = (iqs % 64) / 32;             // 0,1
-            const uint is_b = (iqs % 16) / 8;           // 0,1
-            const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
-            const uint is = 8 * n + qhshift + is_b;     // 0..15
-            const uint qsi = n * 64 + (iqs % 32) * 2;   // 0,2,4..126
-            const uint qhi = n * 32 + (iqs % 16) * 2;   // 0,2,4..62
-
-            const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi    ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi    ] >> qhshift) & 3) << 4)) - 32));
-            buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
-#elif defined(DATA_A_IQ4_NL)
-            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
-
-            const uint ib = idx / 16;
-            const uint iqs = idx & 0xF;
-
-            const float d = float(data_a[ib].d);
-            const uint vui = uint(data_a[ib].qs[iqs]);
-            const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
-
-            buf_a[buf_idx     ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
-#endif
-        }
-        [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
-#if LOAD_VEC_B == 8
-#ifdef MUL_MAT_ID
-            const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
-            const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
-#else
-            const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
-#endif
-            const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
-            buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
-            buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
-            buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
-            buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
-            buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
-            buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
-            buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
-            buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
-#elif LOAD_VEC_B == 4
-#ifdef MUL_MAT_ID
-            const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
-            const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
-#else
-            const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
-#endif
-            const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
-            buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
-            buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
-            buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
-            buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
-#elif !MUL_MAT_ID
-            if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
-            } else {
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
-            }
-#else
-            const uint row_i = ic * BN + loadc_b + l;
-            if (row_i < _ne1) {
-                const u16vec2 row_idx = row_ids[row_i];
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
-            } else {
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
-            }
-#endif
-        }
-
-        barrier();
-
-        pos_a += BK / LOAD_VEC_A;
-        pos_b += BK / LOAD_VEC_B;
-
-        for (uint i = 0; i < BK; i++) {
-            // Load from shared into cache
-            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-                [[unroll]] for (uint j = 0; j < TM; j++) {
-                    cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
-                }
-            }
-            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
-                [[unroll]] for (uint j = 0; j < TN; j++) {
-                    cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
-                }
-            }
-
-            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
-                [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-                    [[unroll]] for (uint cc = 0; cc < TN; cc++) {
-                        [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-                            const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
-                            sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
-                        }
-                    }
-                }
-            }
-        }
-
-        barrier();
-    }
-
-    const uint dr = ir * BM + warp_r * WM;
-    const uint dc = ic * BN + warp_c * WN;
-
-#ifndef MUL_MAT_ID
-    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
-#endif
-
-    [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
-        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-
-            const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
-            const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
-            [[unroll]] for (uint cc = 0; cc < TN; cc++) {
-#ifdef MUL_MAT_ID
-                const uint row_i = dc_warp + cc;
-                if (row_i >= _ne1) break;
-
-                const u16vec2 row_idx = row_ids[row_i];
-#endif
-                [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-#ifdef MUL_MAT_ID
-                    data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
-#else
-                    if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
-                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
-                    }
-#endif
-                }
-            }
-        }
-    }
-}
diff --git a/ggml/src/vulkan-shaders/norm.comp b/ggml/src/vulkan-shaders/norm.comp
deleted file mode 100644 (file)
index 6627a50..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared vec2 sum[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
-    const uint tid = gl_LocalInvocationID.x;
-
-    sum[tid] = vec2(0.0f, 0.0f);
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const float xi = float(data_a[row*p.KX + col]);
-        sum[tid].x += xi;
-        sum[tid].y += xi * xi;
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            sum[tid] += sum[tid + s];
-        }
-        barrier();
-    }
-
-    const float mean = sum[0].x / p.KX;
-    const float var = sum[0].y / p.KX - mean * mean;
-    const float inv_std = inversesqrt(var + p.param1);
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/pad.comp b/ggml/src/vulkan-shaders/pad.comp
deleted file mode 100644 (file)
index e87d8b1..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    const uint i3 = idx / (p.ne12*p.ne11*p.ne10);
-    const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
-    const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10);
-    const uint i2_offset = i2*p.ne11*p.ne10;
-    const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;
-    const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
-
-    const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
-    const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
-
-    const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
-
-    data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
-}
diff --git a/ggml/src/vulkan-shaders/pool2d.comp b/ggml/src/vulkan-shaders/pool2d.comp
deleted file mode 100644 (file)
index b612441..0000000
+++ /dev/null
@@ -1,74 +0,0 @@
-#version 450
-
-#include "types.comp"
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(push_constant) uniform parameter {
-    uint IW; uint IH;
-    uint OW; uint OH;
-    uint OC;
-    uint pelements;
-    uint op;
-    int k0; int k1;
-    int s0; int s1;
-    int p0; int p1;
-} p;
-
-#define BLOCK_SIZE 512
-#define FLT_MAX 3.402823466e+38F
-#define OP_POOL_MAX 0u
-#define OP_POOL_AVG 1u
-
-layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout(binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout(binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint idx = gl_GlobalInvocationID.x;
-    if (idx >= p.pelements) {
-        return;
-    }
-
-    const uint O_HW = p.OW * p.OH;
-
-    const uint nc = idx / O_HW;
-    const uint cur_oh = (idx % O_HW) / p.OW;
-    const uint cur_ow = (idx % O_HW) % p.OW;
-
-    const int start_h = int(cur_oh) * p.s0 - p.p0;
-    const uint bh = max(start_h, 0);
-    const uint eh = min(start_h + p.k0, p.IH);
-
-    const int start_w = int(cur_ow) * p.s1 - p.p1;
-    const uint bw = max(start_w, 0);
-    const uint ew = min(start_w + p.k1, p.IW);
-
-    const float scale = 1.0 / float(p.k0 * p.k1);
-    float res;
-
-    if (p.op == OP_POOL_AVG) {
-        res = 0.0;
-    } else if (p.op == OP_POOL_MAX) {
-        res = -FLT_MAX;
-    } else {
-        return;
-    }
-
-    #pragma unroll
-    for (uint i = bh; i < eh; i++) {
-        #pragma unroll
-        for (uint j = bw; j < ew; j++) {
-            const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]);
-
-            if (p.op == OP_POOL_AVG) {
-                res += cur * scale;
-            } else if (p.op == OP_POOL_MAX) {
-                res = max(res, cur);
-            }
-        }
-    }
-
-    data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res;
-}
diff --git a/ggml/src/vulkan-shaders/relu.comp b/ggml/src/vulkan-shaders/relu.comp
deleted file mode 100644 (file)
index 52a19b6..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    data_d[i] = max(float(data_a[i]), 0);
-}
diff --git a/ggml/src/vulkan-shaders/repeat.comp b/ggml/src/vulkan-shaders/repeat.comp
deleted file mode 100644 (file)
index c03f737..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-uint src0_idx_mod(uint idx) {
-    const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
-    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
-    const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
-    const uint i12_offset = i12*p.ne11*p.ne10;
-    const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
-    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
-    return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00;
-}
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
-}
diff --git a/ggml/src/vulkan-shaders/rms_norm.comp b/ggml/src/vulkan-shaders/rms_norm.comp
deleted file mode 100644 (file)
index b554400..0000000
+++ /dev/null
@@ -1,42 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-shared FLOAT_TYPE sum[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
-    const uint tid = gl_LocalInvocationID.x;
-
-    sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
-        sum[tid] += xi * xi;
-    }
-
-    // sum up partial sums and write back result
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            sum[tid] += sum[tid + s];
-        }
-        barrier();
-    }
-
-    const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
-    const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
-
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
-    }
-}
diff --git a/ggml/src/vulkan-shaders/rope_head.comp b/ggml/src/vulkan-shaders/rope_head.comp
deleted file mode 100644 (file)
index ea89542..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-#include "types.comp"
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_pos[];};
-layout (binding = 2) readonly buffer Z {float data_ff[];};
-layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
-
-layout (push_constant) uniform parameter {
-    uint ncols;
-    uint n_dims;
-    float freq_scale;
-    uint p_delta_rows;
-    float freq_base;
-    float ext_factor;
-    float attn_factor;
-    float corr_dims[2];
-    float theta_scale;
-    uint has_ff;
-} p;
-
-float rope_yarn_ramp(const float low, const float high, const uint i0) {
-    const float y = (i0 / 2 - low) / max(0.001f, high - low);
-    return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
-    float mscale = p.attn_factor;
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = p.freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (p.ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
-    }
-    cos_theta = cos(theta) * mscale;
-    sin_theta = sin(theta) * mscale;
-}
diff --git a/ggml/src/vulkan-shaders/rope_neox.comp b/ggml/src/vulkan-shaders/rope_neox.comp
deleted file mode 100644 (file)
index 83b46b6..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-#version 450
-
-#include "rope_head.comp"
-
-void main() {
-    const uint col = gl_GlobalInvocationID.y * 2;
-    const uint row = gl_GlobalInvocationID.x;
-
-    if (col >= p.ncols) {
-        return;
-    }
-
-    if (col >= p.n_dims) {
-        const uint i = row*p.ncols + col;
-
-        data_d[i + 0] = data_a[i + 0];
-        data_d[i + 1] = data_a[i + 1];
-
-        return;
-    }
-
-    const uint i  = row*p.ncols + col/2;
-    const uint i2 = row/p.p_delta_rows;
-
-    const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
-
-    const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
-
-    const float x0 = float(data_a[i + 0]);
-    const float x1 = float(data_a[i + p.n_dims/2]);
-
-    data_d[i + 0]        = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
-}
diff --git a/ggml/src/vulkan-shaders/rope_norm.comp b/ggml/src/vulkan-shaders/rope_norm.comp
deleted file mode 100644 (file)
index e416ad9..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-#version 450
-
-#include "rope_head.comp"
-
-void main() {
-    const uint col = gl_GlobalInvocationID.y * 2;
-    const uint row = gl_GlobalInvocationID.x;
-
-    if (col >= p.ncols) {
-        return;
-    }
-
-    if (col >= p.n_dims) {
-        const uint i = row*p.ncols + col;
-
-        data_d[i + 0] = data_a[i + 0];
-        data_d[i + 1] = data_a[i + 1];
-
-        return;
-    }
-
-    const uint i = row*p.ncols + col;
-    const uint i2 = row/p.p_delta_rows;
-
-    const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
-
-    const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
-
-    const float x0 = float(data_a[i + 0]);
-    const float x1 = float(data_a[i + 1]);
-
-    data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
-}
diff --git a/ggml/src/vulkan-shaders/scale.comp b/ggml/src/vulkan-shaders/scale.comp
deleted file mode 100644 (file)
index 5cfee8c..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-const uint num_threads = 128;
-
-layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    uint idx = get_idx();
-
-    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
-    const uint num_iter = 4;
-
-    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
-        if (idx >= p.ne) {
-            continue;
-        }
-
-        data_d[p.d_offset + idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
-        idx += num_threads;
-    }
-}
diff --git a/ggml/src/vulkan-shaders/silu.comp b/ggml/src/vulkan-shaders/silu.comp
deleted file mode 100644 (file)
index 4d36f88..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    const float xi = float(data_a[i]);
-    data_d[i] = D_TYPE(xi / (1.0f + exp(-xi)));
-}
diff --git a/ggml/src/vulkan-shaders/sin.comp b/ggml/src/vulkan-shaders/sin.comp
deleted file mode 100644 (file)
index 67c48fb..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
-}
diff --git a/ggml/src/vulkan-shaders/soft_max.comp b/ggml/src/vulkan-shaders/soft_max.comp
deleted file mode 100644 (file)
index 0bd51ec..0000000
+++ /dev/null
@@ -1,106 +0,0 @@
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint KX;
-    uint KY;
-    float scale;
-    float max_bias;
-    float m0;
-    float m1;
-    uint n_head_log2;
-} p;
-
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 512
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) buffer D {D_TYPE data_d[];};
-
-shared FLOAT_TYPE vals[BLOCK_SIZE];
-
-void main() {
-    const uint tid = gl_LocalInvocationID.x;
-    const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
-    const uint rowy = rowx % p.KY;
-
-    float slope = 1.0f;
-
-    // ALiBi
-    if (p.max_bias > 0.0f) {
-        const uint h = rowx/p.KY; // head index
-
-        const float base = h < p.n_head_log2 ? p.m0 : p.m1;
-        const uint   exp  = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
-
-        slope = pow(base, exp);
-    }
-
-    // Find max
-    FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
-
-    [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
-        const uint col = col0 + tid;
-
-        if (col >= p.KX) {
-            break;
-        }
-
-        max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
-    }
-    vals[tid] = max_val;
-
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            vals[tid] = max(vals[tid], vals[tid + s]);
-        }
-        barrier();
-    }
-
-    max_val = vals[0];
-    barrier();
-
-    // Sum up values
-    vals[tid] = FLOAT_TYPE(0.0f);
-
-    [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
-        const uint col = col0 + tid;
-
-        if (col >= p.KX) {
-            break;
-        }
-
-        const uint i = rowx * p.KX + col;
-        const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
-        vals[tid] += val;
-        data_d[i] = D_TYPE(val);
-    }
-
-    barrier();
-    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
-        if (tid < s) {
-            vals[tid] += vals[tid + s];
-        }
-        barrier();
-    }
-
-    const D_TYPE divisor = D_TYPE(vals[0]);
-
-    [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
-        const uint col = col0 + tid;
-
-        if (col >= p.KX) {
-            break;
-        }
-
-        data_d[rowx*p.KX + col] /= divisor;
-    }
-}
diff --git a/ggml/src/vulkan-shaders/square.comp b/ggml/src/vulkan-shaders/square.comp
deleted file mode 100644 (file)
index 2ff48dd..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-#version 450
-
-#include "types.comp"
-#include "generic_unary_head.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    const uint idx = get_idx();
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
-}
diff --git a/ggml/src/vulkan-shaders/sum_rows.comp b/ggml/src/vulkan-shaders/sum_rows.comp
deleted file mode 100644 (file)
index 961e5ff..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-
-shared FLOAT_TYPE tmp[BLOCK_SIZE];
-
-void main() {
-    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
-    const uint col = gl_LocalInvocationID.x;
-
-    tmp[col] = FLOAT_TYPE(0.0f);
-
-    for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
-        tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
-    }
-
-    barrier();
-    [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
-        if (col < s) {
-            tmp[col] += tmp[col + s];
-        }
-        barrier();
-    }
-
-    if (col == 0) {
-        data_d[row] = D_TYPE(tmp[0]);
-    }
-}
diff --git a/ggml/src/vulkan-shaders/tanh.comp b/ggml/src/vulkan-shaders/tanh.comp
deleted file mode 100644 (file)
index 74630dc..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-#version 450
-
-#include "generic_head.comp"
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
-    if (i >= p.KX) {
-        return;
-    }
-
-    data_d[i] = D_TYPE(tanh(data_a[i]));
-}
diff --git a/ggml/src/vulkan-shaders/timestep_embedding.comp b/ggml/src/vulkan-shaders/timestep_embedding.comp
deleted file mode 100644 (file)
index 79e065a..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-#version 450
-
-#extension GL_EXT_shader_16bit_storage : require
-
-layout (push_constant) uniform parameter
-{
-    uint nb1;
-    uint dim;
-    uint max_period;
-} p;
-
-#include "types.comp"
-
-#extension GL_EXT_control_flow_attributes : enable
-#define BLOCK_SIZE 256
-
-layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint i = gl_WorkGroupID.y;
-    const uint j = gl_GlobalInvocationID.x;
-    const uint d_offset = i * p.nb1;
-
-    if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
-        data_d[d_offset + p.dim] = 0.f;
-    }
-
-    const uint half_dim = p.dim / 2;
-    if (j >= half_dim) {
-        return;
-    }
-
-    const float timestep = float(data_a[i]);
-    const float freq = float(exp(-log(p.max_period) * j / half_dim));
-    const float arg = timestep * freq;
-    data_d[d_offset + j] = D_TYPE(cos(arg));
-    data_d[d_offset + j + half_dim] = D_TYPE(sin(arg));
-}
diff --git a/ggml/src/vulkan-shaders/types.comp b/ggml/src/vulkan-shaders/types.comp
deleted file mode 100644 (file)
index 21dce72..0000000
+++ /dev/null
@@ -1,200 +0,0 @@
-#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
-#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-#endif
-
-#if defined(DATA_A_F32)
-#define QUANT_K 1
-#define QUANT_R 1
-
-#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
-#define A_TYPE float
-#elif LOAD_VEC_A == 4
-#define A_TYPE vec4
-#elif LOAD_VEC_A == 8
-#define A_TYPE mat2x4
-#endif
-#endif
-
-#if defined(DATA_A_F16)
-#define QUANT_K 1
-#define QUANT_R 1
-
-#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
-#define A_TYPE float16_t
-#elif LOAD_VEC_A == 4
-#define A_TYPE f16vec4
-#elif LOAD_VEC_A == 8
-#define A_TYPE f16mat2x4
-#endif
-#endif
-
-#if defined(DATA_A_Q4_0)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q4_0
-{
-    float16_t d;
-    uint8_t qs[16];
-};
-
-#define A_TYPE block_q4_0
-#endif
-
-#if defined(DATA_A_Q4_1)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q4_1
-{
-    float16_t d;
-    float16_t m;
-    uint8_t qs[16];
-};
-
-#define A_TYPE block_q4_1
-#endif
-
-#if defined(DATA_A_Q5_0)
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q5_0
-{
-    float16_t d;
-    uint16_t qh[2];
-    uint8_t qs[16];
-};
-
-#define A_TYPE block_q5_0
-#endif
-
-#if defined(DATA_A_Q5_1)
-#extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_q5_1
-{
-    float16_t d;
-    float16_t m;
-    uint qh;
-    uint8_t qs[16];
-};
-
-#define A_TYPE block_q5_1
-#endif
-
-#if defined(DATA_A_Q8_0)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 1
-
-struct block_q8_0
-{
-    float16_t d;
-    int8_t qs[32];
-};
-
-#define A_TYPE block_q8_0
-#endif
-
-// K-quants
-#if defined(DATA_A_Q2_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q2_K
-{
-    uint8_t scales[QUANT_K/16];
-    uint8_t qs[QUANT_K/4];
-    f16vec2 d;
-};
-
-#define A_TYPE block_q2_K
-#endif
-
-#if defined(DATA_A_Q3_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q3_K
-{
-    uint8_t hmask[QUANT_K/8];
-    uint8_t qs[QUANT_K/4];
-    uint8_t scales[12];
-    float16_t d;
-};
-
-#define A_TYPE block_q3_K
-#endif
-
-#if defined(DATA_A_Q4_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q4_K
-{
-    f16vec2 d;
-    uint8_t scales[3*QUANT_K/64];
-    uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_q4_K
-#endif
-
-#if defined(DATA_A_Q5_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q5_K
-{
-    f16vec2 d;
-    uint8_t scales[12];
-    uint8_t qh[QUANT_K/8];
-    uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_q5_K
-#endif
-
-#if defined(DATA_A_Q6_K)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 256
-
-struct block_q6_K
-{
-    uint8_t ql[QUANT_K/2];
-    uint8_t qh[QUANT_K/4];
-    int8_t scales[QUANT_K/16];
-    float16_t d;
-};
-
-#define A_TYPE block_q6_K
-#endif
-
-// IQuants
-
-#if defined(DATA_A_IQ4_NL)
-#extension GL_EXT_shader_16bit_storage : require
-#define QUANT_K 32
-#define QUANT_R 2
-
-struct block_iq4_nl
-{
-    float16_t d;
-    uint8_t qs[QUANT_K/2];
-};
-
-#define A_TYPE block_iq4_nl
-
-const int8_t kvalues_iq4nl[16] = {
-    int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
-    int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
-};
-#endif
diff --git a/ggml/src/vulkan-shaders/upscale.comp b/ggml/src/vulkan-shaders/upscale.comp
deleted file mode 100644 (file)
index 511a086..0000000
+++ /dev/null
@@ -1,36 +0,0 @@
-#version 450
-
-layout (push_constant) uniform parameter
-{
-    uint ne; uint d_offset;
-    uint nb00; uint nb01; uint nb02; uint nb03;
-    uint ne10; uint ne11; uint ne12; uint ne13;
-    float sf0; float sf1; float sf2; float sf3;
-} p;
-
-#include "types.comp"
-
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
-void main() {
-    const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
-
-    if (idx >= p.ne) {
-        return;
-    }
-
-    const uint i10 = idx % p.ne10;
-    const uint i11 = (idx / p.ne10) % p.ne11;
-    const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
-    const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
-
-    const uint i00 = uint(i10 / p.sf0);
-    const uint i01 = uint(i11 / p.sf1);
-    const uint i02 = uint(i12 / p.sf2);
-    const uint i03 = uint(i13 / p.sf3);
-
-    data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
-}
diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
deleted file mode 100644 (file)
index 5c84f47..0000000
+++ /dev/null
@@ -1,519 +0,0 @@
-
-
-#include <iostream>
-#include <fstream>
-#include <sstream>
-#include <string>
-#include <stdexcept>
-#include <array>
-#include <vector>
-#include <map>
-#include <thread>
-#include <mutex>
-#include <future>
-#include <queue>
-#include <condition_variable>
-#include <cstdio>
-#include <cstring>
-#include <cstdlib>
-#include <cassert>
-#include <sys/stat.h>
-#include <sys/types.h>
-
-#ifdef _WIN32
-    #include <windows.h>
-    #include <direct.h> // For _mkdir on Windows
-    #include <algorithm> // For std::replace on w64devkit
-#else
-    #include <unistd.h>
-    #include <sys/wait.h>
-    #include <fcntl.h>
-#endif
-
-#define ASYNCIO_CONCURRENCY 64
-
-std::mutex lock;
-std::vector<std::pair<std::string, std::string>> shader_fnames;
-
-std::string GLSLC = "glslc";
-std::string input_dir = "vulkan-shaders";
-std::string output_dir = "/tmp";
-std::string target_hpp = "ggml-vulkan-shaders.hpp";
-std::string target_cpp = "ggml-vulkan-shaders.cpp";
-bool no_clean = false;
-
-const std::vector<std::string> type_names = {
-    "f32",
-    "f16",
-    "q4_0",
-    "q4_1",
-    "q5_0",
-    "q5_1",
-    "q8_0",
-    "q2_k",
-    "q3_k",
-    "q4_k",
-    "q5_k",
-    "q6_k",
-    "iq4_nl"
-};
-
-void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
-#ifdef _WIN32
-    HANDLE stdout_read, stdout_write;
-    HANDLE stderr_read, stderr_write;
-    SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
-
-    if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
-        !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
-        throw std::runtime_error("Failed to create stdout pipe");
-    }
-
-    if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
-        !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
-        throw std::runtime_error("Failed to create stderr pipe");
-    }
-
-    PROCESS_INFORMATION pi;
-    STARTUPINFOA si = { sizeof(STARTUPINFOA) };
-    si.dwFlags = STARTF_USESTDHANDLES;
-    si.hStdOutput = stdout_write;
-    si.hStdError = stderr_write;
-
-    std::vector<char> cmd(command.begin(), command.end());
-    cmd.push_back('\0');
-
-    if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
-        throw std::runtime_error("Failed to create process");
-    }
-
-    CloseHandle(stdout_write);
-    CloseHandle(stderr_write);
-
-    std::array<char, 128> buffer;
-    DWORD bytes_read;
-
-    while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
-        stdout_str.append(buffer.data(), bytes_read);
-    }
-
-    while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
-        stderr_str.append(buffer.data(), bytes_read);
-    }
-
-    CloseHandle(stdout_read);
-    CloseHandle(stderr_read);
-    WaitForSingleObject(pi.hProcess, INFINITE);
-    CloseHandle(pi.hProcess);
-    CloseHandle(pi.hThread);
-#else
-int stdout_pipe[2];
-    int stderr_pipe[2];
-
-    if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
-        throw std::runtime_error("Failed to create pipes");
-    }
-
-    pid_t pid = fork();
-    if (pid < 0) {
-        throw std::runtime_error("Failed to fork process");
-    }
-
-    if (pid == 0) {
-        close(stdout_pipe[0]);
-        close(stderr_pipe[0]);
-        dup2(stdout_pipe[1], STDOUT_FILENO);
-        dup2(stderr_pipe[1], STDERR_FILENO);
-        close(stdout_pipe[1]);
-        close(stderr_pipe[1]);
-        execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
-        _exit(EXIT_FAILURE);
-    } else {
-        close(stdout_pipe[1]);
-        close(stderr_pipe[1]);
-
-        std::array<char, 128> buffer;
-        ssize_t bytes_read;
-
-        while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
-            stdout_str.append(buffer.data(), bytes_read);
-        }
-
-        while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
-            stderr_str.append(buffer.data(), bytes_read);
-        }
-
-        close(stdout_pipe[0]);
-        close(stderr_pipe[0]);
-        waitpid(pid, nullptr, 0);
-    }
-#endif
-}
-
-bool directory_exists(const std::string& path) {
-    struct stat info;
-    if (stat(path.c_str(), &info) != 0) {
-        return false; // Path doesn't exist or can't be accessed
-    }
-    return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
-}
-
-bool create_directory(const std::string& path) {
-#ifdef _WIN32
-    return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
-#else
-    return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
-#endif
-}
-
-std::string to_uppercase(const std::string& input) {
-    std::string result = input;
-    for (char& c : result) {
-        c = std::toupper(c);
-    }
-    return result;
-}
-
-bool string_ends_with(const std::string& str, const std::string& suffix) {
-    if (suffix.size() > str.size()) {
-        return false;
-    }
-    return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
-}
-
-static const char path_separator = '/';
-
-std::string join_paths(const std::string& path1, const std::string& path2) {
-    return path1 + path_separator + path2;
-}
-
-std::string basename(const std::string &path) {
-    return path.substr(path.find_last_of("/\\") + 1);
-}
-
-// variables to track number of compiles in progress
-static uint32_t compile_count = 0;
-static std::mutex compile_count_mutex;
-static std::condition_variable compile_count_cond;
-
-void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
-    std::string name = _name + (fp16 ? "" : "_fp32");
-    std::string out_fname = join_paths(output_dir, name + ".spv");
-    std::string in_path = join_paths(input_dir, in_fname);
-
-    #ifdef _WIN32
-        std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
-    #else
-        std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o",  out_fname};
-    #endif
-
-    #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
-        cmd.push_back("-g");
-    #endif
-
-    for (const auto& define : defines) {
-        cmd.push_back("-D" + define.first + "=" + define.second);
-    }
-
-    std::string command;
-    for (const auto& part : cmd) {
-        command += part + " ";
-    }
-
-    std::string stdout_str, stderr_str;
-    try {
-        // std::cout << "Executing command: ";
-        // for (const auto& part : cmd) {
-        //     std::cout << part << " ";
-        // }
-        // std::cout << std::endl;
-
-        execute_command(command, stdout_str, stderr_str);
-        if (!stderr_str.empty()) {
-            std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
-            return;
-        }
-
-        std::lock_guard<std::mutex> guard(lock);
-        shader_fnames.push_back(std::make_pair(name, out_fname));
-    } catch (const std::exception& e) {
-        std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
-    }
-    {
-        std::lock_guard<std::mutex> guard(compile_count_mutex);
-        assert(compile_count > 0);
-        compile_count--;
-    }
-    compile_count_cond.notify_all();
-}
-
-std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
-    std::map<std::string, std::string> result = a;
-    result.insert(b.begin(), b.end());
-    return result;
-}
-
-static std::vector<std::future<void>> compiles;
-void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
-    {
-        // wait until fewer than N compiles are in progress.
-        // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
-        uint32_t N = 16;
-        std::unique_lock<std::mutex> guard(compile_count_mutex);
-        while (compile_count >= N) {
-            compile_count_cond.wait(guard);
-        }
-        compile_count++;
-    }
-    compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16));
-}
-
-void matmul_shaders(bool fp16, bool matmul_id) {
-    std::string load_vec = fp16 ? "8" : "4";
-    std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
-    std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
-
-    std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
-    std::string shader_name = "matmul";
-
-    if (matmul_id) {
-        base_dict["MUL_MAT_ID"] = "1";
-        shader_name = "matmul_id";
-    }
-
-    if (fp16) {
-        base_dict["FLOAT16"] = "1";
-    }
-
-    // Shaders with f16 B_TYPE
-    string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
-    string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
-
-    string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
-    string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
-
-    for (const auto& tname : type_names) {
-        std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-        // For unaligned, load one at a time for f32/f16, or two at a time for quants
-        std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
-        // For aligned matmul loads
-        std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
-        string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
-        string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
-    }
-}
-
-void process_shaders() {
-    std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
-    std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
-
-    for (const auto& fp16 : {false, true}) {
-        matmul_shaders(fp16, false);
-        matmul_shaders(fp16, true);
-    }
-
-    for (const auto& tname : type_names) {
-        // mul mat vec
-        std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-        std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
-
-        string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
-        string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
-
-        string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
-
-        // Dequant shaders
-        if (tname != "f16") {
-            string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
-        }
-
-        if (!string_ends_with(tname, "_k")) {
-            shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
-
-            if (tname == "f16") {
-                string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
-            } else {
-                string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
-            }
-            string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
-        }
-    }
-
-    string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-
-    // Norms
-    string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-    string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-    string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
-    string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
-    string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
-    string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
-    string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
-
-    string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-    string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
-
-    string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
-    string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-
-    string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
-    string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
-    string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
-
-    string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-
-    string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
-    string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-
-    string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
-    string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
-
-    string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
-
-    string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
-
-    string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
-
-    string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
-    string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-    string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
-
-    string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
-    string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-
-    for (auto &c : compiles) {
-        c.wait();
-    }
-}
-
-void write_output_files() {
-    FILE* hdr = fopen(target_hpp.c_str(), "w");
-    FILE* src = fopen(target_cpp.c_str(), "w");
-
-    fprintf(hdr, "#include <cstdint>\n\n");
-    fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
-
-    for (const auto& pair : shader_fnames) {
-        const std::string& name = pair.first;
-        #ifdef _WIN32
-            std::string path = pair.second;
-            std::replace(path.begin(), path.end(), '/', '\\' );
-        #else
-            const std::string& path = pair.second;
-        #endif
-
-        FILE* spv = fopen(path.c_str(), "rb");
-        if (!spv) {
-            std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
-            continue;
-        }
-
-        fseek(spv, 0, SEEK_END);
-        size_t size = ftell(spv);
-        fseek(spv, 0, SEEK_SET);
-
-        std::vector<unsigned char> data(size);
-        size_t read_size = fread(data.data(), 1, size, spv);
-        fclose(spv);
-        if (read_size != size) {
-            std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
-            continue;
-        }
-
-        fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
-        fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
-
-        fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
-        for (size_t i = 0; i < size; ++i) {
-            fprintf(src, "0x%02x,", data[i]);
-            if ((i + 1) % 12 == 0) fprintf(src, "\n");
-        }
-        fprintf(src, "\n};\n\n");
-
-        if (!no_clean) {
-            std::remove(path.c_str());
-        }
-    }
-
-    fclose(hdr);
-    fclose(src);
-}
-
-int main(int argc, char** argv) {
-    std::map<std::string, std::string> args;
-    for (int i = 1; i < argc; i += 2) {
-        if (i + 1 < argc) {
-            args[argv[i]] = argv[i + 1];
-        }
-    }
-
-    if (args.find("--glslc") != args.end()) {
-        GLSLC = args["--glslc"]; // Path to glslc
-    }
-    if (args.find("--input-dir") != args.end()) {
-        input_dir = args["--input-dir"]; // Directory containing shader sources
-    }
-    if (args.find("--output-dir") != args.end()) {
-        output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
-    }
-    if (args.find("--target-hpp") != args.end()) {
-        target_hpp = args["--target-hpp"]; // Path to generated header file
-    }
-    if (args.find("--target-cpp") != args.end()) {
-        target_cpp = args["--target-cpp"]; // Path to generated cpp file
-    }
-    if (args.find("--no-clean") != args.end()) {
-        no_clean = true; // Keep temporary SPIR-V files in output-dir after build
-    }
-
-    if (!directory_exists(input_dir)) {
-        std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
-        return EXIT_FAILURE;
-    }
-
-    if (!directory_exists(output_dir)) {
-        if (!create_directory(output_dir)) {
-            std::cerr << "Error creating output directory: " << output_dir << "\n";
-            return EXIT_FAILURE;
-        }
-    }
-
-    process_shaders();
-
-    write_output_files();
-
-    return EXIT_SUCCESS;
-}