]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync latest llama.cpp
authorGeorgi Gerganov <redacted>
Mon, 19 Jun 2023 17:35:08 +0000 (20:35 +0300)
committerGeorgi Gerganov <redacted>
Mon, 19 Jun 2023 17:35:08 +0000 (20:35 +0300)
examples/common-ggml.cpp
include/ggml/ggml.h
scripts/sync-llama.sh
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml-metal.h [new file with mode: 0644]
src/ggml-metal.m [new file with mode: 0644]
src/ggml-metal.metal [new file with mode: 0644]
src/ggml-opencl.cpp [new file with mode: 0644]
src/ggml-opencl.h
src/ggml.c

index 9215dbeabd195339408f5ccb538f1a2dbcde5f3e..33ae03ae10ff419dd18b336e5e7d9406afc94d36 100644 (file)
@@ -52,6 +52,11 @@ bool ggml_common_quantize_0(
         case GGML_FTYPE_ALL_F32:
         case GGML_FTYPE_MOSTLY_F16:
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
+        case GGML_FTYPE_MOSTLY_Q2_K:
+        case GGML_FTYPE_MOSTLY_Q3_K:
+        case GGML_FTYPE_MOSTLY_Q4_K:
+        case GGML_FTYPE_MOSTLY_Q5_K:
+        case GGML_FTYPE_MOSTLY_Q6_K:
                 {
                     fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
                     return false;
@@ -187,6 +192,12 @@ bool ggml_common_quantize_0(
                 case GGML_TYPE_I16:
                 case GGML_TYPE_I32:
                 case GGML_TYPE_Q8_1:
+                case GGML_TYPE_Q2_K:
+                case GGML_TYPE_Q3_K:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q5_K:
+                case GGML_TYPE_Q6_K:
+                case GGML_TYPE_Q8_K:
                 case GGML_TYPE_COUNT:
                     {
                         fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
index 3b83fe6eaf87991ca864bc5f9a6271c59ca3cbf9..18c78551f3dcd1283c0aaad586eb7ac254216bf2 100644 (file)
@@ -241,6 +241,13 @@ extern "C" {
         GGML_TYPE_Q5_1 = 7,
         GGML_TYPE_Q8_0 = 8,
         GGML_TYPE_Q8_1 = 9,
+        // k-quantizations
+        GGML_TYPE_Q2_K = 10,
+        GGML_TYPE_Q3_K = 11,
+        GGML_TYPE_Q4_K = 12,
+        GGML_TYPE_Q5_K = 13,
+        GGML_TYPE_Q6_K = 14,
+        GGML_TYPE_Q8_K = 15,
         GGML_TYPE_I8,
         GGML_TYPE_I16,
         GGML_TYPE_I32,
@@ -249,8 +256,8 @@ extern "C" {
 
     enum ggml_backend {
         GGML_BACKEND_CPU = 0,
-        GGML_BACKEND_CUDA = 1,
-        GGML_BACKEND_CL = 2,
+        GGML_BACKEND_GPU = 10,
+        GGML_BACKEND_GPU_SPLIT = 20,
     };
 
     // model file types
@@ -264,6 +271,11 @@ extern "C" {
         GGML_FTYPE_MOSTLY_Q8_0 = 7,  // except 1d tensors
         GGML_FTYPE_MOSTLY_Q5_0 = 8,  // except 1d tensors
         GGML_FTYPE_MOSTLY_Q5_1 = 9,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
     };
 
     // available tensor operations:
@@ -284,6 +296,7 @@ extern "C" {
         GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
         GGML_OP_REPEAT,
+        GGML_OP_REPEAT_BACK,
         GGML_OP_ABS,
         GGML_OP_SGN,
         GGML_OP_NEG,
@@ -298,6 +311,7 @@ extern "C" {
         GGML_OP_RMS_NORM_BACK,
 
         GGML_OP_MUL_MAT,
+        GGML_OP_OUT_PROD,
 
         GGML_OP_SCALE,
         GGML_OP_SET,
@@ -313,6 +327,7 @@ extern "C" {
         GGML_OP_DIAG_MASK_INF,
         GGML_OP_DIAG_MASK_ZERO,
         GGML_OP_SOFT_MAX,
+        GGML_OP_SOFT_MAX_BACK,
         GGML_OP_ROPE,
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
@@ -323,12 +338,16 @@ extern "C" {
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
+        GGML_OP_FLASH_ATTN_BACK,
         GGML_OP_WIN_PART,
         GGML_OP_WIN_UNPART,
 
         GGML_OP_MAP_UNARY,
         GGML_OP_MAP_BINARY,
 
+        GGML_OP_CROSS_ENTROPY_LOSS,
+        GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+
         GGML_OP_COUNT,
     };
 
@@ -379,7 +398,9 @@ extern "C" {
 
         char name[GGML_MAX_NAME];
 
-        char padding[16];
+        void * extra; // extra things e.g. for ggml-cuda.cu
+
+        char padding[4];
     };
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -417,6 +438,25 @@ extern "C" {
         bool   no_alloc;   // don't allocate memory for the tensor data
     };
 
+
+    // compute types
+    enum ggml_task_type {
+        GGML_TASK_INIT = 0,
+        GGML_TASK_COMPUTE,
+        GGML_TASK_FINALIZE,
+    };
+
+    struct ggml_compute_params {
+        enum ggml_task_type type;
+
+        // ith = thread index, nth = number of threads
+        int ith, nth;
+
+        // work buffer for all threads
+        size_t wsize;
+        void * wdata;
+    };
+
     // misc
 
     GGML_API void    ggml_time_init(void); // call this once at the beginning of the program
@@ -428,8 +468,10 @@ extern "C" {
     GGML_API void    ggml_print_object (const struct ggml_object * obj);
     GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 
-    GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
-    GGML_API size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
+    GGML_API int64_t ggml_nelements   (const struct ggml_tensor * tensor);
+    GGML_API int64_t ggml_nrows       (const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes      (const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
 
     GGML_API int     ggml_blck_size (enum ggml_type type);
     GGML_API size_t  ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
@@ -445,21 +487,26 @@ extern "C" {
     // TODO: temporary until model loading of ggml examples is refactored
     GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
 
+    GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_permuted  (const struct ggml_tensor * tensor);
+
     // use this to compute the memory overhead of a tensor
     GGML_API size_t ggml_tensor_overhead(void);
 
     // main
 
     GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
-    GGML_API void    ggml_free(struct ggml_context * ctx);
+    GGML_API void                  ggml_free(struct ggml_context * ctx);
 
     GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
 
     GGML_API size_t  ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
     GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
 
-    GGML_API void *  ggml_get_mem_buffer(struct ggml_context * ctx);
-    GGML_API size_t  ggml_get_mem_size  (struct ggml_context * ctx);
+    GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);
 
     GGML_API struct ggml_tensor * ggml_new_tensor(
             struct ggml_context * ctx,
@@ -639,6 +686,11 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_repeat_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_abs(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -736,14 +788,22 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
-    // A: m rows, n columns
-    // B: p rows, n columns (i.e. we transpose it internally)
+    // A: n columns, m rows
+    // B: n columns, p rows  (i.e. we transpose it internally)
     // result is m columns, p rows
     GGML_API struct ggml_tensor * ggml_mul_mat(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // A: m columns, n rows,
+    // B: p columns, n rows,
+    // result is m columns, p rows
+    GGML_API struct ggml_tensor * ggml_out_prod(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     //
     // operations on tensors without backpropagation
     //
@@ -954,6 +1014,17 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_soft_max_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // rotary position embedding
     // if mode & 1 == 1, skip n_past elements
     // if mode & 2 == 1, GPT-NeoX style
@@ -1059,6 +1130,14 @@ extern "C" {
             struct ggml_tensor  * v,
             bool                  masked);
 
+    GGML_API struct ggml_tensor * ggml_flash_attn_back(
+           struct ggml_context * ctx,
+           struct ggml_tensor  * q,
+           struct ggml_tensor  * k,
+           struct ggml_tensor  * v,
+           struct ggml_tensor  * d,
+           bool                  masked);
+
     GGML_API struct ggml_tensor * ggml_flash_ff(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1102,6 +1181,19 @@ extern "C" {
             struct ggml_tensor          * b,
                    ggml_binary_op_f32_t   fun);
 
+    // loss function
+
+    GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b);
+
+    GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+            struct ggml_tensor          * c);
+
     //
     // automatic differentiation
     //
@@ -1196,6 +1288,8 @@ extern "C" {
         struct {
             int n_iter;
 
+            float sched; // schedule multiplier (fixed, decay or warmup)
+            float decay; // weight decay for AdamW, use 0.0f to disable
             float alpha; // learning rate
             float beta1;
             float beta2;
@@ -1220,6 +1314,49 @@ extern "C" {
         } lbfgs;
     };
 
+    struct ggml_opt_context {
+        struct ggml_context * ctx;
+        struct ggml_opt_params params;
+
+        int iter;
+        int64_t nx; // number of parameter elements
+
+        bool just_initialized;
+
+        struct {
+            struct ggml_tensor * x;  // view of the parameters
+            struct ggml_tensor * g1; // gradient
+            struct ggml_tensor * g2; // gradient squared
+            struct ggml_tensor * m;  // first moment
+            struct ggml_tensor * v;  // second moment
+            struct ggml_tensor * mh; // first moment hat
+            struct ggml_tensor * vh; // second moment hat
+            struct ggml_tensor * pf; // past function values
+            float fx_best;
+            float fx_prev;
+            int n_no_improvement;
+        } adam;
+
+        struct {
+            struct ggml_tensor * x;    // current parameters
+            struct ggml_tensor * xp;   // previous parameters
+            struct ggml_tensor * g;    // current gradient
+            struct ggml_tensor * gp;   // previous gradient
+            struct ggml_tensor * d;    // search direction
+            struct ggml_tensor * pf;   // past function values
+            struct ggml_tensor * lmal; // the L-BFGS memory alpha
+            struct ggml_tensor * lmys; // the L-BFGS memory ys
+            struct ggml_tensor * lms;  // the L-BFGS memory s
+            struct ggml_tensor * lmy;  // the L-BFGS memory y
+            float fx_best;
+            float step;
+            int j;
+            int k;
+            int end;
+            int n_no_improvement;
+        } lbfgs;
+    };
+
     GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
 
     // optimize the function defined by the tensor f
@@ -1228,6 +1365,27 @@ extern "C" {
             struct ggml_opt_params params,
             struct ggml_tensor * f);
 
+    // initialize optimizer context
+    GGML_API void ggml_opt_init(
+            struct ggml_context * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_opt_params params,
+            int64_t nx);
+
+    // continue optimizing the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt_resume(
+            struct ggml_context * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_tensor * f);
+
+    // continue optimizing the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt_resume_g(
+            struct ggml_context * ctx,
+            struct ggml_opt_context * opt,
+            struct ggml_tensor * f,
+            struct ggml_cgraph * gf,
+            struct ggml_cgraph * gb);
+
     //
     // quantization
     //
index 85c45baab2ea610f0531cb3dcad62e573e201d6d..9bccd91d5d900253e0acd69b8d3694784a0dab4b 100755 (executable)
@@ -1,8 +1,11 @@
 #!/bin/bash
 
-cp -rpv ../llama.cpp/ggml.c        src/ggml.c
-cp -rpv ../llama.cpp/ggml-cuda.h   src/ggml-cuda.h
-cp -rpv ../llama.cpp/ggml-cuda.cu  src/ggml-cuda.cu
-cp -rpv ../llama.cpp/ggml-opencl.h src/ggml-opencl.h
-cp -rpv ../llama.cpp/ggml-opencl.c src/ggml-opencl.c
-cp -rpv ../llama.cpp/ggml.h        include/ggml/ggml.h
+cp -rpv ../llama.cpp/ggml.c           src/ggml.c
+cp -rpv ../llama.cpp/ggml-cuda.h      src/ggml-cuda.h
+cp -rpv ../llama.cpp/ggml-cuda.cu     src/ggml-cuda.cu
+cp -rpv ../llama.cpp/ggml-opencl.h    src/ggml-opencl.h
+cp -rpv ../llama.cpp/ggml-opencl.cpp  src/ggml-opencl.cpp
+cp -rpv ../llama.cpp/ggml-metal.h     src/ggml-metal.h
+cp -rpv ../llama.cpp/ggml-metal.m     src/ggml-metal.m
+cp -rpv ../llama.cpp/ggml-metal.metal src/ggml-metal.metal
+cp -rpv ../llama.cpp/ggml.h           include/ggml/ggml.h
index 98170a3ae17de41c54e36a20e37024ea8f38f51c..36a251ecce973bddee7df398cd61931fbed22cf9 100644 (file)
@@ -1,8 +1,10 @@
 #include <cstddef>
 #include <cstdint>
+#include <limits>
 #include <stdint.h>
 #include <stdio.h>
 #include <atomic>
+#include <assert.h>
 
 #include <cuda_runtime.h>
 #include <cublas_v2.h>
 #include "ggml-cuda.h"
 #include "ggml.h"
 
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
 static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 
 #define CUDA_CHECK(err)                                                                 \
@@ -23,18 +29,44 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
         }                                                                               \
     } while (0)
 
+#if CUDART_VERSION >= 12000
+#define CUBLAS_CHECK(err)                                                               \
+    do {                                                                                \
+        cublasStatus_t err_ = (err);                                                    \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
+                    err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+#else
 #define CUBLAS_CHECK(err)                                                               \
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);  \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
+#endif // CUDART_VERSION >= 11
+
+#ifdef GGML_CUDA_DMMV_F16
+typedef half dfloat; // dequantize float
+typedef half2 dfloat2;
+#else
+typedef float dfloat; // dequantize float
+typedef float2 dfloat2;
+#endif //GGML_CUDA_DMMV_F16
 
-typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
 typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
-typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
+typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+typedef void (*ggml_cuda_op_t)(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
+    float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main);
 
 // QK = number of values after dequantization
 // QR = QK / number of values before dequantization
@@ -83,10 +115,60 @@ typedef struct {
 } block_q8_0;
 static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
+//================================= k-quants
+
+#define QK_K 256
+
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    half d;                  // super-block scale for quantized scales
+    half dmin;               // super-block scale for quantized mins
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+typedef struct {
+    uint8_t hmask[QK_K/8];
+    uint8_t qs[QK_K/4]; // nibbles / quants
+    uint8_t scales[3*QK_K/64];
+    half d;
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+
+typedef struct {
+    half d;                    // super-block scale for quantized scales
+    half dmin;                 // super-block scale for quantized mins
+    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+
+typedef struct {
+    half    d;                   // super-block scale for quantized scales
+    half    dmin;                // super-block scale for quantized mins
+    uint8_t scales[3*QK_K/64];   // scales, quantized with 6 bits
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+
+typedef struct {
+    uint8_t ql[QK_K/2];   // quants, lower 4 bits
+    uint8_t qh[QK_K/4];   // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales
+    half    d;         // delta
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
+
 #define WARP_SIZE 32
 
+#define CUDA_ADD_BLOCK_SIZE 256
 #define CUDA_MUL_BLOCK_SIZE 256
-
+#define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_CPY_BLOCK_SIZE 32
+#define CUDA_SCALE_BLOCK_SIZE 256
+#define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 
 // dmmv = dequantize_mul_mat_vec
@@ -97,6 +179,21 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
 #define GGML_CUDA_DMMV_Y 1
 #endif
 
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 2
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
+static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] + y[i];
+}
+
 static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -106,144 +203,367 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co
     dst[i] = x[i] * y[i%ky];
 }
 
-static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __global__ void silu_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] / (1.0f + expf(-x[i]));
+}
+
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    const float eps = 1e-6;
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int i = 0; i < ncols; i += WARP_SIZE) {
+        const int col = i + tid;
+        const float xi = x[row*ncols + col];
+        tmp += xi * xi;
+    }
+
+    // sum up partial sums
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    const float mean = tmp / ncols;
+    const float scale = 1.0f / sqrtf(mean + eps);
+
+    for (int i = 0; i < ncols; i += WARP_SIZE) {
+        const int col = i + tid;
+        dst[row*ncols + col] = scale * x[row*ncols + col];
+    }
+}
+
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q4_0 * x = (const block_q4_0 *) vx;
 
-    const float d = x[ib].d;
+    const dfloat d = x[ib].d;
 
-    const uint8_t vui = x[ib].qs[iqs];
+    const int vui = x[ib].qs[iqs];
 
-    const int8_t vi0 = vui & 0xF;
-    const int8_t vi1 = vui >> 4;
+    v.x = vui & 0xF;
+    v.y = vui >> 4;
 
-    v0 = (vi0 - 8)*d;
-    v1 = (vi1 - 8)*d;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hsub2(v, {8.0f, 8.0f});
+    v = __hmul2(v, {d, d});
+#else
+    v.x = (v.x - 8.0f) * d;
+    v.y = (v.y - 8.0f) * d;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q4_1 * x = (const block_q4_1 *) vx;
 
-    const float d = x[ib].d;
-    const float m = x[ib].m;
+    const dfloat d = x[ib].d;
+    const dfloat m = x[ib].m;
 
-    const uint8_t vui = x[ib].qs[iqs];
+    const int vui = x[ib].qs[iqs];
 
-    const int8_t vi0 = vui & 0xF;
-    const int8_t vi1 = vui >> 4;
+    v.x = vui & 0xF;
+    v.y = vui >> 4;
 
-    v0 = vi0*d + m;
-    v1 = vi1*d + m;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+#else
+    v.x = (v.x * d) + m;
+    v.y = (v.y * d) + m;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q5_0 * x = (const block_q5_0 *) vx;
 
-    const float d = x[ib].d;
+    const dfloat d = x[ib].d;
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
 
-    const uint8_t xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
-    const uint8_t xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
 
-    const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
-    const int32_t x1 = ((x[ib].qs[iqs] >>  4) | xh_1) - 16;
+    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
 
-    v0 = x0*d;
-    v1 = x1*d;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hsub2(v, {16.0f, 16.0f});
+    v = __hmul2(v, {d, d});
+#else
+    v.x = (v.x - 16.0f) * d;
+    v.y = (v.y - 16.0f) * d;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q5_1 * x = (const block_q5_1 *) vx;
 
-    const float d = x[ib].d;
-    const float m = x[ib].m;
+    const dfloat d = x[ib].d;
+    const dfloat m = x[ib].m;
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
 
-    const uint8_t xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
-    const uint8_t xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
 
-    const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
-    const int32_t x1 = ((x[ib].qs[iqs] >>  4) | xh_1);
+    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
 
-    v0 = x0*d + m;
-    v1 = x1*d + m;
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+#else
+    v.x = (v.x * d) + m;
+    v.y = (v.y * d) + m;
+#endif // GGML_CUDA_DMMV_F16
 }
 
-static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q8_0 * x = (const block_q8_0 *) vx;
 
-    const float d = x[ib].d;
+    const dfloat d = x[ib].d;
+
+    v.x = x[ib].qs[iqs + 0];
+    v.y = x[ib].qs[iqs + 1];
 
-    const int8_t vi0 = x[ib].qs[iqs + 0];
-    const int8_t vi1 = x[ib].qs[iqs + 1];
+#ifdef GGML_CUDA_DMMV_F16
+    v = __hmul2(v, {d, d});
+#else
+    v.x *= d;
+    v.y *= d;
+#endif // GGML_CUDA_DMMV_F16
+}
+
+//================================== k-quants
+
+static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
+
+    const int i   = blockIdx.x;
+    const int tid = threadIdx.x;
+    const int n   = tid/32;
+    const int l   = tid - 32*n;
+    const int is  = 8*n + l/16;
+
+    const block_q2_K * x = (const block_q2_K *) vx;
+
+    const uint8_t q = x[i].qs[32*n + l];
+    float * y = yy + i*QK_K + 128*n;
+
+    float dall = x[i].d;
+    float dmin = x[i].dmin;
+    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
+    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
+    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
 
-    v0 = vi0*d;
-    v1 = vi1*d;
 }
 
-static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
-    const half * x = (const half *) vx;
+static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
+
+    int r = threadIdx.x/4;
+    int i = blockIdx.x;
+    int tid = r/2;
+    int is0 = r%2;
+    int l0 = 16*is0 + 4*(threadIdx.x%4);
+    int n = tid / 4;
+    int j = tid - 4*n;
+
+    const block_q3_K * x = (const block_q3_K *) vx;
+
+    uint8_t m = 1 << (4*n + j);
+    int is = 8*n + 2*j + is0;
+    int shift = 2*j;
+
+    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+    float d_all = x[i].d;
+    float dl = d_all * (us - 32);
+
+    float * y = yy + i*QK_K + 128*n + 32*j;
+    const uint8_t * q = x[i].qs + 32*n;
+    const uint8_t * hm = x[i].hmask;
+
+    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
 
-    v0 = __half2float(x[ib + 0]);
-    v1 = __half2float(x[ib + 1]);
 }
 
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_block(const void * vx, float * y, const int k) {
-    const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
+static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+    if (j < 4) {
+        d = q[j] & 63; m = q[j + 4] & 63;
+    } else {
+        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
 
-    if (i >= k) {
-        return;
+static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
+    const block_q4_K * x = (const block_q4_K *) vx;
+
+    const int i = blockIdx.x;
+
+    //// assume 64 threads - this is very slightly better than the one below
+    //const int tid = threadIdx.x;
+    //const int il  = tid/16;
+    //const int ir  = tid%16;
+    //const int is  = 2*il;
+    //const int n   = 2;
+
+    // assume 32 threads
+    const int tid = threadIdx.x;
+    const int il  = tid/8;
+    const int ir  = tid%8;
+    const int is  = 2*il;
+    const int n   = 4;
+
+    float * y = yy + i*QK_K + 64*il + n*ir;
+
+    const float dall = x[i].d;
+    const float dmin = x[i].dmin;
+
+    const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+    for (int l = 0; l < n; ++l) {
+        y[l + 0] = d1 * (q[l] & 0xF) - m1;
+        y[l +32] = d2 * (q[l] >>  4) - m2;
     }
+}
 
-    const int ib = i/qk; // block index
-    const int iqs = (i%qk)/qr; // quant index
-    const int iybs = i - i%qk; // y block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
+static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
+    const block_q5_K * x = (const block_q5_K *) vx;
 
-    // dequantize
-    float & v0 = y[iybs + iqs + 0];
-    float & v1 = y[iybs + iqs + y_offset];
-    dequantize_kernel(vx, ib, iqs, v0, v1);
+    const int i = blockIdx.x;
+
+    // assume 64 threads - this is very slightly better than the one below
+    const int tid = threadIdx.x;
+    const int il  = tid/16;   // il is in 0...3
+    const int ir  = tid%16;   // ir is in 0...15
+    const int is  = 2*il;     // is is in 0...6
+
+    float * y = yy + i*QK_K + 64*il + 2*ir;
+
+    const float dall = x[i].d;
+    const float dmin = x[i].dmin;
+
+    const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+    const uint8_t * qh = x[i].qh + 2*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+
+    uint8_t   hm  = 1 << (2*il);
+    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+    hm <<= 1;
+    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
 }
 
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
-    // qk = quantized weights per x block
-    // qr = number of quantized weights per data value in x block
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
+    const block_q6_K * x = (const block_q6_K *) vx;
+
+    const int i = blockIdx.x;
+
+    // assume 64 threads - this is very slightly better than the one below
     const int tid = threadIdx.x;
+    const int ip  = tid/32;   // ip is 0 or 1
+    const int il  = tid - 32*ip; // 0...32
+    const int is  = 8*ip + il/16;
 
-    const int iter_stride = 2*GGML_CUDA_DMMV_X;
-    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
-    const int y_offset = qr == 1 ? 1 : qk/2;
+    float * y = yy + i*QK_K + 128*ip + il;
 
-    float tmp = 0; // partial sum for thread in warp
+    const float d = x[i].d;
 
-    for (int i = 0; i < ncols; i += iter_stride) {
-        const int col = i + vals_per_iter*tid;
-        const int ib = (row*ncols + col)/qk; // x block index
-        const int iqs = (col%qk)/qr; // x quant index
-        const int iybs = col - col%qk; // y block start index
+    const uint8_t * ql = x[i].ql + 64*ip + il;
+    const uint8_t   qh = x[i].qh[32*ip + il];
+    const int8_t  * sc = x[i].scales + is;
 
-// processing >2 values per i iter is faster for fast GPUs
-#pragma unroll
-        for (int j = 0; j < vals_per_iter; j += 2) {
-            // process 2 vals per j iter
+    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+}
 
-            // dequantize
-            float v0, v1;
-            dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
-            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q2_K * x = (const block_q2_K *)vx + ib0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
+
+        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
 
-            // matrix multiplication
-            tmp += v0 * y[iybs + iqs + j/qr + 0];
-            tmp += v1 * y[iybs + iqs + j/qr + y_offset];
-            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
         }
+        tmp += dall * sum1 - dmin * sum2;
+
     }
 
     // sum up partial sums and write back result
@@ -258,166 +578,936 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
     }
 }
 
-static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
-    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
-    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
-}
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
 
-static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
 
-static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    const block_q3_K * x = (const block_q3_K *)vx + ib0;
 
-static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
-    dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
-        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
-}
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
 
-static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
-    dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
-        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
-}
+    const uint8_t m = 1 << (4*im);
 
-static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
-    dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
-        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
-}
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
 
-static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
-    dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
-        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
-}
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
 
-static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
-    dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
-        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
-}
+    const uint16_t s_shift = 4*im;
 
-static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
+    float tmp = 0; // partial sum for thread in warp
 
-static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
-    dequantize_mul_mat_vec<1, 1, convert_f16>
-        <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
-}
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+        const uint8_t * h = x[i].hmask + l0;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp += d * sum;
 
-static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_row_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_row_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_row_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_row_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_row_q8_0_cuda;
-        case GGML_TYPE_F16:
-            return convert_fp16_to_fp32_cuda;
-        default:
-            return nullptr;
     }
-}
 
-static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_mul_mat_vec_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_mul_mat_vec_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_mul_mat_vec_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_mul_mat_vec_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_mul_mat_vec_q8_0_cuda;
-        case GGML_TYPE_F16:
-            return convert_mul_mat_vec_f16_cuda;
-        default:
-            return nullptr;
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
     }
 }
 
-// buffer pool for cuda
-#define MAX_CUDA_BUFFERS 256
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-struct scoped_spin_lock {
-    std::atomic_flag& lock;
-    scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
-        while (lock.test_and_set(std::memory_order_acquire)) {
-            ; // spin
-        }
-    }
-    ~scoped_spin_lock() {
-        lock.clear(std::memory_order_release);
-    }
-    scoped_spin_lock(const scoped_spin_lock&) = delete;
-    scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
-};
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
 
-struct cuda_buffer {
-    void * ptr = nullptr;
-    size_t size = 0;
-};
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
-static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
-    scoped_spin_lock lock(g_cuda_pool_lock);
+    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
 
-    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        cuda_buffer& b = g_cuda_buffer_pool[i];
-        if (b.size >= size && b.ptr != nullptr) {
-            void * ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
-        }
-    }
-    void * ptr;
+    const int il  = tid/step;                            // 0...3
+    const int ir  = tid - step*il;                       // 0...7 or 0...3
+    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const uint8_t * q1 = x[i].qs + q_offset;
+        const uint8_t * q2 = q1 + 64;
+        const float   * y1 = yy + i*QK_K + y_offset;
+        const float   * y2 = y1 + 128;
+
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
+            s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    //const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int row = blockIdx.x;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const int tid = threadIdx.x/2;  // 0...15
+    const int ix  = threadIdx.x%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 2;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        const uint8_t * ql1 = x[i].qs + q_offset;
+        const uint8_t * ql2 = ql1 + 64;
+        const uint8_t * qh  = x[i].qh + l0;
+        const float   * y1  = yy + i*QK_K + y_offset;
+        const float   * y2  = y1 + 128;
+
+        const float dall = x[i].d;
+        const float dmin = x[i].dmin;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 sum = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+        }
+        tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q6_K * x = (const block_q6_K *)vx + ib0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * ql = x[i].ql + ql_offset;
+        const uint8_t * qh = x[i].qh + qh_offset;
+        const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = x[i].d;
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp += sum;
+#endif
+
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const half * x = (const half *) vx;
+
+    // automatic half -> float type cast if dfloat == float
+    v.x = x[ib + iqs + 0];
+    v.y = x[ib + iqs + 1];
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_block(const void * vx, float * y, const int k) {
+    const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const int ib = i/qk; // block index
+    const int iqs = (i%qk)/qr; // quant index
+    const int iybs = i - i%qk; // y block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(vx, ib, iqs, v);
+
+    y[iybs + iqs + 0]        = v.x;
+    y[iybs + iqs + y_offset] = v.y;
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
+    // qk = quantized weights per x block
+    // qr = number of quantized weights per data value in x block
+    const int row = blockIdx.y*blockDim.y + threadIdx.y;
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int tid = threadIdx.x;
+
+    const int iter_stride = 2*GGML_CUDA_DMMV_X;
+    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+// partial sum for each thread
+#ifdef GGML_CUDA_DMMV_F16
+    half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
+#else
+    float tmp = 0.0f;
+#endif // GGML_CUDA_DMMV_F16
+
+    for (int i = 0; i < ncols; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int ib = (row*ncols + col)/qk; // x block index
+        const int iqs = (col%qk)/qr; // x quant index
+        const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+        for (int j = 0; j < vals_per_iter; j += 2) {
+            // process 2 vals per j iter
+
+            // dequantize
+            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+            dfloat2 v;
+            dequantize_kernel(vx, ib, iqs + j/qr, v);
+
+            // matrix multiplication
+            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_CUDA_DMMV_F16
+            tmp += __hmul2(v, {
+                y[iybs + iqs + j/qr + 0],
+                y[iybs + iqs + j/qr + y_offset]
+            });
+#else
+            tmp += v.x * y[iybs + iqs + j/qr + 0];
+            tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+#endif // GGML_CUDA_DMMV_F16
+        }
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+#ifdef GGML_CUDA_DMMV_F16
+        dst[row] = tmp.x + tmp.y;
+#else
+        dst[row] = tmp;
+#endif // GGML_CUDA_DMMV_F16
+    }
+}
+
+static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+    const half * x = (half *) vx;
+
+    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+
+    const int nrows_y = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst = row_x;
+
+    float tmp = 0.0f;
+
+    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+        const int col_x = col_x0 + threadIdx.x;
+
+        if (col_x >= ncols_x) {
+            break;
+        }
+
+        // x is transposed and permuted
+        const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
+        const float xi = __half2float(x[ix]);
+
+        const int row_y = col_x;
+
+
+        // y is not transposed but permuted
+        const int iy = channel*nrows_y + row_y;
+
+        tmp += xi * y[iy];
+    }
+
+    // dst is not transposed and not permuted
+    const int idst = channel*nrows_dst + row_dst;
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[idst] = tmp;
+    }
+}
+
+static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+    const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
+
+    const half * x = (half *) vx;
+
+    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+
+    const int nrows_y = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst = row_x;
+
+    const int idst = channel*nrows_dst + row_dst;
+
+    float tmp = 0.0f;
+
+    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+        const int col_x = col_x0 + threadIdx.x;
+
+        if (col_x >= ncols_x) {
+            break;
+        }
+
+        const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
+        const float xi = __half2float(x[ix]);
+
+        const int row_y = col_x;
+
+        const int iy = channel*nrows_y + row_y;
+
+        tmp += xi * y[iy];
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[idst] = tmp;
+    }
+}
+
+static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+    const float * xi = (float *) cxi;
+    float * dsti = (float *) cdsti;
+
+    *dsti = *xi;
+}
+
+static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+    const float * xi = (float *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = __float2half(*xi);
+}
+
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+                                   const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+                                   const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= ne) {
+        return;
+    }
+
+    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+    // then combine those indices with the corresponding byte offsets to get the total offsets
+    const int i02 = i / (ne00*ne01);
+    const int i01 = (i - i02*ne01*ne00) / ne00;
+    const int i00 = i - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+
+    const int i12 = i / (ne10*ne11);
+    const int i11 = (i - i12*ne10*ne11) / ne10;
+    const int i10 = i - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+
+    cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+// rope == RoPE == rotary positional embedding
+static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
+    const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i = row*ncols + col;
+
+    const float theta = p*powf(theta_scale, col/2);
+    const float sin_theta = sinf(theta);
+    const float cos_theta = cosf(theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + 1];
+
+    dst[i + 0] = x0*cos_theta - x1*sin_theta;
+    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
+
+static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
+    const int col = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int i = row*ncols + col;
+    // dst[i] = col > n_past + row ? -INFINITY : x[i];
+    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+}
+
+// the CUDA soft max implementation differs from the CPU implementation
+// instead of doubles floats are used
+// values are also not normalized to the maximum value by subtracting it in the exponential function
+// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
+static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int block_size = blockDim.x;
+    const int tid = threadIdx.x;
+
+    float tmp = 0.0;
+
+    for (int block_start = 0; block_start < ncols; block_start += block_size) {
+        const int col = block_start + tid;
+
+        if (col >= ncols) {
+            break;
+        }
+
+        const int i = row*ncols + col;
+        const float val = expf(x[i]);
+        tmp += val;
+        dst[i] = val;
+    }
+
+    // sum up partial sums
+    __syncthreads();
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    for (int block_start = 0; block_start < ncols; block_start += block_size) {
+        const int col = block_start + tid;
+
+        if (col >= ncols) {
+            break;
+        }
+
+        const int i = row*ncols + col;
+        dst[i] /= tmp;
+    }
+}
+
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = scale * x[i];
+}
+
+static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
+static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
+    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
+}
+
+static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+    silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
+static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+    const dim3 block_nums(1, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+    dequantize_mul_mat_vec<1, 1, convert_f16>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_row_q5_0_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_row_q5_1_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_row_q8_0_cuda;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_F16:
+            return convert_fp16_to_fp32_cuda;
+        default:
+            return nullptr;
+    }
+}
+
+static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
+    const dim3 block_nums(1, nrows_x, nchannels_x);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
+}
+
+static void ggml_mul_mat_vec_nc_f16_f32_cuda(
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
+    const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
+
+    const dim3 block_nums(1, nrows_x, nchannels_x);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
+        (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
+}
+
+static void ggml_cpy_f32_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+    scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
+}
+
+static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
+    GGML_ASSERT(nrows % 2 == 0);
+    const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(num_blocks_x, nrows, 1);
+    rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
+}
+
+static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
+    const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, nrows_x, 1);
+    diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
+}
+
+static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(1, nrows_x, 1);
+    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
+}
+
+// buffer pool for cuda
+#define MAX_CUDA_BUFFERS 256
+
+struct scoped_spin_lock {
+    std::atomic_flag& lock;
+    scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
+        while (lock.test_and_set(std::memory_order_acquire)) {
+            ; // spin
+        }
+    }
+    ~scoped_spin_lock() {
+        lock.clear(std::memory_order_release);
+    }
+    scoped_spin_lock(const scoped_spin_lock&) = delete;
+    scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
+};
+
+struct cuda_buffer {
+    void * ptr = nullptr;
+    size_t size = 0;
+};
+
+static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
+static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
+
+static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        cuda_buffer& b = g_cuda_buffer_pool[id][i];
+        if (b.size >= size && b.ptr != nullptr) {
+            void * ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
+        }
+    }
+    void * ptr;
     CUDA_CHECK(cudaMalloc((void **) &ptr, size));
     *actual_size = size;
     return ptr;
@@ -425,9 +1515,11 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
 
 static void ggml_cuda_pool_free(void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
 
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        cuda_buffer& b = g_cuda_buffer_pool[i];
+        cuda_buffer& b = g_cuda_buffer_pool[id][i];
         if (b.ptr == nullptr) {
             b.ptr = ptr;
             b.size = size;
@@ -438,31 +1530,74 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
     CUDA_CHECK(cudaFree(ptr));
 }
 
-#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
-#define GGML_CUDA_MAX_EVENTS 64
-static cublasHandle_t g_cublasH = nullptr;
-static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
+
+static void * g_scratch_buffer = nullptr;
+static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
+static size_t g_scratch_offset = 0;
+
+static int g_device_count = -1;
+static int g_main_device = 0;
+static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
+
+static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
 
 void ggml_init_cublas() {
-    if (g_cublasH == nullptr) {
-        // create streams
-        for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
+    static bool initialized = false;
+
+    if (!initialized) {
+        CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
+        GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
+        int64_t total_vram = 0;
+        fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
+        for (int id = 0; id < g_device_count; ++id) {
+            cudaDeviceProp prop;
+            CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
+            fprintf(stderr, "  Device %d: %s\n", id, prop.name);
+            g_tensor_split[id] = total_vram;
+            total_vram += prop.totalGlobalMem;
         }
-        // create events
-        for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
-            CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
+        for (int id = 0; id < g_device_count; ++id) {
+            g_tensor_split[id] /= total_vram;
         }
 
-        // create cublas handle
-        CUBLAS_CHECK(cublasCreate(&g_cublasH));
-        CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
+        for (int id = 0; id < g_device_count; ++id) {
+            CUDA_CHECK(cudaSetDevice(id));
+
+            // create main stream
+            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
+
+            // create cublas handle
+            CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
+            CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
+        }
 
         // configure logging to stdout
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
+
+        initialized = true;
+    }
+}
+
+void ggml_cuda_set_tensor_split(const float * tensor_split) {
+    bool all_zero = true;
+    for (int i = 0; i < g_device_count; ++i) {
+        if (tensor_split[i] != 0.0f) {
+            all_zero = false;
+            break;
+        }
+    }
+    if (all_zero) {
+        return;
+    }
+    float split_sum = 0.0f;
+    for (int i = 0; i < g_device_count; ++i) {
+        g_tensor_split[i] = split_sum;
+        split_sum += tensor_split[i];
+    }
+    for (int i = 0; i < g_device_count; ++i) {
+        g_tensor_split[i] /= split_sum;
     }
 }
 
@@ -471,370 +1606,698 @@ void * ggml_cuda_host_malloc(size_t size) {
         return nullptr;
     }
 
-    void * ptr = nullptr;
-    cudaError_t err = cudaMallocHost((void **) &ptr, size);
-    if (err != cudaSuccess) {
-        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
-            size/1024.0/1024.0, cudaGetErrorString(err));
-        return nullptr;
+    void * ptr = nullptr;
+    cudaError_t err = cudaMallocHost((void **) &ptr, size);
+    if (err != cudaSuccess) {
+        // The allocation error can be bypassed. A null ptr will assigned out of this function.
+        // This can fixed the OOM error in WSL.
+        cudaGetLastError();
+        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
+            size/1024.0/1024.0, cudaGetErrorString(err));
+        return nullptr;
+    }
+
+    return ptr;
+}
+
+void ggml_cuda_host_free(void * ptr) {
+    CUDA_CHECK(cudaFreeHost(ptr));
+}
+
+static cudaError_t ggml_cuda_cpy_tensor_2d(
+    void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
+
+    cudaMemcpyKind kind;
+    char * src_ptr;
+    if (src->backend == GGML_BACKEND_CPU) {
+        kind = cudaMemcpyHostToDevice;
+        src_ptr = (char *) src->data;
+    } else if (src->backend == GGML_BACKEND_GPU) {
+        kind = cudaMemcpyDeviceToDevice;
+        struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
+        int id;
+        CUDA_CHECK(cudaGetDevice(&id));
+        src_ptr = (char *) extra->data_device[id];
+    } else {
+        GGML_ASSERT(false);
+    }
+    char * dst_ptr = (char *) dst;
+
+    const int64_t ne0 = src->ne[0];
+    const int64_t nb0 = src->nb[0];
+    const int64_t nb1 = src->nb[1];
+    const int64_t nb2 = src->nb[2];
+    const int64_t nb3 = src->nb[3];
+    const enum ggml_type type = src->type;
+    const int64_t ts = ggml_type_size(type);
+    const int64_t bs = ggml_blck_size(type);
+    int64_t i1_diff = i1_high - i1_low;
+
+    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
+    } else if (nb0 == ts) {
+        return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
+    } else {
+        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
+            if (r != cudaSuccess) return r;
+        }
+        return cudaSuccess;
+    }
+}
+
+inline void ggml_cuda_op_add(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne0 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) i02;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_mul(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
+        const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
+
+        float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
+        float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
+        float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
+
+        // compute
+        mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
+        CUDA_CHECK(cudaGetLastError());
+    }
+
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) i02;
+}
+
+inline void ggml_cuda_op_silu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_rms_norm(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_dequantize_mul_mat_vec(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddq_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t nrows = i01_high - i01_low;
+
+// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_CUDA_DMMV_F16
+    size_t ash;
+    dfloat * src1_dfloat = nullptr; // dfloat == half
+
+    bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+    if (src1_convert_f16) {
+        src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+        ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
+                                ne00, 1, sizeof(float), 0, 0,
+                                ne00, 1, sizeof(half),  0, 0, cudaStream_main);
+    }
+#else
+    dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
+#endif // GGML_CUDA_DMMV_F16
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q4_1:
+            dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_0:
+            dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_1:
+            dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q8_0:
+            dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q2_K:
+            dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q3_K:
+            dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q4_K:
+            dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_K:
+            dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_Q6_K:
+            dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        case GGML_TYPE_F16:
+            convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+    CUDA_CHECK(cudaGetLastError());
+
+#ifdef GGML_CUDA_DMMV_F16
+    if (src1_convert_f16) {
+        ggml_cuda_pool_free(src1_dfloat, ash);
     }
+#endif // GGML_CUDA_DMMV_F16
 
-    return ptr;
+    (void) src1;
+    (void) dst;
+    (void) src0_ddf_i;
+    (void) i02;
+    (void) i1;
 }
 
-void ggml_cuda_host_free(void * ptr) {
-    CUDA_CHECK(cudaFreeHost(ptr));
-}
+inline void ggml_cuda_op_mul_mat_cublas(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
 
-static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
-    const uint64_t ne0 = src->ne[0];
-    const uint64_t ne1 = src->ne[1];
-    const uint64_t nb0 = src->nb[0];
-    const uint64_t nb1 = src->nb[1];
-    const uint64_t nb2 = src->nb[2];
-    const uint64_t nb3 = src->nb[3];
-    const enum ggml_type type = src->type;
-    const size_t ts = ggml_type_size(type);
-    const size_t bs = ggml_blck_size(type);
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
 
-    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
-    if (nb0 == ts && nb1 == ts*ne0/bs) {
-        return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
-    } else if (nb0 == ts) {
-        return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
-    } else {
-        for (uint64_t i1 = 0; i1 < ne1; i1++) {
-            const void * rx = (const void *) ((const char *) x + i1*nb1);
-            void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
-            // pretend the row is a matrix with cols=1
-            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
-            if (r != cudaSuccess) return r;
-        }
-        return cudaSuccess;
-    }
-}
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
 
-static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
     const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[2];
-    const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
+
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
-    size_t x_size, d_size;
-
-    float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
-    float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
-    float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            const int i0 = i03*ne02 + i02;
-            float * c_X2 = d_X + i0*ne01*ne00;
-            float * c_D2 = d_D + i0*ne01*ne00;
-
-            cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaEvent_t  cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
-
-            // copy src0 to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
-            CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
-
-            // wait for data
-            CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
-
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                const int64_t i13 = i03%ne13;
-                const int64_t i12 = i02%ne12;
-                const int64_t i11 = i01%ne11;
-                const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
-
-                float * c_X1 = c_X2 + i01*ne00;
-                float * c_Y = d_Y + i1*ne10;
-                float * c_D1 = c_D2 + i01*ne00;
-
-                // compute
-                mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
-                CUDA_CHECK(cudaGetLastError());
-            }
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
-        }
-    }
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_D, d_size);
+    const int64_t ne0 = dst->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // ldc == nrows of the matrix that cuBLAS writes into
+    int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
+
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
+    CUBLAS_CHECK(
+        cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                i01_diff, ne11, ne10,
+                &alpha, src0_ddf_i, ne00,
+                        src1_ddf_i, ne10,
+                &beta,  dst_ddf_i,  ldc));
+
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) i02;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_rope(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+    GGML_ASSERT(mode == 0);
+
+    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+    const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
+
+    // compute
+    rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i1;
 }
 
-static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+inline void ggml_cuda_op_diag_mask_inf(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
+    const int64_t i01_diff = i01_high - i01_low;
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
+    const int n_past = ((int32_t *) src1->data)[0];
 
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    // compute
+    diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
 
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
 
-    size_t x_size, y_size, d_size;
-    float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
-    float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
+inline void ggml_cuda_op_soft_max(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
 
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
 
-            float * c_X = d_X + i * x_ne;
-            float * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
 
-            // copy data to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+inline void ggml_cuda_op_scale(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
 
-            // compute
-            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-            CUBLAS_CHECK(
-                cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                        ne01, ne11, ne10,
-                        &alpha, c_X, ne00,
-                                c_Y, ne10,
-                        &beta,  c_D, ne01));
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-        }
-    }
+    const float scale = ((float *) src1->data)[0];
 
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
 }
 
-static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
+static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                         ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
+    const int64_t nrows0 = ggml_nrows(src0);
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
+    const bool use_src1 = src1 != nullptr;
+    const int64_t ne10 = use_src1 ? src1->ne[0] : 1;
+    const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
+    const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
+    const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
 
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
 
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
 
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
-
-    size_t x_size, y_size, d_size;
-    half  * d_X =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
-    half  * d_Y =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
-
-    bool src1_cont_rows = nb10 == sizeof(float);
-    bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
-
-            half  * c_X = d_X + i * x_ne;
-            half  * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
-
-            // copy src0 to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
-
-            // convert src1 to fp16
-            // TODO: use multiple threads
-            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
-            char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
-            if (src1_cont_rows) {
-                if (src1_cont_cols) {
-                    ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
-                }
-                else {
-                    for (int64_t i01 = 0; i01 < ne11; i01++) {
-                        ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
-                    }
-                }
-            }
-            else {
-                for (int64_t i01 = 0; i01 < ne11; i01++) {
-                    for (int64_t i00 = 0; i00 < ne10; i00++) {
-                        // very slow due to no inlining
-                        tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
-                    }
-                }
-            }
+    GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
 
-            // copy src1 to device
-            CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+    // strides for iteration over dims 3 and 2
+    const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
+    const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
+    const int64_t src0_stride = ne00 * ne01 * stride_mod;
+    const int64_t src1_stride = ne10 * ne11 * stride_mod;
+    const int64_t dst_stride = ne0 * ne1 * stride_mod;
 
-            // compute
-            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-            CUBLAS_CHECK(
-                cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                        ne01, ne11, ne10,
-                        &alpha, c_X, CUDA_R_16F, ne00,
-                                c_Y, CUDA_R_16F, ne10,
-                        &beta,  c_D, CUDA_R_32F, ne01,
-                        CUBLAS_COMPUTE_32F_FAST_16F,
-                        CUBLAS_GEMM_DEFAULT));
+    const size_t src0_ts = ggml_type_size(src0->type);
+    const size_t src0_bs = ggml_blck_size(src0->type);
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-        }
+    struct ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
+    struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+    struct ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *) dst->extra;
+
+    const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
+    const bool src0_is_contiguous = ggml_is_contiguous(src0);
+    const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
+
+    const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
+    const bool src1_stays_on_host = use_src1 && (
+        dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
+
+    const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+
+    // dd = data device
+    char  * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
+    float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
+    float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
+    float *  dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+    // asq = actual size quantized, asf = actual size float
+    size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t  dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
+
+    // if multiple GPUs are used they need to wait for the main GPU to finish
+    if (split && g_device_count > 1) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        CUDA_CHECK(cudaDeviceSynchronize());
     }
 
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
-}
+    for (int id = 0; id < g_device_count; ++id) {
+        if (!split && id != g_main_device) {
+            continue;
+        }
 
-static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
+        const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU && id == g_main_device;
+        const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
+        int64_t row_low, row_high;
+        if (split) {
+            row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
+            row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
+        } else {
+            row_low = 0;
+            row_high = nrows0;
+        }
+        if (row_low == row_high) {
+            continue;
+        }
 
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
-    const ggml_type type = src0->type;
-    const bool mul_mat_vec = ne11 == 1;
+        int64_t row_diff = row_high - row_low;
 
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
-    const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
-
-    size_t x_size, y_size, d_size, q_size;
-    float * d_X = nullptr;
-    if (!mul_mat_vec) {
-        d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
-    }
-    float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
-    char  * d_Q = (char  *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
-
-    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
-    dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
-    GGML_ASSERT(to_fp32_cuda != nullptr);
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
-            cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
-            cudaEvent_t  cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
-
-            float * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
-            char  * c_Q = d_Q + i * q_sz;
-
-            // copy src0 to device if necessary
-            if (src0->backend == GGML_BACKEND_CPU) {
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
-            } else if (src0->backend == GGML_BACKEND_CUDA) {
-                c_Q = ((char *) src0->data) + i * q_sz;
+        cudaSetDevice(id);
+
+        if (src0_on_device && src0_is_contiguous) {
+            if (src0_is_f32) {
+                src0_ddf[id] = (float *) src0_extra->data_device[id];
+            } else {
+                src0_ddq[id] = (char *) src0_extra->data_device[id];
+            }
+        } else {
+            if (src0_is_f32) {
+                src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
             } else {
-                GGML_ASSERT(false);
+                src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]);
             }
-            if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
-                CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+        }
 
-                // copy src1 to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+        if (src0_needs_f32 && !src0_is_f32) {
+            src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
+        }
 
-                // wait for data
-                CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+        if (use_src1 && !src1_stays_on_host) {
+            if (src1_on_device && src1_is_contiguous) {
+                src1_ddf[id] = (float *) src1_extra->data_device[id];
+            } else {
+                src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
+            }
+        }
+        if (dst_on_device) {
+            dst_ddf[id] = (float *) dst_extra->data_device[id];
+        } else {
+            size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float);
+            dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
+        }
 
-                // compute
-                dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
-                CUDA_CHECK(cudaGetLastError());
+        const int64_t i03_max = flatten_rows ? 1 : ne03;
+        const int64_t i02_max = flatten_rows ? 1 : ne02;
+        const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
 
-            } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
-                float * c_X = d_X + i * x_ne;
+        for (int64_t i03 = 0; i03 < i03_max; i03++) {
+            const int64_t i13 = i03 % ne13;
+            for (int64_t i02 = 0; i02 < i02_max; i02++) {
+                const int64_t i12 = i02 % ne12;
 
-                // convert src0 to fp32 on device
-                to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
-                CUDA_CHECK(cudaGetLastError());
-                CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+                const int64_t i0 = i03*ne02 + i02;
 
-                // copy src1 to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+                // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
+                const int64_t i0_offset_low = row_low/rows_per_iter;
+                const int64_t i0_offset_high = row_high/rows_per_iter;
 
-                // wait for conversion
-                CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+                int64_t i01_low = 0;
+                int64_t i01_high = rows_per_iter;
+                if (split) {
+                    if (i0 < i0_offset_low || i0 > i0_offset_high) {
+                        continue;
+                    }
+                    if (i0 == i0_offset_low) {
+                        i01_low = row_low % rows_per_iter;
+                    }
+                    if (i0 == i0_offset_high) {
+                        i01_high = row_high % rows_per_iter;
+                    }
+                }
 
-                // compute
-                CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-                CUBLAS_CHECK(
-                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                            ne01, ne11, ne10,
-                            &alpha, c_X, ne00,
-                                    c_Y, ne10,
-                            &beta,  c_D, ne01));
-            }
+                // There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables.
+                // Removing the first assert or changing the order of the arguments causes the second assert to fail.
+                // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
+                // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
+                GGML_ASSERT(i01_low == 0 || g_device_count > 1);
+                GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                const int64_t i01_diff = i01_high - i01_low;
+                if (i01_diff == 0) {
+                    continue;
+                }
+                const int64_t i11 = i13*ne12 + i12;
+
+                cudaStream_t cudaStream_main = g_cudaStreams_main[id];
+
+                // for split tensors the data begins at i0 == i0_offset_low
+                char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
+                float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
+                float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
+                float * dst_ddf_i  =  dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+
+                // for split tensors the data pointer needs to be rounded down
+                // to the bin edge for i03, i02 bins beyond the first
+                if (i0 - i0_offset_low > 0) {
+                    GGML_ASSERT(!flatten_rows);
+                    src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
+                    src0_ddf_i -= (row_low % ne01)*ne00;
+                    dst_ddf_i  -= (row_low % ne0)*ne1;
+                }
+
+                // the main device memory buffer can be on VRAM scratch, with space for all partial results
+                // in that case an offset on dst_ddf_i is needed
+                if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
+                    dst_ddf_i += i01_low; // offset is 0 if no tensor split
+                }
+
+                // copy src0, src1 to device if necessary
+                if (use_src1 && !src1_stays_on_host) {
+                    if (src1->backend == GGML_BACKEND_CPU) {
+                        GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
+                        int64_t nrows1 = flatten_rows ? nrows0 : ne11;
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
+                    } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
+                        if (id != g_main_device) {
+                            GGML_ASSERT(!flatten_rows);
+                            float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
+                            src1_ddf_i_source += i11*src1_stride;
+                            CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
+                                                    cudaMemcpyDeviceToDevice, cudaStream_main));
+                        }
+                    } else if (src1_on_device && !src1_is_contiguous) {
+                        GGML_ASSERT(!split);
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
+                    } else {
+                        GGML_ASSERT(false);
+                    }
+                }
+
+                if (!src0_on_device || !src0_is_contiguous) {
+                    if (src0_is_f32) {
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                    } else {
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                    }
+                }
+
+                // convert src0 to f32 if it is necessary for the ggml_cuda_op
+                if (src0_needs_f32 && !src0_is_f32) {
+                    to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
+                    CUDA_CHECK(cudaGetLastError());
+                }
+
+                // do the computation
+                op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
+
+                // copy dst to host or other device if necessary
+                if (!dst_on_device) {
+                    void * dst_off_device;
+                    cudaMemcpyKind kind;
+                    if (dst->backend == GGML_BACKEND_CPU) {
+                        dst_off_device = dst->data;
+                        kind = cudaMemcpyDeviceToHost;
+                    } else if (dst->backend == GGML_BACKEND_GPU) {
+                        dst_off_device = dst_extra->data_device[g_main_device];
+                        kind = cudaMemcpyDeviceToDevice;
+                    } else {
+                        GGML_ASSERT(false);
+                    }
+                    if (split) {
+                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+                        // dst is NOT transposed.
+                        // The outputs of cuBLAS matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
+                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+                        for (int64_t j = 0; j < ne1; ++j) {
+                            float * dhf_dst_i = (float *) ((char *) dst_off_device + (j*ne0 + i01_low)*sizeof(float) + i02*nb2 + i03*nb3);
+                            CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i + j*i01_diff, i01_diff*sizeof(float), kind, cudaStream_main));
+                        }
+                    } else {
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
+                    }
+                }
+            }
         }
     }
 
-    CUDA_CHECK(cudaDeviceSynchronize());
-    if (!mul_mat_vec) {
-        ggml_cuda_pool_free(d_X, x_size);
+    // wait until each device is finished, then free their buffers
+    for (int id = 0; id < g_device_count; ++id) {
+        if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
+            continue;
+        }
+
+        CUDA_CHECK(cudaSetDevice(id));
+        CUDA_CHECK(cudaDeviceSynchronize());
+
+        if (src0_asq[id] > 0) {
+            ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
+        }
+        if (src0_asf[id] > 0) {
+            ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]);
+        }
+        if (src1_asf[id] > 0) {
+            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+        }
+        if (dst_asf[id] > 0) {
+            ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
+        }
     }
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
-    ggml_cuda_pool_free(d_Q, q_size);
 }
 
-void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_mul_f32(src0, src1, dst);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
+}
+
+void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
+}
+
+void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
+}
+
+void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@@ -847,111 +2310,414 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
     if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
         src1->type == GGML_TYPE_F32 &&
         dst->type == GGML_TYPE_F32 &&
-        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
+        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
         return true;
     }
 
     return false;
 }
 
-bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
-    size_t src0_sz = ggml_nbytes(src0);
-    size_t src1_sz = ggml_nbytes(src1);
+void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
+    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    CUDA_CHECK(cudaSetDevice(g_main_device));
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
 
-    // mul_mat_q: src0 is converted to fp32 on device
-    size_t mul_mat_q_transfer = src0_sz + src1_sz;
+    struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    void * src0_ddq = src0_extra->data_device[g_main_device];
 
-    // mul_mat_f16: src1 is converted to fp16 on cpu
-    size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
+    struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
 
-    // choose the smaller one to transfer to the device
-    // TODO: this is not always the best choice due to the overhead of converting to fp16
-    return mul_mat_f16_transfer < mul_mat_q_transfer;
+    struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
 }
 
-void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
-    GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
+void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+    GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
+    GGML_ASSERT(!ggml_is_permuted(src0));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
-    if (src0->type == GGML_TYPE_F32) {
-        ggml_cuda_mul_mat_f32(src0, src1, dst);
-    }
-    else if (src0->type == GGML_TYPE_F16) {
-        if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
-            ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
-        }
-        else {
-            ggml_cuda_mul_mat_q_f32(src0, src1, dst);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
+
+    CUDA_CHECK(cudaSetDevice(g_main_device));
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+
+    struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    void * src0_ddq = src0_extra->data_device[g_main_device];
+
+    struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+    struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+    const int row_stride_x = nb01 / sizeof(half);
+    const int channel_stride_x = nb02 / sizeof(half);
+
+    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
+}
+
+void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
+        src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
+
+    if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
+    } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
+        ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
+    }else if (src0->type == GGML_TYPE_F32) {
+        ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+    } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
+        if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
+            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
+        } else {
+            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
         }
+    } else {
+        GGML_ASSERT(false);
     }
-    else if (ggml_is_quantized(src0->type)) {
-        ggml_cuda_mul_mat_q_f32(src0, src1, dst);
-    }
-    else {
+}
+
+void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
+}
+
+void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne == ggml_nelements(src1));
+
+    GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
+    GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
+
+    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    GGML_ASSERT(src0->ne[3] == 1);
+
+    const int64_t nb00 = src0->nb[0];
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    GGML_ASSERT(src1->ne[3] == 1);
+
+    const int64_t nb10 = src1->nb[0];
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2];
+
+    CUDA_CHECK(cudaSetDevice(g_main_device));
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+
+    const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+
+    char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+    char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
+
+    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+                              ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+                              ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+    } else {
         GGML_ASSERT(false);
     }
+
+    (void) dst;
+}
+
+void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
+}
+
+void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
+}
+
+void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
+}
+
+void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    (void) src0;
+    (void) src1;
+    (void) dst;
+}
+
+void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
+    int nrows = ggml_nrows(tensor);
+    const size_t nb1 = tensor->nb[1];
+    ggml_backend backend = tensor->backend;
+    struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
+    memset(extra, 0, sizeof(*extra));
+
+    for (int id = 0; id < g_device_count; ++id) {
+        if (backend == GGML_BACKEND_GPU && id != g_main_device) {
+            continue;
+        }
+
+        cudaSetDevice(id);
+
+        int row_low, row_high;
+        if (backend == GGML_BACKEND_GPU) {
+            row_low = 0;
+            row_high = nrows;
+        } else if (backend == GGML_BACKEND_GPU_SPLIT) {
+            row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
+            row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
+        } else {
+            GGML_ASSERT(false);
+        }
+        if (row_low == row_high) {
+            continue;
+        }
+
+        int64_t nrows_split = row_high - row_low;
+
+        const size_t offset_split = row_low*nb1;
+        const size_t size = ggml_nbytes_split(tensor, nrows_split);
+
+        void * buf;
+        CUDA_CHECK(cudaMalloc(&buf, size));
+        void * buf_host = (char*)data + offset_split;
+
+        cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+
+        extra->data_device[id] = buf;
+    }
+
+    tensor->extra = extra;
 }
 
-size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
-        return ggml_nelements(src1) * sizeof(ggml_fp16_t);
+void ggml_cuda_free_data(struct ggml_tensor * tensor) {
+    if (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) {
+        return;
     }
-    else {
-        return 0;
+
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+    for (int id = 0; id < g_device_count; ++id) {
+        if (extra->data_device[id] == nullptr) {
+            continue;
+        }
+
+        CUDA_CHECK(cudaSetDevice(id));
+        CUDA_CHECK(cudaFree(extra->data_device[id]));
     }
+
+    delete extra;
 }
 
-void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
-    const int64_t ne0 = tensor->ne[0];
-    const int64_t ne1 = tensor->ne[1];
-    const int64_t ne2 = tensor->ne[2];
-    const int64_t ne3 = tensor->ne[3];
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
+    if (scratch && g_scratch_size == 0) {
+        return;
+    }
+
+    // recursively assign CUDA buffers until a compute tensor is found
+    if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
+        const ggml_op src0_op = tensor->src0->op;
+        if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
+            ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
+        }
+    }
+    if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
+        ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
+    }
 
-    const ggml_type type = tensor->type;
-    const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
+    tensor->backend = GGML_BACKEND_GPU;
+    struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
 
-    size_t q_size;
-    char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
+    const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
+        tensor->op == GGML_OP_VIEW;
+    const size_t size = ggml_nbytes(tensor);
 
-    cudaStream_t cudaStream2 = g_cudaStreams2[0];
+    CUDA_CHECK(cudaSetDevice(g_main_device));
+    if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
+        struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
+        char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+        size_t offset = 0;
+        if (tensor->op == GGML_OP_VIEW) {
+            memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
+        }
+        extra->data_device[g_main_device] = src0_ddc + offset;
+    } else if (tensor->op == GGML_OP_CPY) {
+        struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
+        void * src1_ddv = src1_extra->data_device[g_main_device];
+        extra->data_device[g_main_device] = src1_ddv;
+    } else if (scratch) {
+        GGML_ASSERT(size <= g_scratch_size);
+        if (g_scratch_offset + size > g_scratch_size) {
+            g_scratch_offset = 0;
+        }
 
-    // copy tensor to device
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            int i = i3*ne2 + i2;
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
+        char * data = (char *) g_scratch_buffer;
+        if (data == nullptr) {
+            CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
+            g_scratch_buffer = data;
         }
+        extra->data_device[g_main_device] = data + g_scratch_offset;
+
+        g_scratch_offset += size;
+
+        GGML_ASSERT(g_scratch_offset <= g_scratch_size);
+    } else { // allocate new buffers outside of scratch
+        void * data;
+        CUDA_CHECK(cudaMalloc(&data, size));
+        CUDA_CHECK(cudaMemset(data, 0, size));
+        extra->data_device[g_main_device] = data;
     }
 
-    tensor->data = dst;
-    tensor->backend = GGML_BACKEND_CUDA;
+    tensor->extra = extra;
 }
 
-void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
-    FILE * fp = fopen(fname, "rb");
+void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, true);
+}
 
-    const size_t size = ggml_nbytes(tensor);
+void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, false);
+}
 
-    void * buf;
-    CUDA_CHECK(cudaMalloc(&buf, size));
-    void * buf_host = malloc(size);
+void ggml_cuda_set_main_device(int main_device) {
+    if (main_device >= g_device_count) {
+        fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
+                main_device, g_device_count, g_main_device);
+        return;
+    }
+    g_main_device = main_device;
+    if (g_device_count > 1) {
+        cudaDeviceProp prop;
+        CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
+        fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
+    }
+}
 
-#ifdef _WIN32
-    int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
-#else
-    int ret = fseek(fp, (long) offset, SEEK_SET);
-#endif
-    GGML_ASSERT(ret == 0); // same
+void ggml_cuda_set_scratch_size(size_t scratch_size) {
+    g_scratch_size = scratch_size;
+}
 
-    size_t ret2 = fread(buf_host, size, 1, fp);
-    if (ret2 != 1) {
-        fprintf(stderr, "unexpectedly reached end of file");
-        exit(1);
+void ggml_cuda_free_scratch() {
+    if (g_scratch_buffer == nullptr) {
+        return;
     }
 
-    cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
-    cudaDeviceSynchronize();
+    CUDA_CHECK(cudaFree(g_scratch_buffer));
+    g_scratch_buffer = nullptr;
+}
+
+bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
+    ggml_cuda_func_t func;
+    const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
+        || tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
+        || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
+
+    switch (tensor->op) {
+        case GGML_OP_ADD:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_add;
+            break;
+        case GGML_OP_MUL:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_mul;
+            break;
+        case GGML_OP_SILU:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_silu;
+            break;
+        case GGML_OP_RMS_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_rms_norm;
+            break;
+        case GGML_OP_MUL_MAT:
+            if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
+                return false;
+            }
+            func = ggml_cuda_mul_mat;
+            break;
+        case GGML_OP_SCALE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_scale;
+            break;
+        case GGML_OP_CPY:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_cpy;
+            break;
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_nop;
+            break;
+        case GGML_OP_DIAG_MASK_INF:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_diag_mask_inf;
+            break;
+        case GGML_OP_SOFT_MAX:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_soft_max;
+            break;
+        case GGML_OP_ROPE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_rope;
+            break;
+        default:
+            return false;
+    }
 
-    tensor->data = buf;
-    free(buf_host);
-    fclose(fp);
+    if (params->ith != 0) {
+        return true;
+    }
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return true;
+    }
+    func(tensor->src0, tensor->src1, tensor);
+    return true;
 }
index 6a04dde6c37a9d3179747575ecb2dee455a29720..d32b4484267ab5a3f9028db07330302432c1fd93 100644 (file)
@@ -1,10 +1,19 @@
+#pragma once
+
 #include "ggml.h"
 
 #ifdef  __cplusplus
 extern "C" {
 #endif
 
+#define GGML_CUDA_MAX_DEVICES       16
+
+struct ggml_tensor_extra_gpu {
+    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+};
+
 void   ggml_init_cublas(void);
+void   ggml_cuda_set_tensor_split(const float * tensor_split);
 
 void   ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
@@ -15,8 +24,15 @@ void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
 void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
-void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
-void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
+void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
+
+void   ggml_cuda_free_data(struct ggml_tensor * tensor);
+void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
+void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
+void   ggml_cuda_set_main_device(int main_device);
+void   ggml_cuda_set_scratch_size(size_t scratch_size);
+void   ggml_cuda_free_scratch(void);
+bool   ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
 
 #ifdef  __cplusplus
 }
diff --git a/src/ggml-metal.h b/src/ggml-metal.h
new file mode 100644 (file)
index 0000000..b9e50ac
--- /dev/null
@@ -0,0 +1,67 @@
+// An interface allowing to compute ggml_cgraph with Metal
+//
+// This is a fully functional interface that extends ggml with GPU support for Apple devices.
+// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
+//
+// How it works?
+//
+// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
+// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
+// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
+//
+// You only need to make sure that all memory buffers that you used during the graph creation
+// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
+// used during the graph evaluation to determine the arguments of the compute kernels.
+//
+// Synchronization between device and host memory (for example for input and output tensors)
+// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
+//
+
+#pragma once
+
+#include <stddef.h>
+#include <stdbool.h>
+
+// max memory buffers that can be mapped to the device
+#define GGML_METAL_MAX_BUFFERS 16
+
+struct ggml_tensor;
+struct ggml_cgraph;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct ggml_metal_context;
+
+struct ggml_metal_context * ggml_metal_init(void);
+void ggml_metal_free(struct ggml_metal_context * ctx);
+
+// creates a mapping between a host memory buffer and a device memory buffer
+// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
+// - the mapping is used during computation to determine the arguments of the compute kernels
+// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
+// - max_size specifies the maximum size of a tensor and is used to create shared views such
+//   that it is guaranteed that the tensor will fit in at least one of the views
+//
+bool ggml_metal_add_buffer(
+        struct ggml_metal_context * ctx,
+                       const char * name,
+                             void * data,
+                           size_t   size,
+                           size_t   max_size);
+
+// set data from host memory into the device
+void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
+
+// get data from the device into host memory
+void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
+
+// same as ggml_graph_compute but uses Metal
+// creates gf->n_threads command buffers in parallel
+void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
+
+#ifdef __cplusplus
+}
+#endif
+
diff --git a/src/ggml-metal.m b/src/ggml-metal.m
new file mode 100644 (file)
index 0000000..a7e104d
--- /dev/null
@@ -0,0 +1,972 @@
+#import "ggml-metal.h"
+
+#import "ggml.h"
+
+#import <Foundation/Foundation.h>
+
+#import <Metal/Metal.h>
+#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
+
+#ifdef GGML_METAL_NDEBUG
+#define metal_printf(...)
+#else
+#define metal_printf(...) fprintf(stderr, __VA_ARGS__)
+#endif
+
+#define UNUSED(x) (void)(x)
+
+struct ggml_metal_buffer {
+    const char * name;
+
+    void   * data;
+    size_t   size;
+
+    id<MTLBuffer> metal;
+};
+
+struct ggml_metal_context {
+    float * logits;
+
+    id<MTLDevice>       device;
+    id<MTLCommandQueue> queue;
+    id<MTLLibrary>      library;
+
+    int n_buffers;
+    struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+
+    // custom kernels
+#define GGML_METAL_DECL_KERNEL(name) \
+    id<MTLFunction>             function_##name; \
+    id<MTLComputePipelineState> pipeline_##name
+
+    GGML_METAL_DECL_KERNEL(add);
+    GGML_METAL_DECL_KERNEL(mul);
+    GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
+    GGML_METAL_DECL_KERNEL(scale);
+    GGML_METAL_DECL_KERNEL(silu);
+    GGML_METAL_DECL_KERNEL(relu);
+    GGML_METAL_DECL_KERNEL(gelu);
+    GGML_METAL_DECL_KERNEL(soft_max);
+    GGML_METAL_DECL_KERNEL(diag_mask_inf);
+    GGML_METAL_DECL_KERNEL(get_rows_f16);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_0);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_1);
+    GGML_METAL_DECL_KERNEL(get_rows_q2_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q3_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q5_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q6_k);
+    GGML_METAL_DECL_KERNEL(rms_norm);
+    GGML_METAL_DECL_KERNEL(norm);
+    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
+    GGML_METAL_DECL_KERNEL(rope);
+    GGML_METAL_DECL_KERNEL(alibi_f32);
+    GGML_METAL_DECL_KERNEL(cpy_f32_f16);
+    GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+
+#undef GGML_METAL_DECL_KERNEL
+};
+
+// MSL code
+// TODO: move the contents here when ready
+//       for now it is easier to work in a separate file
+static NSString * const msl_library_source = @"see metal.metal";
+
+// Here to assist with NSBundle Path Hack
+@interface GGMLMetalClass : NSObject
+@end
+@implementation GGMLMetalClass
+@end
+
+struct ggml_metal_context * ggml_metal_init(void) {
+    fprintf(stderr, "%s: allocating\n", __func__);
+
+    struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+
+    ctx->device = MTLCreateSystemDefaultDevice();
+    ctx->queue  = [ctx->device newCommandQueue];
+    ctx->n_buffers = 0;
+
+    // determine if we can use MPS
+    if (MPSSupportsMTLDevice(ctx->device)) {
+        fprintf(stderr, "%s: using MPS\n", __func__);
+    } else {
+        fprintf(stderr, "%s: not using MPS\n", __func__);
+        GGML_ASSERT(false && "MPS not supported");
+    }
+
+#if 0
+    // compile from source string and show compile log
+    {
+        NSError * error = nil;
+
+        ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+    }
+#else
+    UNUSED(msl_library_source);
+
+    // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
+    {
+        NSError * error = nil;
+
+        //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
+        NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
+        NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+        fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
+
+        NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+
+        ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+        if (error) {
+            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            exit(1);
+        }
+    }
+#endif
+
+    // load kernels
+    {
+#define GGML_METAL_ADD_KERNEL(name) \
+        ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
+        ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
+        fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
+
+        GGML_METAL_ADD_KERNEL(add);
+        GGML_METAL_ADD_KERNEL(mul);
+        GGML_METAL_ADD_KERNEL(mul_row);
+        GGML_METAL_ADD_KERNEL(scale);
+        GGML_METAL_ADD_KERNEL(silu);
+        GGML_METAL_ADD_KERNEL(relu);
+        GGML_METAL_ADD_KERNEL(gelu);
+        GGML_METAL_ADD_KERNEL(soft_max);
+        GGML_METAL_ADD_KERNEL(diag_mask_inf);
+        GGML_METAL_ADD_KERNEL(get_rows_f16);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_0);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_1);
+        GGML_METAL_ADD_KERNEL(get_rows_q2_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q3_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q5_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q6_k);
+        GGML_METAL_ADD_KERNEL(rms_norm);
+        GGML_METAL_ADD_KERNEL(norm);
+        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
+        GGML_METAL_ADD_KERNEL(rope);
+        GGML_METAL_ADD_KERNEL(alibi_f32);
+        GGML_METAL_ADD_KERNEL(cpy_f32_f16);
+        GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+        GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+
+#undef GGML_METAL_ADD_KERNEL
+    }
+
+    fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+    fprintf(stderr, "%s: hasUnifiedMemory             = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    if (ctx->device.maxTransferRate != 0) {
+        fprintf(stderr, "%s: maxTransferRate              = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
+    } else {
+        fprintf(stderr, "%s: maxTransferRate              = built-in GPU\n", __func__);
+    }
+
+    return ctx;
+}
+
+void ggml_metal_free(struct ggml_metal_context * ctx) {
+    fprintf(stderr, "%s: deallocating\n", __func__);
+
+    free(ctx);
+}
+
+// finds the Metal buffer that contains the tensor data on the GPU device
+// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
+// Metal buffer based on the host memory pointer
+//
+static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
+    //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
+
+    const int64_t tsize = ggml_nbytes(t);
+
+    // find the view that contains the tensor fully
+    for (int i = 0; i < ctx->n_buffers; ++i) {
+        const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
+
+        if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
+            *offs = (size_t) ioffs;
+
+            //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
+
+            return ctx->buffers[i].metal;
+        }
+    }
+
+    fprintf(stderr, "%s: error: buffer is nil\n", __func__);
+
+    return nil;
+}
+
+bool ggml_metal_add_buffer(
+        struct ggml_metal_context * ctx,
+                     const char * name,
+                           void * data,
+                         size_t   size,
+                         size_t   max_size) {
+    if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
+        fprintf(stderr, "%s: too many buffers\n", __func__);
+        return false;
+    }
+
+    if (data) {
+        // verify that the buffer does not overlap with any of the existing buffers
+        for (int i = 0; i < ctx->n_buffers; ++i) {
+            const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
+
+            if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
+                fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
+                return false;
+            }
+        }
+
+        const size_t size_page = getpagesize();
+
+        size_t size_aligned = size;
+        if ((size_aligned % size_page) != 0) {
+            size_aligned += (size_page - (size_aligned % size_page));
+        }
+
+        // the buffer fits into the max buffer size allowed by the device
+        if (size_aligned <= ctx->device.maxBufferLength) {
+            ctx->buffers[ctx->n_buffers].name = name;
+            ctx->buffers[ctx->n_buffers].data = data;
+            ctx->buffers[ctx->n_buffers].size = size;
+
+            ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+            if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
+                return false;
+            }
+
+            fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
+
+            ++ctx->n_buffers;
+        } else {
+            // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
+            // one of the views
+            const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
+            const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
+            const size_t size_view = ctx->device.maxBufferLength;
+
+            for (size_t i = 0; i < size; i += size_step) {
+                const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
+
+                ctx->buffers[ctx->n_buffers].name = name;
+                ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
+                ctx->buffers[ctx->n_buffers].size = size_step_aligned;
+
+                ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+                if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                    fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
+                    return false;
+                }
+
+                fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
+                if (i + size_step < size) {
+                    fprintf(stderr, "\n");
+                }
+
+                ++ctx->n_buffers;
+            }
+        }
+
+        fprintf(stderr, ", (%8.2f / %8.2f)",
+                ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
+                ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+
+        if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
+            fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
+        } else {
+            fprintf(stderr, "\n");
+        }
+    }
+
+    return true;
+}
+
+void ggml_metal_set_tensor(
+        struct ggml_metal_context * ctx,
+        struct ggml_tensor * t) {
+    metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
+
+    size_t offs;
+    id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
+
+    memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
+}
+
+void ggml_metal_get_tensor(
+        struct ggml_metal_context * ctx,
+        struct ggml_tensor * t) {
+    metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
+
+    size_t offs;
+    id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
+
+    memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
+}
+
+void ggml_metal_graph_compute(
+        struct ggml_metal_context * ctx,
+               struct ggml_cgraph * gf) {
+    metal_printf("%s: evaluating graph\n", __func__);
+
+    // create multiple command buffers and enqueue them
+    // then, we encode the graph into the command buffers in parallel
+
+    const int n_cb = gf->n_threads;
+
+    NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
+
+    for (int i = 0; i < n_cb; ++i) {
+        command_buffers[i] = [ctx->queue commandBuffer];
+
+        // enqueue the command buffers in order to specify their execution order
+        [command_buffers[i] enqueue];
+    }
+
+    // TODO: is this the best way to start threads?
+    dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
+
+    for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+        const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
+
+        dispatch_async(queue, ^{
+            size_t offs_src0 = 0;
+            size_t offs_src1 = 0;
+            size_t offs_dst  = 0;
+
+            id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
+
+            id<MTLComputeCommandEncoder> encoder = nil;
+
+            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+
+            for (int i = node_start; i < node_end; ++i) {
+                metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+
+                struct ggml_tensor * src0 = gf->nodes[i]->src0;
+                struct ggml_tensor * src1 = gf->nodes[i]->src1;
+                struct ggml_tensor * dst  = gf->nodes[i];
+
+                const int64_t  ne00 = src0 ? src0->ne[0] : 0;
+                const int64_t  ne01 = src0 ? src0->ne[1] : 0;
+                const int64_t  ne02 = src0 ? src0->ne[2] : 0;
+                const int64_t  ne03 = src0 ? src0->ne[3] : 0;
+
+                const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+                const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+                const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+                const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+
+                const int64_t  ne10 = src1 ? src1->ne[0] : 0;
+                const int64_t  ne11 = src1 ? src1->ne[1] : 0;
+                const int64_t  ne12 = src1 ? src1->ne[2] : 0;
+                const int64_t  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+                const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+                const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+                const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+                const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+
+                const int64_t  ne0  = dst ? dst->ne[0] : 0;
+                const int64_t  ne1  = dst ? dst->ne[1] : 0;
+                const int64_t  ne2  = dst ? dst->ne[2] : 0;
+                const int64_t  ne3  = dst ? dst->ne[3] : 0;
+
+                const uint64_t nb0  = dst ? dst->nb[0] : 0;
+                const uint64_t nb1  = dst ? dst->nb[1] : 0;
+                const uint64_t nb2  = dst ? dst->nb[2] : 0;
+                const uint64_t nb3  = dst ? dst->nb[3] : 0;
+
+                const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+                const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+                const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
+
+                id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
+                id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
+                id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(ctx, dst,  &offs_dst)  : nil;
+
+                //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+                //if (src0) {
+                //    metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+                //            ggml_is_contiguous(src0), src0->name);
+                //}
+                //if (src1) {
+                //    metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+                //            ggml_is_contiguous(src1), src1->name);
+                //}
+                //if (dst) {
+                //    metal_printf("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
+                //            dst->name);
+                //}
+
+                switch (dst->op) {
+                    case GGML_OP_RESHAPE:
+                    case GGML_OP_VIEW:
+                    case GGML_OP_TRANSPOSE:
+                    case GGML_OP_PERMUTE:
+                        {
+                            // noop
+                        } break;
+                    case GGML_OP_ADD:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_MUL:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            if (ggml_nelements(src1) == ne10) {
+                                // src1 is a row
+                                [encoder setComputePipelineState:ctx->pipeline_mul_row];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_mul];
+                            }
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_SCALE:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float scale = *(const float *) src1->data;
+
+                            [encoder setComputePipelineState:ctx->pipeline_scale];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_SILU:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_silu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_RELU:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_relu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_GELU:
+                    {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_gelu];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                    case GGML_OP_SOFT_MAX:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int nth = 32;
+
+                            [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_DIAG_MASK_INF:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int n_past = ((int32_t *)(src1->data))[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_MUL_MAT:
+                        {
+                            // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
+
+                            GGML_ASSERT(ne00 == ne10);
+                            GGML_ASSERT(ne02 == ne12);
+
+                            if (ggml_is_contiguous(src0) &&
+                                ggml_is_contiguous(src1) &&
+                                (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
+
+                                if (encoder != nil) {
+                                    [encoder endEncoding];
+                                    encoder = nil;
+                                }
+
+                                MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
+                                MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
+
+                                // for F32 x F32 we use MPS
+                                MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
+
+                                MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
+
+                                MPSMatrixDescriptor * desc  = [MPSMatrixDescriptor
+                                    matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
+
+                                MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
+                                    initWithDevice:ctx->device transposeLeft:false transposeRight:true
+                                        resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
+
+                                // we need to do ne02 multiplications
+                                // TODO: is there a way to do this in parallel - currently very slow ..
+                                // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
+                                for (int64_t i02 = 0; i02 < ne02; ++i02) {
+                                    size_t offs_src0_cur = offs_src0 + i02*nb02;
+                                    size_t offs_src1_cur = offs_src1 + i02*nb12;
+                                    size_t offs_dst_cur  = offs_dst  + i02*nb2;
+
+                                    MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
+                                    MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
+                                    MPSMatrix * mat_dst  = [[MPSMatrix alloc] initWithBuffer:id_dst  offset:offs_dst_cur  descriptor:desc ];
+
+                                    [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
+                                }
+                            } else {
+                                if (encoder == nil) {
+                                    encoder = [command_buffer computeCommandEncoder];
+                                }
+
+                                int nth0 = 32;
+                                int nth1 = 1;
+
+                                // use custom matrix x vector kernel
+                                switch (src0t) {
+                                    case GGML_TYPE_F16:
+                                        {
+                                            GGML_ASSERT(ne02 == ne12);
+
+                                            nth0 = 64;
+                                            nth1 = 1;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_0:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_1:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q2_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q3_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
+                                        } break;
+                                    case GGML_TYPE_Q6_K:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 4;
+                                            nth1 = 16;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
+                                        } break;
+                                    default:
+                                        {
+                                            fprintf(stderr, "Asserting on type %d\n",(int)src0t);
+                                            GGML_ASSERT(false && "not implemented");
+                                        }
+                                };
+
+                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
+                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
+                                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
+                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
+                                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
+                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
+                                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
+
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
+                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src0t == GGML_TYPE_Q2_K ||
+                                         src0t == GGML_TYPE_Q3_K ||
+                                         src0t == GGML_TYPE_Q4_K ||
+                                         src0t == GGML_TYPE_Q5_K ||
+                                         src0t == GGML_TYPE_Q6_K) {
+                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                } else {
+                                    [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                            }
+                        } break;
+                    case GGML_OP_GET_ROWS:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            switch (src0->type) {
+                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
+                                case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
+                                case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
+                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
+                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
+                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
+                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
+                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
+                                default: GGML_ASSERT(false && "not implemented");
+                            }
+
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
+                            [encoder setBytes:&(dst->nb[1])  length:sizeof(uint64_t) atIndex:5];
+
+                            const int64_t n = ggml_nelements(src1);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_RMS_NORM:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float eps = 1e-6f;
+
+                            const int nth = 256;
+
+                            [encoder setComputePipelineState:ctx->pipeline_rms_norm];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            const int64_t nrows = ggml_nrows(src0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_NORM:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const float eps = 1e-5f;
+
+                            const int nth = 256;
+
+                            [encoder setComputePipelineState:ctx->pipeline_norm];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+                            const int64_t nrows = ggml_nrows(src0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_ALIBI:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            GGML_ASSERT((src0t == GGML_TYPE_F32));
+
+                            const int   n_past   = ((int32_t *) src1->data)[0]; UNUSED(n_past);
+                            const int   n_head   = ((int32_t *) src1->data)[1];
+                            const float max_bias = ((float *)   src1->data)[2];
+
+                            if (__builtin_popcount(n_head) != 1) {
+                                GGML_ASSERT(false && "only power-of-two n_head implemented");
+                            }
+
+                            const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+                            const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+
+                            [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
+                            const int nth = 32;
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_ROPE:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int n_dims = ((int32_t *) src1->data)[1];
+                            const int mode   = ((int32_t *) src1->data)[2];
+
+                            const int n_past = ((int32_t *)(src1->data))[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_rope];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00   length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01   length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02   length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03   length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00   length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01   length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02   length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03   length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0    length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1    length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2    length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3    length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0    length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1    length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2    length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3    length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&n_past length:sizeof(     int) atIndex:18];
+                            [encoder setBytes:&n_dims length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&mode   length:sizeof(     int) atIndex:20];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    case GGML_OP_CPY:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            const int nth = 32;
+
+                            switch (src0t) {
+                                case GGML_TYPE_F32:
+                                    {
+                                        switch (dstt) {
+                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
+                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
+                                            default: GGML_ASSERT(false && "not implemented");
+                                        };
+                                    } break;
+                                case GGML_TYPE_F16:
+                                    {
+                                        switch (dstt) {
+                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
+                                            case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+                                            default: GGML_ASSERT(false && "not implemented");
+                                        };
+                                    } break;
+                                default: GGML_ASSERT(false && "not implemented");
+                            }
+
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    default:
+                        fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                        GGML_ASSERT(false);
+                }
+            }
+
+            if (encoder != nil) {
+                [encoder endEncoding];
+                encoder = nil;
+            }
+
+            [command_buffer commit];
+        });
+    }
+
+    // wait for all threads to finish
+    dispatch_barrier_sync(queue, ^{});
+
+    [command_buffers[n_cb - 1] waitUntilCompleted];
+
+    // check status of command buffers
+    // needed to detect if the device ran out-of-memory for example (#1881)
+    for (int i = 0; i < n_cb; i++) {
+        MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
+        if (status != MTLCommandBufferStatusCompleted) {
+            fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
+            GGML_ASSERT(false);
+        }
+    }
+}
diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal
new file mode 100644 (file)
index 0000000..d1e4922
--- /dev/null
@@ -0,0 +1,1585 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+
+#define QK4_0 32
+#define QR4_0 2
+typedef struct {
+    half    d;             // delta
+    uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+
+#define QK4_1 32
+typedef struct {
+    half d;          // delta
+    half m;          // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+
+static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
+    const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const half d = x[i].d;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F) - 8;
+            const int x1 = (x[i].qs[j] >>   4) - 8;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
+    const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const half d = x[i].d;
+        const half m = x[i].m;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F);
+            const int x1 = (x[i].qs[j] >>   4);
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+kernel void kernel_add(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] + src1[tpig];
+}
+
+kernel void kernel_mul(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src1[tpig];
+}
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_mul_row(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src1[tpig % ne00];
+}
+
+kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    float x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+constant float GELU_COEF_A    = 0.044715f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    float x = src0[tpig];
+    dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_soft_max(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    // parallel max
+    buf[tpitg[0]] = -INFINITY;
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg[0]/2; i > 0; i /= 2) {
+        if (tpitg[0] < i) {
+            buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg[0] == 0) {
+        buf[0] = buf[0];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float max = buf[0];
+
+    // parallel sum
+    buf[tpitg[0]] = 0.0f;
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        buf[tpitg[0]] += exp(psrc0[i00] - max);
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg[0]/2; i > 0; i /= 2) {
+        if (tpitg[0] < i) {
+            buf[tpitg[0]] += buf[tpitg[0] + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg[0] == 0) {
+        buf[0] = buf[0];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float sum = buf[0];
+
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        pdst[i00] = exp(psrc0[i00] - max) / sum;
+    }
+}
+
+kernel void kernel_diag_mask_inf(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant       int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i02 = tpig[2];
+    const int64_t i01 = tpig[1];
+    const int64_t i00 = tpig[0];
+
+    if (i00 > n_past + i01) {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    } else {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+    }
+}
+
+kernel void kernel_get_rows_f16(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    for (int j = 0; j < ne00; j++) {
+        dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
+    }
+}
+
+kernel void kernel_get_rows_q4_0(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q4_0(
+            (device const block_q4_0 *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q4_1(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q4_1(
+            (device const block_q4_1 *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_norm(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant     float & eps,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+    // MEAN
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += x[i00];
+    }
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    const float mean  = sum[0];
+
+    // recenter
+    device float * y = dst + tgpig*ne00;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] - mean;
+    }
+
+    // VARIANCE
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += y[i00] * y[i00];
+    }
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    const float variance = sum[0];
+
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = y[i00] * scale;
+    }
+}
+
+
+kernel void kernel_rms_norm(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant     float & eps,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+
+    // parallel sum
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        sum[tpitg] += x[i00] * x[i00];
+    }
+
+    // reduce
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = ntg/2; i > 0; i /= 2) {
+        if (tpitg < i) {
+            sum[tpitg] += sum[tpitg + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const float mean  = sum[0];
+    const float scale = 1.0f/sqrt(mean + eps);
+
+    device float * y = dst + tgpig*ne00;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
+kernel void kernel_mul_mat_q4_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+    const int nb = ne00/QK4_0;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int ix = tpitg.y/4;           // 0 or 1
+    const int iy = tpitg.y - 4*ix;      // 0...3
+
+    const int first = 4 * iy;
+
+    float sumf = 0;
+
+    for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
+
+        const float d = (float)x[i].d;
+
+        device const uint8_t * xl = x[i].qs + first;
+        device const float   * yl = y + i * QK4_0 + first;
+
+        float2 acc = {0.0f, 0.0f};
+
+        for (int j = 0; j < 4; ++j) {
+
+            acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
+            acc[1] += yl[j] + yl[j+16];
+
+        }
+
+        sumf += d * (acc[0] - 8.f*acc[1]);
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_mul_mat_q4_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+    const int nb = ne00/QK4_1;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+
+    const uint nth = tptg.x*tptg.y;
+    const uint ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int ix = tpitg.y/4;           // 0 or 1
+    const int iy = tpitg.y - 4*ix;      // 0...3
+
+    const int first = 4 * iy;
+
+    float sumf = 0;
+
+    for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
+
+        const float d = (float)x[i].d;
+        const float m = (float)x[i].m;
+
+        device const uint8_t * xl = x[i].qs + first;
+        device const float   * yl = y + i * QK4_1 + first;
+
+        float2 acc = {0.0f, 0.0f};
+
+        for (int j = 0; j < 4; ++j) {
+
+            acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
+            acc[1] += yl[j+16] * (d * (xl[j] >>  4) + m);
+
+        }
+
+        sumf += acc[0] + acc[1];
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_mul_mat_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tpig[[thread_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3  tptg[[threads_per_threadgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+    const int64_t im = tgpig.z;
+
+    device const half  * x = (device const half  *) (src0 + r0*nb01 + im*nb02);
+    device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+    sum[tpitg.x] = 0.0f;
+
+    for (int i = tpitg.x; i < ne00; i += tptg.x) {
+        sum[tpitg.x] += (float) x[i] * (float) y[i];
+    }
+
+    // accumulate the sum from all threads in the threadgroup
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    for (uint i = tptg.x/2; i > 0; i /= 2) {
+        if (tpitg.x < i) {
+            sum[tpitg.x] += sum[tpitg.x + i];
+        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    if (tpitg.x == 0) {
+        dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_alibi_f32(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        constant      float & m0,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    float m_k = pow(m0, i2 + 1);
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
+    }
+}
+
+kernel void kernel_rope(
+        device const  void * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        constant       int & n_past,
+        constant       int & n_dims,
+        constant       int & mode,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i3 = tpig[2];
+    const int64_t i2 = tpig[1];
+    const int64_t i1 = tpig[0];
+
+    const bool is_neox = mode & 2;
+    const float theta_scale = pow(10000.0, -2.0f/n_dims);
+
+    const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+
+    float theta = (float)p;
+
+    if (!is_neox) {
+        for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+            const float cos_theta = cos(theta);
+            const float sin_theta = sin(theta);
+
+            theta *= theta_scale;
+
+            device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[1];
+
+            dst_data[0] = x0*cos_theta - x1*sin_theta;
+            dst_data[1] = x0*sin_theta + x1*cos_theta;
+        }
+    } else {
+        // TODO: implement
+    }
+}
+
+kernel void kernel_cpy_f16_f16(
+        device const half * src0,
+        device       half * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f16(
+        device const float * src0,
+        device        half * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f32(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+//============================================ k-quants ======================================================
+
+#define QK_K 256
+
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    half d;           // super-block scale for quantized scales
+    half dmin;        // super-block scale for quantized mins
+} block_q2_k;
+// 84 bytes / block
+
+typedef struct {
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+    half d;                    // super-block scale
+} block_q3_k;
+// 110 bytes / block
+
+typedef struct {
+    half d;             // super-block scale for quantized scales
+    half dmin;          // super-block scale for quantized mins
+    uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_k;
+// 144 bytes / block
+
+typedef struct {
+    half d;                      // super-block scale for quantized scales
+    half dmin;                   // super-block scale for quantized mins
+    uint8_t scales[3*QK_K/64];   // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_k;
+// 176 bytes / block
+
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    half d;                  // super-block scale
+} block_q6_k;
+// 210 bytes / block
+
+static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
+    uchar4 r;
+    if (j < 4) {
+        r[0] = q[j+0] & 63;
+        r[2] = q[j+1] & 63;
+        r[1] = q[j+4] & 63;
+        r[3] = q[j+5] & 63;
+    } else {
+        r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
+        r[1] = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+        r[3] = (q[j+5] >>  4) | ((q[j+1] >> 6) << 4);
+    }
+    return r;
+}
+
+//========================================== dequantization =============================
+
+static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = x[i].d;
+        const float min = x[i].dmin;
+
+        device const uint8_t * q = x[i].qs;
+
+        int is = 0;
+        float dl, ml;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                uint8_t sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+
+                sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+
+                shift += 2;
+            }
+            q += 32;
+        }
+
+    }
+}
+
+static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    uint16_t aux[8];
+    thread const int8_t * scales = (thread const int8_t*)aux;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs;
+        device const uint8_t * h = x[i].hmask;
+        uint8_t m = 1;
+
+        device const uint16_t * a = (device const uint16_t *)x[i].scales;
+        aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
+        aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
+        aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
+        aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
+        aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
+        aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
+        aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
+        aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
+
+        int is = 0;
+        float dl;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
+                }
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
+                }
+
+                shift += 2;
+                m <<= 1;
+            }
+            q += 32;
+        }
+
+    }
+
+}
+
+static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = x[i].d;
+        const float min = x[i].dmin;
+
+        device const uint8_t * q = x[i].qs;
+        device const uint8_t * scales = x[i].scales;
+
+        int is = 0;
+        for (int j = 0; j < QK_K; j += 64) {
+            const uchar4 sc = get_scale_min_k4(is, scales);
+            const float d1 = d * sc[0]; const float m1 = min * sc[1];
+            const float d2 = d * sc[2]; const float m2 = min * sc[3];
+            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
+            q += 32; is += 2;
+        }
+
+    }
+}
+
+static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+   for (int i = 0; i < nb; i++) {
+
+        const float d = (float)(x[i].d);
+        const float min = (float)(x[i].dmin);
+
+        device const uint8_t * ql = x[i].qs;
+        device const uint8_t * qh = x[i].qh;
+
+        int is = 0;
+        uint8_t u1 = 1, u2 = 2;
+        for (int j = 0; j < QK_K; j += 64) {
+            const uchar4 sc = get_scale_min_k4(is, x[i].scales);
+            const float d1 = d * sc[0]; const float m1 = min * sc[1];
+            const float d2 = d * sc[2]; const float m2 = min * sc[3];
+            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
+            ql += 32; is += 2;
+            u1 <<= 2; u2 <<= 2;
+        }
+    }
+
+}
+
+static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        device const uint8_t * ql = x[i].ql;
+        device const uint8_t * qh = x[i].qh;
+        device const int8_t  * sc = x[i].scales;
+
+        const float d = x[i].d;
+
+        for (int n = 0; n < QK_K; n += 128) {
+            for (int l = 0; l < 32; ++l) {
+                int is = l/16;
+                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+                y[l +  0] = d * sc[is + 0] * q1;
+                y[l + 32] = d * sc[is + 2] * q2;
+                y[l + 64] = d * sc[is + 4] * q3;
+                y[l + 96] = d * sc[is + 6] * q4;
+            }
+            y  += 128;
+            ql += 64;
+            qh += 32;
+            sc += 8;
+        }
+    }
+}
+
+kernel void kernel_get_rows_q2_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q2_k(
+            (device const block_q2_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q3_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q3_k(
+            (device const block_q3_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q4_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q4_k(
+            (device const block_q4_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q5_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q5_k(
+            (device const block_q5_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+kernel void kernel_get_rows_q6_k(
+        device const  void * src0,
+        device const   int * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb1,
+        uint tpig[[thread_position_in_grid]]) {
+    const int i = tpig;
+    const int r = ((device int32_t *) src1)[i];
+
+    dequantize_row_q6_k(
+            (device const block_q6_k *) ((device char *) src0 + r*nb01),
+                       (device float *) ((device char *)  dst + i*nb1), ne00);
+}
+
+//====================================== dot products =========================
+
+kernel void kernel_mul_mat_q2_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int tid = tpitg.y;    // 0...16
+    const int il  = tid/4;      // 0...3
+    const int ir  = tid%4;      // 0...3
+    const int ip  = il/2;       // 0 or 1
+    const int shift1 = 4*(il%2);// 0 or 4
+    const int shift2 = shift1+2;// 2 or 6
+    const int n   = 8;
+    const int is  = 4*il + (n*ir)/16;
+
+    const int y_offset = 64*il + n*ir;
+    const int q_offset = 32*ip + n*ir;
+
+    sum[ith] = 0.0f;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * q = x[i].qs + q_offset;
+        device const uint8_t * scales = x[i].scales + is;
+
+        uint8_t d1 = scales[0] & 0xF;
+        uint8_t d2 = scales[2] & 0xF;
+        uint8_t m1 = scales[0] >>  4;
+        uint8_t m2 = scales[2] >>  4;
+
+        device const float   * y = yy + i*QK_K + y_offset;
+
+        //float4 s = {0.f, 0.f, 0.f, 0.f};
+        float2 s = {0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
+            s[1] += y[l+32] * ((q[l] >> shift2) & 3);
+            smin += y[l+ 0] * m1 + y[l+32] * m2;
+        }
+
+        const float dall = (float)x[i].d;
+        const float dmin = (float)x[i].dmin;
+
+        sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
+
+    }
+    sum[ith] = sumf;
+
+    //int mask1 = (ith%4 == 0);
+    //int mask2 = (ith%16 == 0);
+
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //if (ith == 0) {
+    //    for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+    //    dst[r1*ne0 + r0] = sum[0];
+    //}
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    // This version is slightly faster than the commented out one below,
+    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+}
+
+kernel void kernel_mul_mat_q3_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const uint8_t m3 = 3;
+    const int8_t  m4 = 4;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int tid = tpitg.y;        // expecting 16
+    const int ip  = tid/8;          // 0 or 1
+    const int il  = tid/2 - 4*ip;   // 0...3
+    const int ir  = tid%2;
+    const int n   = 8;
+    const int l0  = n*ir;
+
+    const uint8_t m = 1 << (4*ip + il);
+
+    const int shift = 2*il;
+
+    const uint16_t s_shift1 = 4*ip;
+    const uint16_t s_shift2 = s_shift1 + 2*(il/2);
+    const int ik = 4 + (il%2);
+
+    const int q_offset = 32*ip + l0;
+    const int y_offset = 128*ip + 32*il + l0;
+
+    //float sumf = 0;
+    float sumf1 = 0, sumf2 = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs + q_offset;
+        device const uint8_t * h = x[i].hmask + l0;
+        device const float   * y = yy + i * QK_K + y_offset;
+
+        device const uint16_t * a = (device const uint16_t *)x[i].scales;
+        const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
+
+        float s = 0;
+        for (int l = 0; l < n; ++l) {
+            s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
+        }
+        float d = d_all * s;
+        sumf1 += d * scales[0];
+        sumf2 += d;
+        //sumf += d_all * s * (scales[0] - 32);
+
+        s = 0;
+        for (int l = 0; l < n; ++l) {
+            s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
+        }
+        d = d_all * s;
+        sumf1 += d * scales[1];
+        sumf2 += d;
+        //sumf += d_all * s * (scales[1] - 32);
+
+    }
+
+    //sum[ith] = sumf;
+    sum[ith] = sumf1 - 32.f*sumf2;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+}
+
+kernel void kernel_mul_mat_q4_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int tid = tpitg.y;   // 0...16
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 4;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    sum[ith] = 0.0f;
+
+    uchar2 sc1, sc2, sc3, sc4;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * q1 = (x + i)->qs + q_offset;
+        device const uint8_t * q2 = q1 + 64;
+        device const float   * y1 = yy + i*QK_K + y_offset;
+        device const float   * y2 = y1 + 128;
+
+        const float dall = (float)((x + i)->d);
+        const float dmin = (float)((x + i)->dmin);
+
+        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
+        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
+        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
+        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
+        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+
+            s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
+            s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
+            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+
+        }
+        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    // This version is slightly faster than the commented out one below,
+    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+    //// accumulate the sum from all threads in the threadgroup
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    //for (uint i = nth/2; i > 0; i /= 2) {
+    //    if (ith < i) {
+    //        sum[ith] += sum[ith + i];
+    //    }
+    //    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //}
+
+    //if (ith == 0) {
+    //    dst[r1*ne0 + r0] = sum[0];
+    //}
+}
+
+kernel void kernel_mul_mat_q5_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    const int tid = tpitg.y;   // 0...16
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 4;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1 = 1u << (2*im);
+    const uint8_t hm2 = hm1 << 1;
+    const uint8_t hm3 = hm1 << 4;
+    const uint8_t hm4 = hm2 << 4;
+
+    uchar2 sc1, sc2, sc3, sc4;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * q1 = (x + i)->qs + q_offset;
+        device const uint8_t * q2 = q1 + 64;
+        device const uint8_t * qh = (x + i)->qh + l0;
+        device const float   * y1 = yy + i*QK_K + y_offset;
+        device const float   * y2 = y1 + 128;
+
+        const float dall = (float)((x + i)->d);
+        const float dmin = (float)((x + i)->dmin);
+
+        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
+        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
+        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
+        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
+        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+
+            s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
+            s[1] += y1[l+32] * ((q1[l] >>  4) + (qh[l] & hm2 ? 16 : 0));
+            s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
+            s[3] += y2[l+32] * ((q2[l] >>  4) + (qh[l] & hm4 ? 16 : 0));
+            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+
+        }
+        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+
+    }
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+}
+
+kernel void kernel_mul_mat_q6_k_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        threadgroup float  * sum [[threadgroup(0)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint2 tpitg[[thread_position_in_threadgroup]],
+        uint2  tptg[[threads_per_threadgroup]]) {
+
+    const uint8_t kmask1 = 0x03;
+    const uint8_t kmask2 = 0x0C;
+    const uint8_t kmask3 = 0x30;
+    const uint8_t kmask4 = 0xC0;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
+
+    device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    const int nth = tptg.x*tptg.y;
+    const int ith = tptg.y*tpitg.x + tpitg.y;
+
+    // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
+    const int iqs  = 16 * tpitg.y;
+    const int ip   = iqs / 128;         // 0 or 1
+    const int il   = (iqs - 128*ip)/16; // 0...7
+    const int n    = 4;
+    const int l0   = n*il;
+    const int is   = 8*ip + l0/16;
+
+    const int y_offset = 128*ip + l0;
+    const int q_offset_l = 64*ip + l0;
+    const int q_offset_h = 32*ip + l0;
+
+    float sumf = 0;
+    for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+        device const uint8_t * ql = x[i].ql + q_offset_l;
+        device const uint8_t * qh = x[i].qh + q_offset_h;
+        device const int8_t  * sc = x[i].scales + is;
+
+        device const float * y = yy + i * QK_K + y_offset;
+
+        const float dall = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < n; ++l) {
+            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+64] * ((int8_t)((ql[l+ 0]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+            sums[3] += y[l+96] * ((int8_t)((ql[l+32]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+        }
+
+        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+    }
+
+    sum[ith] = sumf;
+
+    //
+    // Accumulate the sum from all threads in the threadgroup
+    //
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%4 == 0) {
+        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith%16 == 0) {
+        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    if (ith == 0) {
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        dst[r1*ne0 + r0] = sum[0];
+    }
+
+}
diff --git a/src/ggml-opencl.cpp b/src/ggml-opencl.cpp
new file mode 100644 (file)
index 0000000..95f4cec
--- /dev/null
@@ -0,0 +1,1684 @@
+#include "ggml-opencl.h"
+
+#include <array>
+#include <atomic>
+#include <sstream>
+#include <vector>
+#include <limits>
+
+#define CL_TARGET_OPENCL_VERSION 110
+#include <clblast.h>
+
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+
+#include "ggml.h"
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#define CL_DMMV_BLOCK_SIZE 32
+
+#define MULTILINE_QUOTE(...) #__VA_ARGS__
+static std::string program_source = MULTILINE_QUOTE(
+
+typedef char int8_t;
+typedef uchar uint8_t;
+typedef int int32_t;
+typedef uint uint32_t;
+
+struct __attribute__ ((packed)) block_q4_0
+{
+    half d;
+    uint8_t qs[QK4_0 / 2];
+};
+
+struct __attribute__ ((packed)) block_q4_1
+{
+    half d;
+    half m;
+    uint8_t qs[QK4_1 / 2];
+};
+
+struct __attribute__ ((packed)) block_q5_0
+{
+    half d;
+    uint32_t qh;
+    uint8_t qs[QK5_0 / 2];
+};
+
+struct __attribute__ ((packed)) block_q5_1
+{
+    half d;
+    half m;
+    uint32_t qh;
+    uint8_t qs[QK5_1 / 2];
+};
+
+struct __attribute__ ((packed)) block_q8_0
+{
+    half d;
+    int8_t qs[QK8_0];
+};
+
+struct __attribute__((packed)) block_q2_K
+{
+    uint8_t scales[16];
+    uint8_t qs[64];
+    half d;
+    half dmin;
+};
+
+struct __attribute__((packed)) block_q3_K
+{
+    uint8_t hmask[32];
+    uint8_t qs[64];
+    uint8_t scales[12];
+    half d;
+};
+
+struct __attribute__((packed)) block_q4_K
+{
+    half d;
+    half dmin;
+    uint8_t scales[12];
+    uint8_t qs[128];
+};
+
+struct __attribute__((packed)) block_q5_K
+{
+    half d;
+    half dmin;
+    uint8_t scales[12];
+    uint8_t qh[32];
+    uint8_t qs[128];
+};
+
+struct __attribute__((packed)) block_q6_K
+{
+    uint8_t ql[128];
+    uint8_t qh[64];
+    int8_t scales[16];
+    half d;
+};
+
+__kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
+    const uint i = get_global_id(0);
+
+    y[i] = vload_half(0, &x[i]);
+}
+
+void dequantize_q4_0(__global const struct block_q4_0* x, const int ib, const int iqs, float* v0, float* v1) {
+    const float d = vload_half(0, &x[ib].d);
+
+    const uint8_t vui = x[ib].qs[iqs];
+
+    const int8_t vi0 = vui & 0xF;
+    const int8_t vi1 = vui >> 4;
+
+    *v0 = (vi0 - 8)*d;
+    *v1 = (vi1 - 8)*d;
+}
+void dequantize_q4_1(__global const struct block_q4_1* x, const int ib, const int iqs, float* v0, float* v1) {
+    const float d = vload_half(0, &x[ib].d);
+    const float m = vload_half(0, &x[ib].m);
+
+    const uint8_t vui = x[ib].qs[iqs];
+
+    const int8_t vi0 = vui & 0xF;
+    const int8_t vi1 = vui >> 4;
+
+    *v0 = vi0*d + m;
+    *v1 = vi1*d + m;
+}
+void dequantize_q5_0(__global const struct block_q5_0* x, const int ib, const int iqs, float* v0, float* v1) {
+    const float d = vload_half(0, &x[ib].d);
+
+    uint32_t qh = x[ib].qh;
+
+    const uint8_t xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const uint8_t xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
+    const int32_t x1 = ((x[ib].qs[iqs] >>  4) | xh_1) - 16;
+
+    *v0 = x0*d;
+    *v1 = x1*d;
+}
+void dequantize_q5_1(__global const struct block_q5_1* x, const int ib, const int iqs, float* v0, float* v1) {
+    const float d = vload_half(0, &x[ib].d);
+    const float m = vload_half(0, &x[ib].m);
+
+    uint32_t qh = x[ib].qh;
+
+    const uint8_t xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const uint8_t xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    const int32_t x1 = ((x[ib].qs[iqs] >>  4) | xh_1);
+
+    *v0 = x0*d + m;
+    *v1 = x1*d + m;
+}
+void dequantize_q8_0(__global const struct block_q8_0* x, const int ib, const int iqs, float* v0, float* v1) {
+    const float d = vload_half(0, &x[ib].d);
+
+    const int8_t vi0 = x[ib].qs[iqs + 0];
+    const int8_t vi1 = x[ib].qs[iqs + 1];
+
+    *v0 = vi0*d;
+    *v1 = vi1*d;
+}
+void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float* v1){
+    *v0 = vload_half(0, &x[ib + 0]);
+    *v1 = vload_half(0, &x[ib + 1]);
+}
+
+inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
+{
+    if (j < 4)
+    {
+        *d = q[j] & 63;
+        *m = q[j + 4] & 63;
+    }
+    else
+    {
+        *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
+        *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
+    }
+}
+
+__kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
+{
+    const int i = get_group_id(0);
+    const int tid = get_local_id(0);
+    const int n = tid / 32;
+    const int l = tid - 32 * n;
+    const int is = 8 * n + l / 16;
+
+    const uint8_t q = x[i].qs[32 * n + l];
+    __global float *y = yy + i * 256 + 128 * n;
+
+    const float dall = vload_half(0, &x[i].d);
+    const float dmin = vload_half(0, &x[i].dmin);
+
+    y[l + 0] = dall * (x[i].scales[is + 0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is + 0] >> 4);
+    y[l + 32] = dall * (x[i].scales[is + 2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is + 2] >> 4);
+    y[l + 64] = dall * (x[i].scales[is + 4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is + 4] >> 4);
+    y[l + 96] = dall * (x[i].scales[is + 6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is + 6] >> 4);
+}
+
+__kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
+{
+    int r = get_local_id(0) / 4;
+    int i = get_group_id(0);
+    int tid = r / 2;
+    int is0 = r % 2;
+    int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
+    int n = tid / 4;
+    int j = tid - 4 * n;
+
+    uint8_t m = 1 << (4 * n + j);
+    int is = 8 * n + 2 * j + is0;
+    int shift = 2 * j;
+
+    int8_t us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4)
+              : is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4)
+              : is < 12  ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4)
+              : (x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4);
+    float d_all = vload_half(0, &x[i].d);
+    float dl = d_all * (us - 32);
+
+    __global float *y = yy + i * 256 + 128 * n + 32 * j;
+    const __global uint8_t *q = x[i].qs + 32 * n;
+    const __global uint8_t *hm = x[i].hmask;
+
+    for (int l = l0; l < l0 + 4; ++l)
+        y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+}
+
+__kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
+{
+    const int i = get_group_id(0);
+    const int tid = get_local_id(0);
+    const int il = tid / 8;
+    const int ir = tid % 8;
+    const int is = 2 * il;
+    const int n = 4;
+
+    __global float *y = yy + i * 256 + 64 * il + n * ir;
+
+    const float dall = vload_half(0, &x[i].d);
+    const float dmin = vload_half(0, &x[i].dmin);
+
+    __global const uint8_t *q = x[i].qs + 32 * il + n * ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+    float d1 = dall * sc;
+    float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+    float d2 = dall * sc;
+    float m2 = dmin * m;
+    for (int l = 0; l < n; ++l)
+    {
+        y[l + 0] = d1 * (q[l] & 0xF) - m1;
+        y[l + 32] = d2 * (q[l] >> 4) - m2;
+    }
+}
+
+__kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
+{
+    const int i = get_group_id(0);
+    const int tid = get_local_id(0);
+    const int il = tid / 16;
+    const int ir = tid % 16;
+    const int is = 2 * il;
+
+    __global float *y = yy + i * 256 + 64 * il + 2 * ir;
+
+    const float dall = vload_half(0, &x[i].d);
+    const float dmin = vload_half(0, &x[i].dmin);
+
+    __global const uint8_t *ql = x[i].qs + 32 * il + 2 * ir;
+    __global const uint8_t *qh = x[i].qh + 2 * ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+    const float d1 = dall * sc;
+    const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+    const float d2 = dall * sc;
+    const float m2 = dmin * m;
+
+    uint8_t hm = 1 << (2 * il);
+    y[0] = d1 * ((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0)) - m1;
+    y[1] = d1 * ((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0)) - m1;
+    hm <<= 1;
+    y[32] = d2 * ((ql[0] >> 4) + (qh[0] & hm ? 16 : 0)) - m2;
+    y[33] = d2 * ((ql[1] >> 4) + (qh[1] & hm ? 16 : 0)) - m2;
+}
+
+__kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
+{
+    const int i = get_group_id(0);
+    const int tid = get_local_id(0);
+    const int ip = tid / 32;
+    const int il = tid - 32 * ip;
+    const int is = 8 * ip + il / 16;
+
+    __global float *y = yy + i * 256 + 128 * ip + il;
+
+    const float d = vload_half(0, &x[i].d);
+
+    __global const uint8_t *ql = x[i].ql + 64 * ip + il;
+    const uint8_t qh = x[i].qh[32 * ip + il];
+    __global const int8_t *sc = x[i].scales + is;
+
+    y[0] = d * sc[0] * ((int8_t)((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+    y[64] = d * sc[4] * ((int8_t)((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+    y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+}
+
+
+void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+    int n = iqs / 128;
+    int r = iqs - 128 * n;
+    int l = r / 8;
+
+    __global const float *y = yy + 128 * n + l;
+    __global const uint8_t *q = x[ib].qs + 32 * n + l;
+    __global const uint8_t *s = x[ib].scales + 8 * n;
+
+    const float dall = vload_half(0, &x[ib].d);
+    const float dmin = vload_half(0, &x[ib].dmin);
+
+    float sum = y[  0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
+              + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
+              + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
+              + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
+              + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
+              + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
+              + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
+              + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
+
+    *result = sum;
+}
+
+void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    int n = iqs/128;
+    int r = iqs - 128*n;
+    int l = r/8;
+
+    __global const float   * y = yy + 128*n + l;
+    __global const uint8_t * q = x[ib].qs + 32*n + l;
+    __global const uint8_t * hm = x[ib].hmask + l;
+    const int8_t * s = (const int8_t *)utmp + 8*n;
+
+    aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
+    aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
+    aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
+
+    utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+    utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+    utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+    utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+    const float dall = vload_half(0, &x[ib].d);
+    const uint8_t m = 1 << (4*n);
+
+    float sum = y[  0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
+              + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
+              + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
+              + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
+              + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
+              + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
+              + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
+              + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
+
+    *result = sum * dall;
+
+}
+
+void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+    const int j  = iqs / 64;        // j  is in 0...3
+    const int ir = (iqs - 64*j)/2;  // ir is in 0...28 in steps of 4
+    const int is = 2*j;             // is is in 0...6 in steps of 2
+
+    __global const float   * y = yy + 64*j + ir;
+    __global const uint8_t * q = x[ib].qs + 32*j + ir;
+
+    const float dall = vload_half(0, &x[ib].d);
+    const float dmin = vload_half(0, &x[ib].dmin);
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
+    const float d1 = dall * sc;
+    const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
+    const float d2 = dall * sc;
+    const float m2 = dmin * m;
+
+    float sum = 0;
+    for (int k = 0; k < 4; ++k) {
+        sum += y[k +  0] * (d1 * (q[k] & 0xF) - m1);
+        sum += y[k + 32] * (d2 * (q[k] >>  4) - m2);
+    }
+
+    *result = sum;
+}
+
+void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+    const int j  = iqs / 64;
+    const int ir = (iqs - 64*j)/2;
+    const int is = 2*j;
+
+    __global const float   * y  = yy + 64*j + ir;
+    __global const uint8_t * ql = x[ib].qs + 32*j + ir;
+    __global const uint8_t * qh = x[ib].qh + ir;
+
+    const float dall = vload_half(0, &x[ib].d);
+    const float dmin = vload_half(0, &x[ib].dmin);
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
+    const float d1 = dall * sc;
+    const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
+    const float d2 = dall * sc;
+    const float m2 = dmin * m;
+
+    uint8_t hm  = 1 << is;
+    float sum = 0;
+    for (int k = 0; k < 4; ++k) {
+        sum += y[k +  0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
+    }
+    hm <<= 1;
+    for (int k = 0; k < 4; ++k) {
+        sum += y[k + 32] * (d2 * ((ql[k] >>  4) + (qh[k] & hm ? 16 : 0)) - m2);
+    }
+    *result = sum;
+
+}
+
+void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+
+
+    const int ip = iqs / 128;        // 0 or 1
+    const int il = (iqs - 128*ip)/8; // 0...15
+    const int is = 8*ip;
+
+    __global const float * y = yy + 128*ip + il;
+
+    const float d = vload_half(0, &x[ib].d);
+
+    __global const uint8_t * ql = x[ib].ql + 64*ip + il;
+    __global const uint8_t * qh = x[ib].qh + 32*ip + il;
+    __global const int8_t  * sc = x[ib].scales + is;
+
+    *result = y[  0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
+           + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
+           + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
+           + y[ 96] * d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
+           + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
+           + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
+           + y[ 80] * d * sc[5] * ((int8_t)((ql[16]  >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
+           + y[112] * d * sc[7] * ((int8_t)((ql[48]  >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
+
+}
+
+);
+
+
+std::string dequant_template = MULTILINE_QUOTE(
+__kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
+    const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2;
+
+    if (i >= get_global_size(0)) {
+        return;
+    }
+
+    const uint qk = QUANT_K;
+    const uint qr = QUANT_R;
+
+    const int ib = i/qk; // block index
+    const int iqs = (i%qk)/qr; // quant index
+    const int iybs = i - i%qk; // y block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    float v0, v1;
+    DEQUANT_FUNC(x, ib, iqs, &v0, &v1);
+    y[iybs + iqs + 0] = v0;
+    y[iybs + iqs + y_offset] = v1;
+}
+);
+
+std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
+__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
+    const int block_size = get_local_size(0);
+    const int row = get_group_id(0);
+    const int tid = get_local_id(0);
+
+    const uint qk = QUANT_K;
+    const uint qr = QUANT_R;
+
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    tmp[tid] = 0;
+
+    for (int i = 0; i < ncols/block_size; i += 2) {
+        const int col = i*block_size + 2*tid;
+        const int ib = (row*ncols + col)/qk; // block index
+        const int iqs = (col%qk)/qr; // quant index
+        const int iybs = col - col%qk; // y block start index
+
+        // dequantize
+        float v0, v1;
+        DEQUANT_FUNC(x, ib, iqs, &v0, &v1);
+
+        // matrix multiplication
+        tmp[tid] += v0 * y[iybs + iqs + 0];
+        tmp[tid] += v1 * y[iybs + iqs + y_offset];
+    }
+
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=block_size/2; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
+}
+);
+
+std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
+__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
+    const int block_size = get_local_size(0);
+    const int row = get_group_id(0);
+    const int tid = get_local_id(0);
+
+    const int iter_stride = 256;
+    const int vals_per_iter = iter_stride / block_size;
+    const int num_blocks_per_row = ncols / 256;
+    const int ib0 = row*num_blocks_per_row;
+
+    tmp[tid] = 0;
+
+    for (int i = 0; i < ncols; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int ib = ib0 + col/256; // x block index
+        const int iqs = col%256; // x quant index
+        const int iybs = col - col%256; // y block start index
+
+        // dequantize
+        float v;
+        DOT_KERNEL(x, ib, iqs, y + iybs, &v);
+        tmp[tid] += v;
+    }
+
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=block_size/2; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
+}
+);
+
+std::string mul_template = MULTILINE_QUOTE(
+__kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
+    const int i = get_group_id(0)*get_local_size(0) + get_local_id(0);
+
+    if (i >= get_global_size(0)) {
+        return;
+    }
+
+    dst[dst_offset + i] = x[x_offset + i] * y[y_offset + i%ky];
+}
+);
+
+#define CL_CHECK(err)                                               \
+    do {                                                            \
+        cl_int err_ = (err);                                        \
+        if (err_ != CL_SUCCESS) {                                   \
+            fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n",  \
+                #err, err_, __FILE__, __LINE__);                    \
+            exit(1);                                                \
+        }                                                           \
+    } while (0)
+
+#define CLBLAST_CHECK(err)                                          \
+    do {                                                            \
+        CLBlastStatusCode err_ = (err);                             \
+        if (err_ != CLBlastSuccess) {                               \
+            fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n",  \
+                #err, err_, __FILE__, __LINE__);                    \
+            exit(1);                                                \
+        }                                                           \
+    } while (0)
+
+std::array<std::string, 5> dequant_str_keys = {
+    "KERNEL_NAME", "X_TYPE", "QUANT_K", "QUANT_R", "DEQUANT_FUNC"
+};
+
+std::array<std::string, 30> dequant_str_values = {
+    "dequantize_row_q4_0", "struct block_q4_0", "QK4_0", "QR4_0", "dequantize_q4_0",
+    "dequantize_row_q4_1", "struct block_q4_1", "QK4_1", "QR4_1", "dequantize_q4_1",
+    "dequantize_row_q5_0", "struct block_q5_0", "QK5_0", "QR5_0", "dequantize_q5_0",
+    "dequantize_row_q5_1", "struct block_q5_1", "QK5_1", "QR5_1", "dequantize_q5_1",
+    "dequantize_row_q8_0", "struct block_q8_0", "QK8_0", "QR8_0", "dequantize_q8_0",
+    "convert_row_f16", "half", "1", "1", "convert_f16"
+};
+
+std::array<std::string, 30> dequant_mul_mat_vec_str_values = {
+    "dequantize_mul_mat_vec_q4_0", "struct block_q4_0", "QK4_0", "QR4_0", "dequantize_q4_0",
+    "dequantize_mul_mat_vec_q4_1", "struct block_q4_1", "QK4_1", "QR4_1", "dequantize_q4_1",
+    "dequantize_mul_mat_vec_q5_0", "struct block_q5_0", "QK5_0", "QR5_0", "dequantize_q5_0",
+    "dequantize_mul_mat_vec_q5_1", "struct block_q5_1", "QK5_1", "QR5_1", "dequantize_q5_1",
+    "dequantize_mul_mat_vec_q8_0", "struct block_q8_0", "QK8_0", "QR8_0", "dequantize_q8_0",
+    "convert_mul_mat_vec_f16", "half", "1", "1", "convert_f16"
+};
+
+std::array<std::string, 2> mul_str_keys = {
+    "KERNEL_NAME", "TYPE"
+};
+std::array<std::string, 2> mul_str_values = {
+    "mul_f32", "float"
+};
+
+std::array<std::string, 3> dmmv_k_str_keys = {
+    "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
+};
+
+std::array<std::string, 15> dmmv_k_str_values = {
+    "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
+    "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
+    "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
+    "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
+    "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
+};
+
+std::string& replace(std::string& s, const std::string& from, const std::string& to) {
+    size_t pos = 0;
+    while ((pos = s.find(from, pos)) != std::string::npos) {
+         s.replace(pos, from.length(), to);
+         pos += to.length();
+    }
+    return s;
+}
+
+std::string generate_kernels() {
+    std::stringstream src;
+    src << program_source << '\n';
+    for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
+        std::string dequant_kernel = dequant_template;
+        std::string dmmv_kernel = dequant_mul_mat_vec_template;
+        for (size_t j = 0; j < dequant_str_keys.size(); j++) {
+            replace(dequant_kernel, dequant_str_keys[j], dequant_str_values[i + j]);
+            replace(dmmv_kernel, dequant_str_keys[j], dequant_mul_mat_vec_str_values[i + j]);
+        }
+        src << dequant_kernel << '\n';
+        src << dmmv_kernel << '\n';
+    }
+    for (size_t i = 0; i < mul_str_values.size(); i += mul_str_keys.size()) {
+        std::string mul_kernel = mul_template;
+        for (size_t j = 0; j < mul_str_keys.size(); j++) {
+            replace(mul_kernel, mul_str_keys[j], mul_str_values[i + j]);
+        }
+        src << mul_kernel << '\n';
+    }
+    for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
+        std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
+        for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
+            replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
+        }
+        src << dmmv_k_kernel << '\n';
+    }
+
+    return src.str();
+}
+
+static cl_platform_id platform;
+static cl_device_id device;
+static cl_context context;
+static cl_command_queue queue;
+static cl_program program;
+static cl_kernel convert_row_f16_cl;
+static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl;
+static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl;
+static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl, dequantize_block_q5_k_cl, dequantize_block_q6_k_cl;
+static cl_kernel dequantize_mul_mat_vec_q2_K_cl, dequantize_mul_mat_vec_q3_K_cl, dequantize_mul_mat_vec_q4_K_cl, dequantize_mul_mat_vec_q5_K_cl, dequantize_mul_mat_vec_q6_K_cl;
+static cl_kernel mul_f32_cl;
+static bool fp16_support;
+
+static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
+    cl_program p;
+    char *program_log;
+    size_t program_size;
+    size_t log_size;
+    int err;
+
+    program_size = strlen(program_buffer);
+
+    p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);
+    if(err < 0) {
+        fprintf(stderr, "OpenCL error creating program");
+        exit(1);
+    }
+
+    const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
+                               "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
+
+    err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
+    if(err < 0) {
+
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
+        program_log = (char*) malloc(log_size + 1);
+        program_log[log_size] = '\0';
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
+        fprintf(stderr, "ggml_opencl: kernel compile error:\n\n%s\n", program_log);
+        free(program_log);
+        exit(1);
+    }
+
+    return p;
+}
+
+void ggml_cl_init(void) {
+    cl_int err;
+
+    struct cl_device;
+    struct cl_platform {
+        cl_platform_id id;
+        unsigned number;
+        char name[128];
+        char vendor[128];
+        struct cl_device * devices;
+        unsigned n_devices;
+        struct cl_device * default_device;
+    };
+
+    struct cl_device {
+        struct cl_platform * platform;
+        cl_device_id id;
+        unsigned number;
+        cl_device_type type;
+        char name[128];
+    };
+
+    enum { NPLAT = 16, NDEV = 16 };
+
+    struct cl_platform platforms[NPLAT];
+    unsigned n_platforms = 0;
+    struct cl_device devices[NDEV];
+    unsigned n_devices = 0;
+    struct cl_device * default_device = NULL;
+
+    platform = NULL;
+    device = NULL;
+
+    cl_platform_id platform_ids[NPLAT];
+    CL_CHECK(clGetPlatformIDs(NPLAT, platform_ids, &n_platforms));
+
+    for (unsigned i = 0; i < n_platforms; i++) {
+        struct cl_platform * p = &platforms[i];
+        p->number = i;
+        p->id = platform_ids[i];
+        CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
+        CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
+
+        cl_device_id device_ids[NDEV];
+        cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
+        if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
+            p->n_devices = 0;
+        } else {
+            CL_CHECK(clGetDeviceIDsError);
+        }
+        p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
+        p->default_device = NULL;
+
+        for (unsigned j = 0; j < p->n_devices; j++) {
+            struct cl_device * d = &devices[n_devices];
+            d->number = n_devices++;
+            d->id = device_ids[j];
+            d->platform = p;
+            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
+            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
+
+            if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
+                p->default_device = d;
+            }
+        }
+
+        if (default_device == NULL && p->default_device != NULL) {
+            default_device = p->default_device;
+        }
+    }
+
+    if (n_devices == 0) {
+        fprintf(stderr, "ggml_opencl: could find any OpenCL devices.\n");
+        exit(1);
+    }
+
+    char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
+    char * user_device_string = getenv("GGML_OPENCL_DEVICE");
+    int user_platform_number = -1;
+    int user_device_number = -1;
+
+    unsigned n;
+    if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
+        user_platform_number = (int)n;
+    }
+    if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
+        user_device_number = (int)n;
+    }
+    if (user_platform_number != -1 && user_device_number != -1) {
+        cl_platform* platform = &platforms[user_platform_number];
+        if ((unsigned)user_device_number >= platform->n_devices) {
+            fprintf(stderr, "ggml_opencl: invalid device number %d\n", user_device_number);
+            exit(1);
+        }
+        default_device = &platform->devices[user_device_number];
+    } else {
+
+        struct cl_device * selected_devices = devices;
+        unsigned n_selected_devices = n_devices;
+
+        if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
+            for (unsigned i = 0; i < n_platforms; i++) {
+                struct cl_platform * p = &platforms[i];
+                if (strstr(p->name, user_platform_string) != NULL ||
+                    strstr(p->vendor, user_platform_string) != NULL) {
+                    user_platform_number = (int)i;
+                    break;
+                }
+            }
+            if (user_platform_number == -1) {
+                fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
+                exit(1);
+            }
+        }
+        if (user_platform_number != -1) {
+            struct cl_platform * p = &platforms[user_platform_number];
+            selected_devices = p->devices;
+            n_selected_devices = p->n_devices;
+            default_device = p->default_device;
+            if (n_selected_devices == 0) {
+                fprintf(stderr, "ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
+                exit(1);
+            }
+        }
+
+        if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
+            for (unsigned i = 0; i < n_selected_devices; i++) {
+                struct cl_device * d = &selected_devices[i];
+                if (strstr(d->name, user_device_string) != NULL) {
+                    user_device_number = d->number;
+                    break;
+                }
+            }
+            if (user_device_number == -1) {
+                fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", user_device_string);
+                exit(1);
+            }
+        }
+        if (user_device_number != -1) {
+            selected_devices = &devices[user_device_number];
+            n_selected_devices = 1;
+            default_device = &selected_devices[0];
+        }
+
+        GGML_ASSERT(n_selected_devices > 0);
+
+        if (default_device == NULL) {
+            default_device = &selected_devices[0];
+        }
+    }
+
+    fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
+    fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", default_device->name);
+    if (default_device->type != CL_DEVICE_TYPE_GPU) {
+        fprintf(stderr, "ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
+    }
+
+    platform = default_device->platform->id;
+    device = default_device->id;
+
+    size_t ext_str_size;
+    clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size);
+    char *ext_buffer = (char *)alloca(ext_str_size + 1);
+    clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL);
+    ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated
+    // Check if ext_buffer contains cl_khr_fp16
+    fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL;
+    fprintf(stderr, "ggml_opencl: device FP16 support: %s\n", fp16_support ? "true" : "false");
+
+    cl_context_properties properties[] = {
+        (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
+    };
+
+    CL_CHECK((context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
+
+    CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
+        (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err :
+        (queue = clCreateCommandQueue(context, device, 0, &err), err)
+    )));
+
+    const std::string kernel_src = generate_kernels();
+
+    program = build_program_from_source(context, device, kernel_src.c_str());
+
+    // FP16 to FP32 kernel
+    CL_CHECK((convert_row_f16_cl = clCreateKernel(program, "convert_row_f16", &err), err));
+
+    // Dequantize kernels
+    CL_CHECK((dequantize_row_q4_0_cl = clCreateKernel(program, "dequantize_row_q4_0", &err), err));
+    CL_CHECK((dequantize_row_q4_1_cl = clCreateKernel(program, "dequantize_row_q4_1", &err), err));
+    CL_CHECK((dequantize_row_q5_0_cl = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
+    CL_CHECK((dequantize_row_q5_1_cl = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
+    CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
+    CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
+    CL_CHECK((dequantize_block_q2_k_cl = clCreateKernel(program, "dequantize_block_q2_K", &err), err));
+    CL_CHECK((dequantize_block_q3_k_cl = clCreateKernel(program, "dequantize_block_q3_K", &err), err));
+    CL_CHECK((dequantize_block_q4_k_cl = clCreateKernel(program, "dequantize_block_q4_K", &err), err));
+    CL_CHECK((dequantize_block_q5_k_cl = clCreateKernel(program, "dequantize_block_q5_K", &err), err));
+    CL_CHECK((dequantize_block_q6_k_cl = clCreateKernel(program, "dequantize_block_q6_K", &err), err));
+
+    // dequant mul mat kernel
+    CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q4_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_1", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q5_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_0", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q5_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_1", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err));
+    CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
+    CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));
+
+    // mul kernel
+    CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
+}
+
+static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return &dequantize_row_q4_0_cl;
+        case GGML_TYPE_Q4_1:
+            return &dequantize_row_q4_1_cl;
+        case GGML_TYPE_Q5_0:
+            return &dequantize_row_q5_0_cl;
+        case GGML_TYPE_Q5_1:
+            return &dequantize_row_q5_1_cl;
+        case GGML_TYPE_Q8_0:
+            return &dequantize_row_q8_0_cl;
+        case GGML_TYPE_Q2_K:
+            return &dequantize_block_q2_k_cl;
+        case GGML_TYPE_Q3_K:
+            return &dequantize_block_q3_k_cl;
+        case GGML_TYPE_Q4_K:
+            return &dequantize_block_q4_k_cl;
+        case GGML_TYPE_Q5_K:
+            return &dequantize_block_q5_k_cl;
+        case GGML_TYPE_Q6_K:
+            return &dequantize_block_q6_k_cl;
+        case GGML_TYPE_F16:
+            return &convert_row_f16_cl;
+        default:
+            return nullptr;
+    }
+}
+
+static size_t ggml_cl_global_denom(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            return 1;
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+            return 4;
+        case GGML_TYPE_Q4_K:
+            return 8;
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+            return 4;
+        case GGML_TYPE_F16:
+        default:
+            return 1;
+    }
+}
+
+static size_t ggml_cl_local_size(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            return 0;
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+            return 64;
+        case GGML_TYPE_Q4_K:
+            return 32;
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+            return 64;
+        case GGML_TYPE_F16:
+        default:
+            return 0;
+    }
+}
+
+static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return &dequantize_mul_mat_vec_q4_0_cl;
+        case GGML_TYPE_Q4_1:
+            return &dequantize_mul_mat_vec_q4_1_cl;
+        case GGML_TYPE_Q5_0:
+            return &dequantize_mul_mat_vec_q5_0_cl;
+        case GGML_TYPE_Q5_1:
+            return &dequantize_mul_mat_vec_q5_1_cl;
+        case GGML_TYPE_Q8_0:
+            return &dequantize_mul_mat_vec_q8_0_cl;
+        case GGML_TYPE_F16:
+            return &convert_mul_mat_vec_f16_cl;
+        case GGML_TYPE_Q2_K:
+            return &dequantize_mul_mat_vec_q2_K_cl;
+        case GGML_TYPE_Q3_K:
+            return &dequantize_mul_mat_vec_q3_K_cl;
+        case GGML_TYPE_Q4_K:
+            return &dequantize_mul_mat_vec_q4_K_cl;
+        case GGML_TYPE_Q5_K:
+            return &dequantize_mul_mat_vec_q5_K_cl;
+        case GGML_TYPE_Q6_K:
+            return &dequantize_mul_mat_vec_q6_K_cl;
+        default:
+            return nullptr;
+    }
+}
+
+// buffer pool for cl
+#define MAX_CL_BUFFERS 256
+
+struct scoped_spin_lock {
+    std::atomic_flag& lock;
+    scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
+        while (lock.test_and_set(std::memory_order_acquire)) {
+            ; // spin
+        }
+    }
+    ~scoped_spin_lock() {
+        lock.clear(std::memory_order_release);
+    }
+    scoped_spin_lock(const scoped_spin_lock&) = delete;
+    scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
+};
+
+struct cl_buffer {
+    cl_mem mem;
+    size_t size = 0;
+};
+
+static cl_buffer g_cl_buffer_pool[MAX_CL_BUFFERS];
+static std::atomic_flag g_cl_pool_lock = ATOMIC_FLAG_INIT;
+
+static cl_mem ggml_cl_pool_malloc(size_t size, size_t * actual_size) {
+    scoped_spin_lock lock(g_cl_pool_lock);
+    cl_int err;
+
+    int best_i = -1;
+    size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
+    int worst_i = -1;
+    size_t worst_size = 0; //largest unused buffer seen so far
+    for (int i = 0; i < MAX_CL_BUFFERS; ++i) {
+        cl_buffer &b = g_cl_buffer_pool[i];
+        if (b.size > 0 && b.size >= size && b.size < best_size)
+        {
+            best_i = i;
+            best_size = b.size;
+        }
+        if (b.size > 0 && b.size > worst_size)
+        {
+            worst_i = i;
+            worst_size = b.size;
+        }
+    }
+    if(best_i!=-1) //found the smallest buffer that fits our needs
+    {
+        cl_buffer& b = g_cl_buffer_pool[best_i];
+        cl_mem mem = b.mem;
+        *actual_size = b.size;
+        b.size = 0;
+        return mem;
+    }
+    if(worst_i!=-1) //no buffer that fits our needs, resize largest one to save memory
+    {
+         cl_buffer& b = g_cl_buffer_pool[worst_i];
+         cl_mem mem = b.mem;
+         b.size = 0;
+         clReleaseMemObject(mem);
+    }
+    cl_mem mem;
+    CL_CHECK((mem = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err), err));
+    *actual_size = size;
+    return mem;
+}
+
+static void ggml_cl_pool_free(cl_mem mem, size_t size) {
+    scoped_spin_lock lock(g_cl_pool_lock);
+
+    for (int i = 0; i < MAX_CL_BUFFERS; ++i) {
+        cl_buffer& b = g_cl_buffer_pool[i];
+        if (b.size == 0) {
+            b.mem = mem;
+            b.size = size;
+            return;
+        }
+    }
+    fprintf(stderr, "WARNING: cl buffer pool full, increase MAX_CL_BUFFERS\n");
+    clReleaseMemObject(mem);
+}
+
+void ggml_cl_free_data(const struct ggml_tensor* tensor) {
+    if (tensor->backend != GGML_BACKEND_GPU) {
+        return;
+    }
+
+    cl_mem mem = (cl_mem)tensor->data;
+    clReleaseMemObject(mem);
+}
+
+static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cl_event* ev) {
+    cl_int err;
+    const uint64_t ne0 = src->ne[0];
+    const uint64_t ne1 = src->ne[1];
+    const uint64_t nb0 = src->nb[0];
+    const uint64_t nb1 = src->nb[1];
+    const uint64_t nb2 = src->nb[2];
+    const uint64_t nb3 = src->nb[3];
+    const enum ggml_type type = src->type;
+    const size_t ts = ggml_type_size(type);
+    const size_t bs = ggml_blck_size(type);
+
+    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        err = clEnqueueWriteBuffer(queue, dst, CL_FALSE, offset, ne1*nb1, x, 0, NULL, ev);
+        return err;
+    }
+    if (nb0 == ts) {
+        const size_t buffer_origin[3] = { offset, 0, 0 };
+        const size_t host_origin[3] = { 0, 0, 0 };
+        const size_t region[3] = { ts*ne0/bs, ne1, 1 };
+        err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, ts*ne0/bs, 0, nb1, 0, x, 0, NULL, ev);
+        return err;
+    }
+    for (uint64_t i1 = 0; i1 < ne1; i1++) {
+        // pretend the row is a matrix with cols=1
+        const size_t buffer_origin[3] = { offset, i1, 0 };
+        const size_t host_origin[3] = { 0, 0, 0 };
+        const size_t region[3] = { ts/bs, ne0, 1 };
+        err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, 0, 0, nb0, 0, ((const char *)x) + i1*nb0, 0, NULL, ev);
+        if (err != CL_SUCCESS) {
+            break;
+        }
+    }
+    return err;
+}
+
+static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[2];
+    const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+    const int64_t nb10 = src1->nb[0];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+    size_t x_size;
+    size_t d_size;
+
+    cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
+    cl_mem d_Y = (cl_mem) src1->data; // src1 is already on device, broadcasted.
+    cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
+
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            const int i0 = i03*ne02 + i02;
+
+            cl_event ev;
+
+            // copy src0 to device
+            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev));
+
+            if (nb10 == sizeof(float)) {
+                // Contiguous, avoid overhead from queueing many kernel runs
+                const int64_t i13 = i03%ne13;
+                const int64_t i12 = i02%ne12;
+                const int i1 = i13*ne12*ne11 + i12*ne11;
+
+                cl_int x_offset = 0;
+                cl_int y_offset = i1*ne10;
+                cl_int d_offset = 0;
+
+                size_t global = ne00 * ne01;
+                cl_int ky = ne10;
+                CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
+                CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
+                CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
+                CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
+                CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
+                CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
+                CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
+                CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
+            } else {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const int64_t i13 = i03%ne13;
+                    const int64_t i12 = i02%ne12;
+                    const int64_t i11 = i01%ne11;
+                    const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
+
+                    cl_int x_offset = i01*ne00;
+                    cl_int y_offset = i1*ne10;
+                    cl_int d_offset = i01*ne00;
+
+                    // compute
+                    size_t global = ne00;
+                    cl_int ky = ne10;
+                    CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
+                    CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
+                    CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
+                    CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
+                    CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
+                    CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
+                    CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
+                    CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
+                }
+            }
+
+            CL_CHECK(clReleaseEvent(ev));
+            CL_CHECK(clFinish(queue));
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * ne00*ne01, d, 0, NULL, NULL));
+        }
+    }
+    ggml_cl_pool_free(d_X, x_size);
+    ggml_cl_pool_free(d_D, d_size);
+}
+
+void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cl_mul_f32(src0, src1, dst);
+}
+
+static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+    const int x_ne = ne01 * ne00;
+    const int y_ne = ne11 * ne10;
+    const int d_ne = ne11 * ne01;
+
+    size_t x_size;
+    size_t y_size;
+    size_t d_size;
+    cl_mem d_X;
+    if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
+        d_X = (cl_mem) src0->data;
+    } else {
+        d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
+    }
+    cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
+    cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            // copy data to device
+            if (src0->backend != GGML_BACKEND_GPU) {
+                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
+            }
+            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
+
+            CL_CHECK(clFinish(queue));
+
+            // compute
+            cl_event ev_sgemm;
+            clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
+                                                       clblast::Transpose::kYes, clblast::Transpose::kNo,
+                                                       ne01, ne11, ne10,
+                                                       alpha,
+                                                       d_X, 0, ne00,
+                                                       d_Y, 0, ne10,
+                                                       beta,
+                                                       d_D, 0, ne01,
+                                                       &queue, &ev_sgemm);
+
+            if (status != clblast::StatusCode::kSuccess) {
+                GGML_ASSERT(false);
+            }
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
+        }
+    }
+
+    if (src0->backend != GGML_BACKEND_GPU) {
+        ggml_cl_pool_free(d_X, x_size);
+    }
+    ggml_cl_pool_free(d_Y, y_size);
+    ggml_cl_pool_free(d_D, d_size);
+}
+
+static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
+    GGML_ASSERT(fp16_support);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const ggml_fp16_t alpha = ggml_fp32_to_fp16(1.0f);
+    const ggml_fp16_t beta = ggml_fp32_to_fp16(0.0f);
+    const int x_ne = ne01 * ne00;
+    const int y_ne = ne11 * ne10;
+    const int d_ne = ne11 * ne01;
+
+    size_t x_size;
+    size_t y_size;
+    size_t d_size;
+    cl_mem d_X;
+    if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
+        d_X = (cl_mem) src0->data;
+    } else {
+        d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
+    }
+    cl_mem d_Y = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &y_size);
+    cl_mem d_D = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * d_ne, &d_size);
+
+    bool src1_cont_rows = nb10 == sizeof(float);
+    bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            // copy src0 to device
+            if (src0->backend != GGML_BACKEND_GPU) {
+                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
+            }
+
+            // convert src1 to fp16
+            // TODO: use multiple threads
+            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
+            char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
+            if (src1_cont_rows) {
+                if (src1_cont_cols) {
+                    ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
+                }
+                else {
+                    for (int64_t i01 = 0; i01 < ne11; i01++) {
+                        ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
+                    }
+                }
+            }
+            else {
+                for (int64_t i01 = 0; i01 < ne11; i01++) {
+                    for (int64_t i00 = 0; i00 < ne10; i00++) {
+                        // very slow due to no inlining
+                        tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
+                    }
+                }
+            }
+
+            // copy src1 to device
+            CL_CHECK(clEnqueueWriteBuffer(queue, d_Y, false, 0, sizeof(ggml_fp16_t) * y_ne, tmp, 0, NULL, NULL));
+
+            CL_CHECK(clFinish(queue));
+
+            // compute
+            cl_event ev_sgemm;
+            clblast::StatusCode status = clblast::Gemm<cl_half>(clblast::Layout::kColMajor,
+                                                       clblast::Transpose::kYes, clblast::Transpose::kNo,
+                                                       ne01, ne11, ne10,
+                                                       alpha,
+                                                       d_X, 0, ne00,
+                                                       d_Y, 0, ne10,
+                                                       beta,
+                                                       d_D, 0, ne01,
+                                                       &queue, &ev_sgemm);
+
+            if (status != clblast::StatusCode::kSuccess) {
+                GGML_ASSERT(false);
+            }
+
+            // copy dst to host, then convert to float
+            CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
+
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+            ggml_fp16_to_fp32_row(tmp, d, d_ne);
+        }
+    }
+
+    if (src0->backend != GGML_BACKEND_GPU) {
+        ggml_cl_pool_free(d_X, x_size);
+    }
+    ggml_cl_pool_free(d_Y, y_size);
+    ggml_cl_pool_free(d_D, d_size);
+}
+
+static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+    const ggml_type type = src0->type;
+    const bool mul_mat_vec = ne11 == 1;
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+    const int x_ne = ne01 * ne00;
+    const int y_ne = ne11 * ne10;
+    const int d_ne = ne11 * ne01;
+    const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
+
+    size_t x_size;
+    size_t y_size;
+    size_t d_size;
+    size_t q_size;
+    cl_mem d_X;
+    if (!mul_mat_vec) {
+        d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
+    }
+    cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
+    cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
+    cl_mem d_Q;
+    if (src0->backend == GGML_BACKEND_CPU) {
+        d_Q = ggml_cl_pool_malloc(q_sz, &q_size);
+    }
+
+    cl_kernel* to_fp32_cl = ggml_get_to_fp32_cl(type);
+    cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type);
+    GGML_ASSERT(to_fp32_cl != nullptr);
+
+    const size_t global_denom = ggml_cl_global_denom(type);
+    const size_t local = ggml_cl_local_size(type);
+
+    size_t ev_idx = 0;
+    std::vector<cl_event> events;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            // copy src0 to device if necessary
+            if (src0->backend == GGML_BACKEND_CPU) {
+                events.emplace_back();
+                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
+            } else if (src0->backend == GGML_BACKEND_GPU) {
+                d_Q = (cl_mem) src0->data;
+            } else {
+                GGML_ASSERT(false);
+            }
+            if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
+                // copy src1 to device
+                events.emplace_back();
+                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, events.data() + ev_idx++));
+
+                // compute
+                const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
+                const size_t local = CL_DMMV_BLOCK_SIZE;
+                const cl_int ncols = ne00;
+                events.emplace_back();
+                CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q));
+                CL_CHECK(clSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL));
+                CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y));
+                CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D));
+                CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols));
+                CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
+            } else { // general dequantization kernel + CLBlast matrix matrix multiplication
+                // convert src0 to fp32 on device
+                const size_t global = x_ne / global_denom;
+                CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
+                CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
+                CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
+
+                // copy src1 to device
+                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
+
+                events.emplace_back();
+
+                // wait for conversion
+                CL_CHECK(clFinish(queue));
+
+                // compute
+                clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
+                                                           clblast::Transpose::kYes, clblast::Transpose::kNo,
+                                                           ne01, ne11, ne10,
+                                                           alpha,
+                                                           d_X, 0, ne00,
+                                                           d_Y, 0, ne10,
+                                                           beta,
+                                                           d_D, 0, ne01,
+                                                           &queue, events.data() + ev_idx++);
+
+                if (status != clblast::StatusCode::kSuccess) {
+                    GGML_ASSERT(false);
+                }
+            }
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL));
+            for (auto *event : events) {
+                clReleaseEvent(event);
+            }
+
+            ev_idx = 0;
+            events.clear();
+        }
+    }
+
+    if (!mul_mat_vec) {
+        ggml_cl_pool_free(d_X, x_size);
+    }
+    ggml_cl_pool_free(d_Y, y_size);
+    ggml_cl_pool_free(d_D, d_size);
+    if (src0->backend == GGML_BACKEND_CPU) {
+        ggml_cl_pool_free(d_Q, q_size);
+    }
+}
+
+
+bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    // TODO: find the optimal values for these
+    if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+        src1->type == GGML_TYPE_F32 &&
+        dst->type == GGML_TYPE_F32 &&
+        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {
+        return true;
+    }
+
+    return false;
+}
+
+bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
+    // If device doesn't support FP16
+    if (!fp16_support) {
+        return false;
+    }
+
+    size_t src0_sz = ggml_nbytes(src0);
+    size_t src1_sz = ggml_nbytes(src1);
+
+    // mul_mat_q: src0 is converted to fp32 on device
+    size_t mul_mat_q_transfer = src0_sz + src1_sz;
+
+    // mul_mat_f16: src1 is converted to fp16 on cpu
+    size_t mul_mat_f16_transfer = src0_sz + sizeof(ggml_fp16_t) * ggml_nelements(src1);
+
+    // choose the smaller one to transfer to the device
+    // TODO: this is not always the best choice due to the overhead of converting to fp16
+    return mul_mat_f16_transfer < mul_mat_q_transfer;
+}
+
+void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize) {
+    GGML_ASSERT(ggml_cl_can_mul_mat(src0, src1, dst));
+
+    if (src0->type == GGML_TYPE_F32) {
+        ggml_cl_mul_mat_f32(src0, src1, dst);
+    }
+    else if (src0->type == GGML_TYPE_F16) {
+        if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
+            ggml_cl_mul_mat_f16(src0, src1, dst, wdata, wsize);
+        }
+        else {
+            ggml_cl_mul_mat_q_f32(src0, src1, dst);
+        }
+    }
+    else if (ggml_is_quantized(src0->type)) {
+        ggml_cl_mul_mat_q_f32(src0, src1, dst);
+    }
+    else {
+        GGML_ASSERT(false);
+    }
+}
+
+size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
+        return ggml_nelements(src1) * sizeof(ggml_fp16_t);
+    }
+    return 0;
+}
+
+void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
+    const int64_t ne0 = tensor->ne[0];
+    const int64_t ne1 = tensor->ne[1];
+    const int64_t ne2 = tensor->ne[2];
+    const int64_t ne3 = tensor->ne[3];
+
+    const ggml_type type = tensor->type;
+    const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
+
+    size_t q_size;
+    cl_mem dst = ggml_cl_pool_malloc(q_sz, &q_size);
+
+    tensor->data = data;
+    // copy tensor to device
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            int i = i3*ne2 + i2;
+            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, dst, i*ne0*ne1, tensor, i3, i2, NULL));
+        }
+    }
+
+    CL_CHECK(clFinish(queue));
+
+    tensor->data = dst;
+    GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+}
index 5a1a500930b9aa3bc8912e6d41fa8a5b741e8da7..a92b445c9d7660da6ed5dfcbc4cce9ae7a5b9827 100644 (file)
@@ -8,6 +8,7 @@ extern "C" {
 
 void ggml_cl_init(void);
 
+void   ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 bool   ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 void   ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
@@ -15,7 +16,9 @@ void   ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor
 void * ggml_cl_host_malloc(size_t size);
 void   ggml_cl_host_free(void * ptr);
 
-void ggml_cl_transform_tensor(struct ggml_tensor * tensor);
+void ggml_cl_free_data(const struct ggml_tensor* tensor);
+
+void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
 
 #ifdef  __cplusplus
 }
index f92e7d5d66f695713eadf200ffe937069dd90cbf..14e08f9d61fd73e6e6512020d6b0922a375790cb 100644 (file)
@@ -3,6 +3,10 @@
 
 #include "ggml.h"
 
+#ifdef GGML_USE_K_QUANTS
+#include "k_quants.h"
+#endif
+
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
 #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
 #include <float.h>
 #include <limits.h>
 
+#ifdef GGML_USE_METAL
+#include <unistd.h>
+#endif
+
 // if C99 - static_assert is noop
 // ref: https://stackoverflow.com/a/53923785/4039976
 #ifndef static_assert
 #define static_assert(cond, msg) struct global_scope_noop_trick
 #endif
 
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+#endif
+
 #if defined(_WIN32)
 
 #include <windows.h>
@@ -122,7 +136,11 @@ typedef void* thread_ret_t;
 #else
 inline static void* ggml_aligned_malloc(size_t size) {
     void* aligned_memory = NULL;
+#ifdef GGML_USE_METAL
+    int result = posix_memalign(&aligned_memory, getpagesize(), size);
+#else
     int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
+#endif
     if (result != 0) {
         // Handle allocation failure
         return NULL;
@@ -407,21 +425,27 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
 //
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
-static int64_t timer_freq;
+static int64_t timer_freq, timer_start;
 void ggml_time_init(void) {
-    LARGE_INTEGER frequency;
-    QueryPerformanceFrequency(&frequency);
-    timer_freq = frequency.QuadPart;
+    LARGE_INTEGER t;
+    QueryPerformanceFrequency(&t);
+    timer_freq = t.QuadPart;
+
+    // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
+    // and the uptime is high enough.
+    // We subtract the program start time to reduce the likelihood of that happening.
+    QueryPerformanceCounter(&t);
+    timer_start = t.QuadPart;
 }
 int64_t ggml_time_ms(void) {
     LARGE_INTEGER t;
     QueryPerformanceCounter(&t);
-    return (t.QuadPart * 1000) / timer_freq;
+    return ((t.QuadPart-timer_start) * 1000) / timer_freq;
 }
 int64_t ggml_time_us(void) {
     LARGE_INTEGER t;
     QueryPerformanceCounter(&t);
-    return (t.QuadPart * 1000000) / timer_freq;
+    return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
 }
 #else
 void ggml_time_init(void) {}
@@ -478,6 +502,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 // quantization
 //
 
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 // multiply int8_t, add results pairwise twice
 static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
@@ -537,7 +563,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
 static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
 {
     const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
-    const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
+    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
     const __m256i lowMask = _mm256_set1_epi8( 0xF );
     return _mm256_and_si256(lowMask, bytes);
 }
@@ -610,7 +636,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
     bytesh = _mm_or_si128(bytesh, bit_mask);
     bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
     bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
-    return _mm256_set_m128i(bytesh, bytesl);
+    return MM256_SET_M128I(bytesh, bytesl);
 }
 
 // Unpack 32 4-bit fields into 32 bytes
@@ -623,7 +649,7 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
     const __m128i lowMask = _mm_set1_epi8(0xF);
     tmpl = _mm_and_si128(lowMask, tmpl);
     tmph = _mm_and_si128(lowMask, tmph);
-    return _mm256_set_m128i(tmph, tmpl);
+    return MM256_SET_M128I(tmph, tmpl);
 }
 
 // add int16_t pairwise and return as float vector
@@ -631,7 +657,7 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
     const __m128i ones = _mm_set1_epi16(1);
     const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
     const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
-    const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
+    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
@@ -1569,6 +1595,48 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
         .vec_dot_q                = NULL,   // TODO
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
+#ifdef GGML_USE_K_QUANTS
+    [GGML_TYPE_Q2_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q2_K,
+        .quantize_row_q           = quantize_row_q2_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q2_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q3_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q3_K,
+        .quantize_row_q           = quantize_row_q3_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q3_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q4_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q4_K,
+        .quantize_row_q           = quantize_row_q4_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q4_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q5_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q5_K,
+        .quantize_row_q           = quantize_row_q5_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q5_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+    [GGML_TYPE_Q6_K] = {
+        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q6_K,
+        .quantize_row_q           = quantize_row_q6_K,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
+        .quantize_row_q_dot       = quantize_row_q8_K,
+        .vec_dot_q                = ggml_vec_dot_q6_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+    },
+#endif
 };
 
 // For internal test use
@@ -2315,7 +2383,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
 
         // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1));
+        __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
 
         // Apply the scale, and accumulate
         acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
@@ -2791,7 +2859,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         __m128i bxh = _mm256_extractf128_si256(bx, 1);
         bxl = _mm_or_si128(bxl, bxhil);
         bxh = _mm_or_si128(bxh, bxhih);
-        bx = _mm256_set_m128i(bxh, bxl);
+        bx = MM256_SET_M128I(bxh, bxl);
 
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
@@ -3047,7 +3115,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         __m128i bxh = _mm256_extractf128_si256(bx, 1);
         bxl = _mm_or_si128(bxl, bxhil);
         bxh = _mm_or_si128(bxh, bxhih);
-        bx = _mm256_set_m128i(bxh, bxl);
+        bx = MM256_SET_M128I(bxh, bxl);
 
         const __m256 dy = _mm256_set1_ps(y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
@@ -3348,12 +3416,12 @@ inline static float ggml_gelu_quick_f32(float x) {
     return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
 }
 
-inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
-    const uint16_t * i16 = (const uint16_t *) x;
-    for (int i = 0; i < n; ++i) {
-        y[i] = table_gelu_quick_f16[i16[i]];
-    }
-}
+//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+//    const uint16_t * i16 = (const uint16_t *) x;
+//    for (int i = 0; i < n; ++i) {
+//        y[i] = table_gelu_quick_f16[i16[i]];
+//    }
+//}
 
 #ifdef GGML_GELU_QUICK_FP16
 inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
@@ -3498,11 +3566,19 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q5_1] = QK5_1,
     [GGML_TYPE_Q8_0] = QK8_0,
     [GGML_TYPE_Q8_1] = QK8_1,
+#ifdef GGML_USE_K_QUANTS
+    [GGML_TYPE_Q2_K] = QK_K,
+    [GGML_TYPE_Q3_K] = QK_K,
+    [GGML_TYPE_Q4_K] = QK_K,
+    [GGML_TYPE_Q5_K] = QK_K,
+    [GGML_TYPE_Q6_K] = QK_K,
+    [GGML_TYPE_Q8_K] = QK_K,
+#endif
     [GGML_TYPE_I8]   = 1,
     [GGML_TYPE_I16]  = 1,
     [GGML_TYPE_I32]  = 1,
 };
-static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = sizeof(float),
@@ -3513,11 +3589,19 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
     [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
     [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
+#ifdef GGML_USE_K_QUANTS
+    [GGML_TYPE_Q2_K] = sizeof(block_q2_K),
+    [GGML_TYPE_Q3_K] = sizeof(block_q3_K),
+    [GGML_TYPE_Q4_K] = sizeof(block_q4_K),
+    [GGML_TYPE_Q5_K] = sizeof(block_q5_K),
+    [GGML_TYPE_Q6_K] = sizeof(block_q6_K),
+    [GGML_TYPE_Q8_K] = sizeof(block_q8_K),
+#endif
     [GGML_TYPE_I8]   = sizeof(int8_t),
     [GGML_TYPE_I16]  = sizeof(int16_t),
     [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
 
 
 static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3529,11 +3613,17 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q5_1] = "q5_1",
     [GGML_TYPE_Q8_0] = "q8_0",
     [GGML_TYPE_Q8_1] = "q8_1",
+    [GGML_TYPE_Q2_K] = "q2_K",
+    [GGML_TYPE_Q3_K] = "q3_K",
+    [GGML_TYPE_Q4_K] = "q4_K",
+    [GGML_TYPE_Q5_K] = "q5_K",
+    [GGML_TYPE_Q6_K] = "q6_K",
+    [GGML_TYPE_Q8_K] = "q8_K",
     [GGML_TYPE_I8]   = "i8",
     [GGML_TYPE_I16]  = "i16",
     [GGML_TYPE_I32]  = "i32",
 };
-static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
+static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
 
 static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = false,
@@ -3544,11 +3634,17 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q5_1] = true,
     [GGML_TYPE_Q8_0] = true,
     [GGML_TYPE_Q8_1] = true,
+    [GGML_TYPE_Q2_K] = true,
+    [GGML_TYPE_Q3_K] = true,
+    [GGML_TYPE_Q4_K] = true,
+    [GGML_TYPE_Q5_K] = true,
+    [GGML_TYPE_Q6_K] = true,
+    [GGML_TYPE_Q8_K] = true,
     [GGML_TYPE_I8]   = false,
     [GGML_TYPE_I16]  = false,
     [GGML_TYPE_I32]  = false,
 };
-static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
+static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
 
 static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "NONE",
@@ -3567,6 +3663,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "SUM_ROWS",
     "MEAN",
     "REPEAT",
+    "REPEAT_BACK",
     "ABS",
     "SGN",
     "NEG",
@@ -3581,6 +3678,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "RMS_NORM_BACK",
 
     "MUL_MAT",
+    "OUT_PROD",
 
     "SCALE",
     "SET",
@@ -3596,6 +3694,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "DIAG_MASK_INF",
     "DIAG_MASK_ZERO",
     "SOFT_MAX",
+    "SOFT_MAX_BACK",
     "ROPE",
     "ROPE_BACK",
     "ALIBI",
@@ -3606,14 +3705,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
 
     "FLASH_ATTN",
     "FLASH_FF",
+    "FLASH_ATTN_BACK",
     "WIN_PART",
     "WIN_UNPART",
 
     "MAP_UNARY",
     "MAP_BINARY",
+
+    "CROSS_ENTROPY_LOSS",
+    "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
+static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3632,6 +3735,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "Σx_k",
     "Σx/n",
     "repeat(x)",
+    "repeat_back(x)",
     "abs(x)",
     "sgn(x)",
     "-x",
@@ -3645,6 +3749,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rms_norm(x)",
     "rms_norm_back(x)",
 
+    "X*Y",
     "X*Y",
 
     "x*v",
@@ -3661,6 +3766,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "diag_mask_inf(x)",
     "diag_mask_zero(x)",
     "soft_max(x)",
+    "soft_max_back(x)",
     "rope(x)",
     "rope_back(x)",
     "alibi(x)",
@@ -3671,14 +3777,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 
     "flash_attn(x)",
     "flash_ff(x)",
+    "flash_attn_back(x)",
     "win_part(x)",
     "win_unpart(x)",
 
     "f(x)",
     "f(x,y)",
+
+    "cross_entropy_loss(x,y)",
+    "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
+static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3692,6 +3802,7 @@ struct ggml_context {
     void * mem_buffer;
     bool   mem_buffer_owned;
     bool   no_alloc;
+    bool   no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
 
     int    n_objects;
 
@@ -3708,26 +3819,6 @@ struct ggml_context_container {
     struct ggml_context context;
 };
 
-//
-// compute types
-//
-
-enum ggml_task_type {
-    GGML_TASK_INIT = 0,
-    GGML_TASK_COMPUTE,
-    GGML_TASK_FINALIZE,
-};
-
-struct ggml_compute_params {
-    enum ggml_task_type type;
-
-    int ith, nth;
-
-    // work buffer for all threads
-    size_t wsize;
-    void * wdata;
-};
-
 //
 // ggml state
 //
@@ -3784,7 +3875,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
     return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
 }
 
-int ggml_nrows(const struct ggml_tensor * tensor) {
+int64_t ggml_nrows(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -3793,7 +3884,20 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
+    // this should handle cases where the tensor is not contiguous in memory
+    // probaby just:
+    //
+    //     return tensor->ne[3]*tensor->nb[3]
+    //
+    // is enough, but just in case, adding the second part
+
+    return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
+}
+
+size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
 }
 
 int ggml_blck_size(enum ggml_type type) {
@@ -3847,6 +3951,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
         (t0->ne[3] == t1->ne[3]);
 }
 
+static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t0->ne[1] == t1->ne[1])  &&
+        (t0->ne[2] == t1->ne[2])  &&
+        (t0->ne[3] == t1->ne[3]);
+}
+
 bool ggml_is_quantized(enum ggml_type type) {
     return GGML_IS_QUANTIZED[type];
 }
@@ -3862,6 +3975,11 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_Q5_0:          wtype = GGML_TYPE_Q5_0;  break;
         case GGML_FTYPE_MOSTLY_Q5_1:          wtype = GGML_TYPE_Q5_1;  break;
         case GGML_FTYPE_MOSTLY_Q8_0:          wtype = GGML_TYPE_Q8_0;  break;
+        case GGML_FTYPE_MOSTLY_Q2_K:          wtype = GGML_TYPE_Q2_K;  break;
+        case GGML_FTYPE_MOSTLY_Q3_K:          wtype = GGML_TYPE_Q3_K;  break;
+        case GGML_FTYPE_MOSTLY_Q4_K:          wtype = GGML_TYPE_Q4_K;  break;
+        case GGML_FTYPE_MOSTLY_Q5_K:          wtype = GGML_TYPE_Q5_K;  break;
+        case GGML_FTYPE_MOSTLY_Q6_K:          wtype = GGML_TYPE_Q6_K;  break;
         case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
     }
@@ -3875,11 +3993,11 @@ size_t ggml_tensor_overhead(void) {
     return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
 }
 
-static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+bool ggml_is_transposed(const struct ggml_tensor * tensor) {
     return tensor->nb[0] > tensor->nb[1];
 }
 
-static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -3889,6 +4007,12 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
+bool ggml_is_permuted(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
+}
+
 static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
@@ -4029,6 +4153,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
         /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
         /*.no_alloc           =*/ params.no_alloc,
+        /*.no_alloc_save      =*/ params.no_alloc,
         /*.n_objects          =*/ 0,
         /*.objects_begin      =*/ NULL,
         /*.objects_end        =*/ NULL,
@@ -4092,25 +4217,52 @@ void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
     ctx->no_alloc = no_alloc;
 }
 
-void * ggml_get_mem_buffer(struct ggml_context * ctx) {
+void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
     return ctx->mem_buffer;
 }
 
-size_t ggml_get_mem_size(struct ggml_context * ctx) {
+size_t ggml_get_mem_size(const struct ggml_context * ctx) {
     return ctx->mem_size;
 }
 
+size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
+    size_t max_size = 0;
+
+    struct ggml_object * obj = ctx->objects_begin;
+
+    while (obj != NULL) {
+        struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
+
+        const size_t size = ggml_nbytes(tensor);
+
+        if (max_size < size) {
+            max_size = size;
+        }
+
+        obj = obj->next;
+    }
+
+    return max_size;
+}
+
 // IMPORTANT:
 // when creating "opt" tensors, always save and load the scratch buffer
 // this is an error prone process, but it is necessary to support inplace
 // operators when using scratch buffers
 // TODO: implement a better way
 void ggml_scratch_save(struct ggml_context * ctx) {
+    // this is needed to allow opt tensors to store their data
+    // TODO: again, need to find a better way
+    ctx->no_alloc_save = ctx->no_alloc;
+    ctx->no_alloc      = false;
+
     ctx->scratch_save = ctx->scratch;
     ctx->scratch.data = NULL;
 }
 
 void ggml_scratch_load(struct ggml_context * ctx) {
+    ctx->no_alloc = ctx->no_alloc_save;
+
     ctx->scratch = ctx->scratch_save;
 }
 
@@ -4219,6 +4371,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.perf_time_us =*/ 0,
         /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
+        /*.extra        =*/ NULL,
         /*.pad          =*/ { 0 },
     };
 
@@ -4658,7 +4811,7 @@ struct ggml_tensor * ggml_add_impl(
 
     bool is_node = false;
 
-    if (!inplace && (a->grad || b->grad)) {
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
@@ -4698,7 +4851,7 @@ struct ggml_tensor * ggml_add1_impl(
 
     bool is_node = false;
 
-    if (!inplace && (a->grad || b->grad)) {
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
@@ -5124,6 +5277,34 @@ struct ggml_tensor * ggml_repeat(
     return result;
 }
 
+// ggml_repeat_back
+
+struct ggml_tensor * ggml_repeat_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    GGML_ASSERT(ggml_can_repeat(b, a));
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    if (ggml_are_same_shape(a, b) && !is_node) {
+        return a;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
+
+    result->op   = GGML_OP_REPEAT_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_abs
 
 struct ggml_tensor * ggml_abs_impl(
@@ -5535,6 +5716,32 @@ struct ggml_tensor * ggml_mul_mat(
     return result;
 }
 
+// ggml_out_prod
+
+struct ggml_tensor * ggml_out_prod(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_can_out_prod(a, b));
+    GGML_ASSERT(!ggml_is_transposed(a));
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+
+    result->op   = GGML_OP_OUT_PROD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_scale
 
 struct ggml_tensor * ggml_scale_impl(
@@ -5547,7 +5754,7 @@ struct ggml_tensor * ggml_scale_impl(
 
     bool is_node = false;
 
-    if (!inplace && (a->grad || b->grad)) {
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
@@ -5590,7 +5797,7 @@ struct ggml_tensor * ggml_set_impl(
 
     bool is_node = false;
 
-    if (!inplace && (a->grad || b->grad)) {
+    if (a->grad || b->grad) {
         is_node = true;
     }
 
@@ -5899,14 +6106,18 @@ struct ggml_tensor * ggml_view_1d(
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
 
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
-
-    if (is_node) {
-        memcpy(result->padding, &offset, sizeof(offset));
-    }
+    result->opt[0] = offs;
 
     return result;
 }
@@ -5931,6 +6142,13 @@ struct ggml_tensor * ggml_view_2d(
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
 
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
     result->nb[1] = nb1;
     result->nb[2] = result->nb[1]*ne1;
     result->nb[3] = result->nb[2];
@@ -5939,10 +6157,7 @@ struct ggml_tensor * ggml_view_2d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
-
-    if (is_node) {
-        memcpy(result->padding, &offset, sizeof(offset));
-    }
+    result->opt[0] = offs;
 
     return result;
 }
@@ -5969,6 +6184,13 @@ struct ggml_tensor * ggml_view_3d(
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
 
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
     result->nb[1] = nb1;
     result->nb[2] = nb2;
     result->nb[3] = result->nb[2]*ne2;
@@ -5977,10 +6199,7 @@ struct ggml_tensor * ggml_view_3d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
-
-    if (is_node) {
-        memcpy(result->padding, &offset, sizeof(offset));
-    }
+    result->opt[0] = offs;
 
     return result;
 }
@@ -6009,6 +6228,13 @@ struct ggml_tensor * ggml_view_4d(
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
 
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    memcpy(offs->data, &offset, 2*sizeof(int32_t));
+
+    ggml_scratch_load(ctx);
+
     result->nb[1] = nb1;
     result->nb[2] = nb2;
     result->nb[3] = nb3;
@@ -6017,10 +6243,7 @@ struct ggml_tensor * ggml_view_4d(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = NULL;
-
-    if (is_node) {
-        memcpy(result->padding, &offset, sizeof(offset));
-    }
+    result->opt[0] = offs;
 
     return result;
 }
@@ -6083,10 +6306,18 @@ struct ggml_tensor * ggml_permute(
     result->src1 = NULL;
 
     if (is_node) {
-        result->padding[0] = axis0;
-        result->padding[1] = axis1;
-        result->padding[2] = axis2;
-        result->padding[3] = axis3;
+        ggml_scratch_save(ctx);
+
+        struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
+
+        ((int32_t *) b->data)[0] = axis0;
+        ((int32_t *) b->data)[1] = axis1;
+        ((int32_t *) b->data)[2] = axis2;
+        ((int32_t *) b->data)[3] = axis3;
+
+        ggml_scratch_load(ctx);
+
+        result->opt[0] = b;
     }
 
     return result;
@@ -6326,6 +6557,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
     return ggml_soft_max_impl(ctx, a, true);
 }
 
+
+// ggml_soft_max_back
+
+struct ggml_tensor * ggml_soft_max_back_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true; // TODO : implement backward pass
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SOFT_MAX_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_soft_max_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_soft_max_back_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_soft_max_back_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_soft_max_back_impl(ctx, a, b, true);
+}
+
 // ggml_rope
 
 struct ggml_tensor * ggml_rope_impl(
@@ -6338,7 +6607,7 @@ struct ggml_tensor * ggml_rope_impl(
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
 
-    if (!inplace && a->grad) {
+    if (a->grad) {
         is_node = true;
     }
 
@@ -6392,8 +6661,7 @@ struct ggml_tensor * ggml_rope_back(
     bool is_node = false;
 
     if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
+        is_node = false; // TODO: implement backward
     }
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6586,7 +6854,6 @@ struct ggml_tensor * ggml_flash_attn(
     bool is_node = false;
 
     if (q->grad || k->grad || v->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -6618,7 +6885,6 @@ struct ggml_tensor * ggml_flash_ff(
     bool is_node = false;
 
     if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
         is_node = true;
     }
 
@@ -6636,6 +6902,70 @@ struct ggml_tensor * ggml_flash_ff(
     return result;
 }
 
+// ggml_flash_attn_back
+
+struct ggml_tensor * ggml_flash_attn_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * d,
+        bool                  masked) {
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
+    // TODO: check if vT can be multiplied by (k*qT)
+
+    // d shape [D,N,ne2,ne3]
+    // q shape [D,N,ne2,ne3]
+    // k shape [D,M,ne2,ne3]
+    // v shape [M,D,ne2,ne3]
+
+    const int64_t   D = q->ne[0];
+    const int64_t   N = q->ne[1];
+    const int64_t   M = k->ne[1];
+    const int64_t ne2 = q->ne[2];
+    const int64_t ne3 = q->ne[3];
+
+    GGML_ASSERT(k->ne[0] == D);
+    GGML_ASSERT(v->ne[0] == M);
+    GGML_ASSERT(v->ne[1] == D);
+    GGML_ASSERT(d->ne[0] == D);
+    GGML_ASSERT(d->ne[1] == N);
+    GGML_ASSERT(k->ne[2] == ne2);
+    GGML_ASSERT(k->ne[3] == ne3);
+    GGML_ASSERT(v->ne[2] == ne2);
+    GGML_ASSERT(v->ne[3] == ne3);
+    GGML_ASSERT(d->ne[2] == ne2);
+    GGML_ASSERT(d->ne[3] == ne3);
+
+    bool is_node = false;
+
+    if (q->grad || k->grad || v->grad) {
+        // when using this operation (in backwards pass) these grads are set.
+        // we don't want to create (big) grad of our result, so is_node is false.
+        is_node = false;
+    }
+
+    // store gradients of q, k and v as continuous tensors concatenated in result.
+    // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
+    // gradq->data = result->data
+    // gradk->data = result->data + nb0*D*N*ne2*ne3
+    // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
+    // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
+    int64_t ne[4] = {D,M+N+M,ne2,ne3};
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op   = GGML_OP_FLASH_ATTN_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = q;
+    result->src1 = k;
+    result->opt[0] = v;
+    result->opt[1] = d;
+    result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
+
+    return result;
+}
+
 // ggml_win_part
 
 struct ggml_tensor * ggml_win_part(
@@ -6804,6 +7134,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
     return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
 }
 
+// ggml_cross_entropy_loss
+
+struct ggml_tensor * ggml_cross_entropy_loss(
+        struct ggml_context         * ctx,
+        struct ggml_tensor          * a,
+        struct ggml_tensor          * b) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+    result->op   = GGML_OP_CROSS_ENTROPY_LOSS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
+// ggml_cross_entropy_loss_back
+
+struct ggml_tensor * ggml_cross_entropy_loss_back(
+        struct ggml_context         * ctx,
+        struct ggml_tensor          * a,
+        struct ggml_tensor          * b,
+        struct ggml_tensor          * c) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_is_scalar(c));
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
+    result->grad = NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = c;
+
+    return result;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 void ggml_set_param(
@@ -7753,7 +8127,7 @@ static void ggml_compute_forward_add_q_f32(
 
         void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
         float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
-        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb0));
+        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
 
         assert(ne00 % 32 == 0);
 
@@ -7793,6 +8167,11 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
                 ggml_compute_forward_add_q_f32(params, src0, src1, dst);
             } break;
@@ -8096,6 +8475,11 @@ static void ggml_compute_forward_add1(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
                 ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
             } break;
@@ -8218,6 +8602,11 @@ static void ggml_compute_forward_acc(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
         default:
             {
                 GGML_ASSERT(false);
@@ -8336,10 +8725,10 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-#ifdef GGML_USE_CUBLAS
-    if (src1->backend == GGML_BACKEND_CUDA) {
+#ifdef GGML_USE_CLBLAST
+    if (src1->backend == GGML_BACKEND_GPU) {
         if (ith == 0) {
-            ggml_cuda_mul(src0, src1, dst);
+            ggml_cl_mul(src0, src1, dst);
         }
         return;
     }
@@ -8939,6 +9328,99 @@ static void ggml_compute_forward_repeat(
     }
 }
 
+// ggml_compute_forward_repeat_back
+
+static void ggml_compute_forward_repeat_back_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const size_t nb0  = dst->nb[0];
+    const size_t nb1  = dst->nb[1];
+    const size_t nb2  = dst->nb[2];
+    const size_t nb3  = dst->nb[3];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne00/ne0);
+    const int nr1 = (int)(ne01/ne1);
+    const int nr2 = (int)(ne02/ne2);
+    const int nr3 = (int)(ne03/ne3);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (ggml_is_contiguous(dst)) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+    } else {
+        for         (int k3 = 0; k3 < ne3; k3++) {
+            for     (int k2 = 0; k2 < ne2; k2++) {
+                for (int k1 = 0; k1 < ne1; k1++) {
+                    ggml_vec_set_f32(ne0,
+                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
+                        0);
+                }
+            }
+        }
+    }
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3; i3++) {
+        for                     (int k3 = 0; k3 < ne3; k3++) {
+            for                 (int i2 = 0; i2 < nr2; i2++) {
+                for             (int k2 = 0; k2 < ne2; k2++) {
+                    for         (int i1 = 0; i1 < nr1; i1++) {
+                        for     (int k1 = 0; k1 < ne1; k1++) {
+                            for (int i0 = 0; i0 < nr0; i0++) {
+                                ggml_vec_acc_f32(ne0,
+                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
+                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_repeat_back_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_abs
 
 static void ggml_compute_forward_abs_f32(
@@ -9206,8 +9688,6 @@ static void ggml_compute_forward_gelu(
                 GGML_ASSERT(false);
             } break;
     }
-
-    //printf("XXXXXXXX gelu\n");
 }
 
 // ggml_compute_forward_gelu_quick
@@ -9267,8 +9747,6 @@ static void ggml_compute_forward_gelu_quick(
                 GGML_ASSERT(false);
             } break;
     }
-
-    //printf("XXXXXXXX quick gelu\n");
 }
 
 // ggml_compute_forward_silu
@@ -9515,7 +9993,7 @@ static void ggml_compute_forward_rms_norm_f32(
                     sum += (ggml_float)(x[i00] * x[i00]);
                 }
 
-                float mean = sum/ne00;
+                const float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
@@ -9838,14 +10316,7 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
     if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
             ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -10010,14 +10481,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
     if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
             ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -10222,14 +10686,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
     if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
             ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -10372,6 +10829,11 @@ static void ggml_compute_forward_mul_mat(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
             } break;
@@ -10390,6 +10852,176 @@ static void ggml_compute_forward_mul_mat(
     }
 }
 
+// ggml_compute_forward_out_prod
+
+
+static void ggml_compute_forward_out_prod_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    //const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
+    // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+
+    if (params->type == GGML_TASK_INIT) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        // dst indices
+        const int64_t i3 = ir/(ne2*ne1);
+        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        const int64_t i02 = i2;
+        const int64_t i03 = i3;
+
+        //const int64_t i10 = i1;
+        const int64_t i12 = i2;
+        const int64_t i13 = i3;
+
+        for (int64_t i01 = 0; i01 < ne01; ++i01) {
+            const int64_t i11 = i01;
+
+            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+            ggml_vec_mad_f32(ne0, d, s0, *s1);
+            // for (int64_t i0 = 0; i0 < ne0; ++i0) {
+            //     d[i0] += s0[i0] * s1[i1];
+            // }
+        }
+    }
+
+    //int64_t t1 = ggml_perf_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_out_prod(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+            {
+                GGML_ASSERT(false); // todo
+                // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(false); // todo
+                // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_scale
 
 static void ggml_compute_forward_scale_f32(
@@ -10555,6 +11187,11 @@ static void ggml_compute_forward_set(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
         default:
             {
                 GGML_ASSERT(false);
@@ -10720,6 +11357,11 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
             } break;
@@ -10802,7 +11444,11 @@ static void ggml_compute_forward_get_rows_back_f32(
     GGML_ASSERT(ggml_is_contiguous(opt0));
     GGML_ASSERT(ggml_is_contiguous(dst));
 
-    ggml_compute_forward_dup_same_cont(params, opt0, dst);
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    if (params->type == GGML_TASK_INIT) {
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -10946,8 +11592,8 @@ static void ggml_compute_forward_diag_mask_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst,
         const float value) {
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 2);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(src1) == 2);
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -10955,7 +11601,7 @@ static void ggml_compute_forward_diag_mask_f32(
     const int  n_past  =       ((int32_t *) src1->data)[0];
     const bool inplace = (bool)((int32_t *) src1->data)[1];
 
-    assert(n_past >= 0);
+    GGML_ASSERT(n_past >= 0);
 
     if (!inplace && (params->type == GGML_TASK_INIT)) {
         // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -10979,8 +11625,8 @@ static void ggml_compute_forward_diag_mask_f32(
     const int nr = src0->ne[1];
     const int nz = n/nr;
 
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     for (int k = 0; k < nz; k++) {
         for (int j = ith; j < nr; j += nth) {
@@ -11116,42 +11762,138 @@ static void ggml_compute_forward_soft_max(
     }
 }
 
-// ggml_compute_forward_alibi
+// ggml_compute_forward_soft_max_back
 
-static void ggml_compute_forward_alibi_f32(
+static void ggml_compute_forward_soft_max_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src1, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int   n_past   = ((int32_t *) src1->data)[0];
-    const int   n_head   = ((int32_t *) src1->data)[1];
-    const float max_bias = ((float *)   src1->data)[2];
+    // TODO: handle transposed/permuted matrices
 
-    assert(n_past >= 0);
+    const int ith = params->ith;
+    const int nth = params->nth;
 
-    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
-    const int ne1 = src0->ne[1]; // seq_len_without_past
-    //const int ne2 = src0->ne[2]; // n_head -> this is k
-    //const int ne3 = src0->ne[3]; // 1 -> bsz
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
 
-    const int n  = ggml_nrows(src0);
-    const int ne2_ne3 = n/ne1; // ne2*ne3
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    //const int nb3 = src0->nb[3];
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
 
-    assert(nb0 == sizeof(float));
-    assert(ne1 + n_past == ne0); (void) n_past;
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(dy[i]));
+            assert(!isnan(y[i]));
+        }
+#endif
+        // Jii = yi - yi*yi
+        // Jij = -yi*yj
+        // J = diag(y)-y.T*y
+        // dx = J * dy
+        // dxk = sum_i(Jki * dyi)
+        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
+        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
+        // dxk = -yk * dot(y, dy) + yk*dyk
+        // dxk = yk * (- dot(y, dy) + dyk)
+        // dxk = yk * (dyk - dot(y, dy))
+        //
+        // post-order:
+        // dot_y_dy := dot(y, dy)
+        // dx := dy
+        // dx := dx - dot_y_dy
+        // dx := dx * y
+
+        // linear runtime, no additional memory
+        float dot_y_dy = 0;
+        ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
+        ggml_vec_cpy_f32 (nc, dx, dy);
+        ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
+        ggml_vec_mul_f32 (nc, dx, dx, y);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(dx[i]));
+            assert(!isinf(dx[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_soft_max_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_alibi
+
+static void ggml_compute_forward_alibi_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(src1) == 3);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int   n_past   = ((int32_t *) src1->data)[0];
+    const int   n_head   = ((int32_t *) src1->data)[1];
+    const float max_bias = ((float *)   src1->data)[2];
+
+    assert(n_past >= 0);
+
+    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
+    const int ne1 = src0->ne[1]; // seq_len_without_past
+    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    //const int ne3 = src0->ne[3]; // 1 -> bsz
+
+    const int n  = ggml_nrows(src0);
+    const int ne2_ne3 = n/ne1; // ne2*ne3
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    //const int nb3 = src0->nb[3];
+
+    assert(nb0 == sizeof(float));
+    assert(ne1 + n_past == ne0); (void) n_past;
 
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11188,8 +11930,9 @@ static void ggml_compute_forward_alibi_f16(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
+
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -11266,6 +12009,12 @@ static void ggml_compute_forward_alibi(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -11285,8 +12034,9 @@ static void ggml_compute_forward_clamp_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
-    assert(src1->type == GGML_TYPE_F32);
-    assert(ggml_nelements(src1) == 2);
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_nelements(src1) == 2);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -11337,6 +12087,12 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_Q8_K:
         case GGML_TYPE_I8:
         case GGML_TYPE_I16:
         case GGML_TYPE_I32:
@@ -11426,7 +12182,7 @@ static void ggml_compute_forward_rope_f32(
                         theta *= theta_scale;
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = src[0];
                         const float x1 = src[1];
@@ -11447,7 +12203,7 @@ static void ggml_compute_forward_rope_f32(
                             const int64_t i0 = ib*n_dims + ic/2;
 
                             const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                                  float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+                                  float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                             const float x0 = src[0];
                             const float x1 = src[n_dims/2];
@@ -13194,6 +13950,414 @@ static void ggml_compute_forward_flash_ff(
     }
 }
 
+// ggml_compute_forward_flash_attn_back
+
+static void ggml_compute_forward_flash_attn_back_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * d,
+        const bool masked,
+              struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int64_t neq0 = q->ne[0];
+    const int64_t neq1 = q->ne[1];
+    const int64_t neq2 = q->ne[2];
+    const int64_t neq3 = q->ne[3];
+
+    const int64_t nek0 = k->ne[0];
+    const int64_t nek1 = k->ne[1];
+    //const int64_t nek2 = k->ne[2];
+    //const int64_t nek3 = k->ne[3];
+
+    const int64_t nev0 = v->ne[0];
+    const int64_t nev1 = v->ne[1];
+    //const int64_t nev2 = v->ne[2];
+    //const int64_t nev3 = v->ne[3];
+
+    const int64_t ned0 = d->ne[0];
+    const int64_t ned1 = d->ne[1];
+    //const int64_t ned2 = d->ne[2];
+    //const int64_t ned3 = d->ne[3];
+
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const int nbk0 = k->nb[0];
+    const int nbk1 = k->nb[1];
+    const int nbk2 = k->nb[2];
+    const int nbk3 = k->nb[3];
+
+    const int nbq0 = q->nb[0];
+    const int nbq1 = q->nb[1];
+    const int nbq2 = q->nb[2];
+    const int nbq3 = q->nb[3];
+
+    const int nbv0 = v->nb[0];
+    const int nbv1 = v->nb[1];
+    const int nbv2 = v->nb[2];
+    const int nbv3 = v->nb[3];
+
+    const int nbd0 = d->nb[0];
+    const int nbd1 = d->nb[1];
+    const int nbd2 = d->nb[2];
+    const int nbd3 = d->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+    const int64_t P = nek1 - N;
+    const int64_t M = P + N;
+
+    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+    const int mxDM = MAX(D, Mup);
+
+    // GGML_ASSERT(ne0 == D);
+    // GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(float));
+    GGML_ASSERT(nbv0 == sizeof(float));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (params->type == GGML_TASK_INIT) {
+        if (ith == 0) {
+            memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
+        }
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0f/sqrtf(D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2);
+        const int iq2 = ir - iq3*neq2;
+        for ( int iq1 = 0; iq1 < neq1; ++iq1) {
+
+
+            // not sure about CACHE_LINE_SIZE_F32..
+            // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+            float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+            float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+
+            for (int i = M; i < Mup; ++i) {
+                S[i] = -INFINITY;
+            }
+
+            for (int64_t ic = 0; ic < nek1; ++ic) {
+                // k indices
+                const int ik3 = iq3;
+                const int ik2 = iq2;
+                const int ik1 = ic;
+
+                // S indices
+                const int i1 = ik1;
+
+                ggml_vec_dot_f32(neq0,
+                        S + i1,
+                        (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                        (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+            }
+
+            // scale
+            ggml_vec_scale_f32(nek1, S, scale);
+
+            if (masked) {
+                for (int64_t i = P; i < M; i++) {
+                    if (i > P + iq1) {
+                        S[i] = -INFINITY;
+                    }
+                }
+            }
+
+            // softmax
+            {
+                float max = -INFINITY;
+                ggml_vec_max_f32(M, &max, S);
+
+                ggml_float sum = 0.0;
+                {
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                    max = -max;
+                    vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+                    vvexpf(SM, SM, &Mup);
+                    ggml_vec_sum_f32(Mup, &sum, SM);
+#else
+                    uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
+                    ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
+
+                    for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+                        float * SR =  S + i;
+                        float * SW = SM + i;
+
+                        for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+                            if (SR[j] == -INFINITY) {
+                                SW[j] = 0.0f;
+                            } else {
+                                ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
+                                memcpy(&scvt[j], &s, sizeof(uint16_t));
+                                const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                                sump[j] += (ggml_float)val;
+                                SW[j] = val;
+                            }
+                        }
+                    }
+
+                    for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+                        sum += sump[i];
+                    }
+#endif
+                }
+
+                assert(sum > 0.0);
+
+                sum = 1.0/sum;
+                ggml_vec_scale_f32(M, SM, sum);
+
+            }
+
+            // step-by-step explanation
+            {
+                // forward-process                   shape      grads from backward process
+                // parallel_for iq2,iq3:
+                //  k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,iq2,iq3]  += grad[kcur]
+                //  q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+                //  v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iq2,iq3]  += grad[vcur]
+                //  for iq1:
+                //   kcur   = k[:D,:M,iq2,iq3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
+                //   qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
+                //   vcur   = v[:M,:D,iq2,iq3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
+                //   S0     = -Inf                   [D,1,1,1]
+                //  ~S1[i]  = dot(kcur[:D,i], qcur)
+                //   S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
+                //   S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
+                //   S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                //   S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
+                //  ~S5[i]  = dot(vcur[:,i], S4)
+                //   S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,iq1,iq2,iq3]
+                //  ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
+                //   dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
+                // dst                               backward-/ grad[dst]                 = d
+                //
+                // output gradients with their dependencies:
+                //
+                // grad[kcur] = grad[S1].T @ qcur
+                // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                // grad[S4]   = grad[S5] @ vcur
+                // grad[S4]   = d[:D,iq1,iq2,iq3] @ vcur
+                // grad[qcur] = grad[S1]   @ kcur
+                // grad[vcur] = grad[S5].T @ S4
+                // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
+                //
+                // in post-order:
+                //
+                // S1         = qcur @ kcur.T
+                // S2         = S1 * scale
+                // S3         = diag_mask_inf(S2, P)
+                // S4         = softmax(S3)
+                // grad[S4]   = d[:D,iq1,iq2,iq3] @ vcur
+                // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                // grad[qcur] = grad[S1]   @ kcur
+                // grad[kcur] = grad[S1].T @ qcur
+                // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
+                //
+                // using less variables (SM=S4):
+                //
+                // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
+                // SM            = softmax(S)
+                // S             = d[:D,iq1,iq2,iq3] @ vcur
+                // dot_SM_gradSM = dot(SM, S)
+                // S             = SM * (S - dot(SM, S))
+                // S             = diag_mask_zero(S, P) * scale
+                //
+                // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
+                // grad[k][:D,:M,iq2,iq3]  += S.T @ qcur
+                // grad[v][:M,:D,iq2,iq3]  += d[:D,iq1,iq2,iq3].T @ SM
+            }
+
+            // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
+            // S = d[:D,iq1,iq2,iq3] @ vcur
+            // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
+            ggml_vec_set_f32(M, S, 0);
+            for (int64_t ic = 0; ic < D; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                ggml_vec_mad_f32(M,
+                        S,
+                         (float *) ((char *) v->data + (          ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                        *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
+            }
+
+            // S = SM * (S - dot(SM, S))
+            float dot_SM_gradSM = 0;
+            ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
+            ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+            ggml_vec_mul_f32 (M, S, S, SM);
+
+            // S = diag_mask_zero(S, P) * scale
+            if (masked) {
+                // for (int64_t i = P + iq1 + 1; i < M; i++) {
+                //     S[i] = 0;
+                // }
+                for (int64_t i = P; i < M; i++) {
+                    if (i > P + iq1) {
+                        S[i] = 0;
+                    }
+                }
+            }
+            ggml_vec_scale_f32(M, S, scale);
+
+            void * grad_q = (char *) dst->data;
+            void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
+            void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
+
+            const size_t nbgq1 = nb0*neq0;
+            const size_t nbgq2 = nb0*neq0*neq1;
+            const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+            const size_t nbgk1 = nb0*nek0;
+            const size_t nbgk2 = nb0*nek0*nek1;
+            const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+            const size_t nbgv1 = nb0*nev0;
+            const size_t nbgv2 = nb0*nev0*nev1;
+            const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+            // S    shape [M,1]
+            // SM   shape [M,1]
+            // kcur shape [D,M]
+            // qcur shape [D,1]
+            // vcur shape [M,D]
+            //
+            // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+            // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+            // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
+            //
+            //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
+            //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
+            for (int64_t ic = 0; ic < M; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                ggml_vec_mad_f32(D,
+                        (float *) ((char *) grad_q  + (i1*nbgq1  + i2*nbgq2  + i3*nbgq3)),
+                        (float *) ((char *) k->data + (ic*nbk1   + i2*nbk2   + i3*nbk3)),
+                        S[ic]);
+            }
+
+            // grad[k][:D,:M,iq2,iq3] += S.T       @ qcur
+            // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+            // grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
+            for (int64_t ic = 0; ic < M; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                // ggml_vec_set_f32(D,
+                //         (float *) ((char *) grad_k  + (ic*nbgk1  + i2*nbgk2  + i3*nbgk3)),
+                //         0);
+                ggml_vec_mad_f32(D,
+                        (float *) ((char *) grad_k  + (ic*nbgk1  + i2*nbgk2  + i3*nbgk3)),
+                        (float *) ((char *) q->data + (i1*nbq1   + i2*nbq2   + i3*nbq3)),
+                        S[ic]);
+            }
+
+            // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T       @ SM
+            // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
+            // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3]         * SM[:M]
+            for (int64_t ic = 0; ic < D; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
+
+                // ggml_vec_set_f32(M,
+                //         (float *) ((char *) grad_v   + (          ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
+                //         0);
+                ggml_vec_mad_f32(M,
+                        (float *) ((char *) grad_v   + (          ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
+                        SM,
+                        *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1  + i2*nbd2  + i3*nbd3)));
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_flash_attn_back(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * d,
+        const bool masked,
+        struct ggml_tensor * dst) {
+    switch (q->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_win_part
 
 static void ggml_compute_forward_win_part_f32(
@@ -13205,16 +14369,15 @@ static void ggml_compute_forward_win_part_f32(
         return;
     }
 
-    const int64_t ne00 = src0->ne[0];
+    const int64_t ne00 = src0->ne[0]; UNUSED(ne00);
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
-    //const int64_t ne03 = src0->ne[3];
-    UNUSED(ne00);
+    const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
 
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
     const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
+    const int64_t ne3 = dst->ne[3]; UNUSED(ne3);
 
     const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
     const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
@@ -13393,32 +14556,311 @@ static void ggml_compute_forward_map_binary_f32(
         return;
     }
 
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+    assert(src1->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        fun(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+
+static void ggml_compute_forward_map_binary(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst,
+        const ggml_binary_op_f32_t fun) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_cross_entropy_loss
+
+static void ggml_compute_forward_cross_entropy_loss_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    float * sums = (float *) params->wdata;
+
+    // TODO: handle transposed/permuted matrices
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    if (params->type == GGML_TASK_INIT) {
+        if (ith == 0) {
+            memset(sums, 0, sizeof(float) * (nth + nth * nc));
+        }
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        if (ith == 0) {
+            float * dp = (float *) dst->data;
+            ggml_vec_sum_f32(nth, dp, sums);
+            dp[0] *= -1.0f;
+        }
+        return;
+    }
+
+    const double eps = 1e-9;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float * st = (float *) params->wdata + nth + ith*nc;
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+        // soft_max
+        ggml_float sum = 0.0;
+        {
+            float max = -INFINITY;
+            ggml_vec_max_f32(nc, &max, s0);
+
+            uint16_t scvt;
+            for (int i = 0; i < nc; i++) {
+                if (s0[i] == -INFINITY) {
+                    st[i] = 0.0f;
+                } else {
+                    // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
+                    ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
+                    memcpy(&scvt, &s, sizeof(scvt));
+                    const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                    sum += (ggml_float)val;
+                    st[i] = val;
+                }
+            }
+
+            assert(sum > 0.0);
+            // sum = 1.0/sum;
+        }
+        // avoid log(0) by rescaling from [0..1] to [eps..1]
+        sum = (1.0 - eps) / sum;
+        ggml_vec_scale_f32(nc, st, sum);
+        ggml_vec_add1_f32(nc, st, st, eps);
+        ggml_vec_log_f32(nc, st, st);
+        ggml_vec_mul_f32(nc, st, st, s1);
+
+        ggml_vec_sum_f32(nc, sums + ith, st);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(st[i]));
+            assert(!isinf(st[i]));
+        }
+#endif
+    }
+
+}
+
+static void ggml_compute_forward_cross_entropy_loss(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_cross_entropy_loss_back
+
+static void ggml_compute_forward_cross_entropy_loss_back_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int64_t ith = params->ith;
+    const int64_t nth = params->nth;
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const float eps = 1e-9f;
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0->ne[0];
+    const int64_t nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    float * d   = (float *) opt0->data;
+
+    for (int64_t i1 = ir0; i1 < ir1; i1++) {
+        float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
+        float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float * sm  = (float *) params->wdata + ith*nc;
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+        // step by step explanation:
+        {
+            //float * sums = (float *) params->wdata;
+
+            // forward pass with annotated gradients from backward pass
+            // (built by going in reverse operation order, adding to gradients of current operation args)
+            // st0 = exp(s0-max(s0))                                                       grad[st0] = grad[st1]*(1.0 - eps)/sum
+                                                          // from softmax_back:            grad[s0]  = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
+            // ggml_vec_scale_f32(nc, st, sum);           // st1 = st0*/sum = softmax(s0)  grad[st1] = grad[st2]*(1.0 - eps)
+            // ggml_vec_scale_f32(nc, st, (1.0f - eps));  // st2 = st1*(1.0 - eps)         grad[st2] = grad[st3]
+            // ggml_vec_add1_f32(nc, st, st, eps);        // st3 = st2 + eps               grad[st3] = grad[st4]/st3
+            // ggml_vec_log_f32(nc, st, st);              // st4 = log(st3)                grad[st4] = grad[st5] * s1
+            // ggml_vec_mul_f32(nc, st, st, s1);          // st5 = st4 * s1                grad[st5] = grad[sums[ith]]
+            // ggml_vec_sum_f32(nc, sums + ith, st);      // sums[ith] = st5               grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
+
+            // substitute into grad[st1], because we can reuse softmax_back from this point on
+            // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
+            // postorder:
+            // grad[st1] := softmax(s0)
+            // grad[st1] := grad[st1]*(1.0 - eps)
+            // grad[st1] := grad[st1] + eps
+            // grad[st1] := s1 / grad[st1]
+            // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
+
+            // src0 gradients by going through softmax_back
+            // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
+            //   from softmax_back:
+            //   dxk = yk * (dyk - dot(y, dy))
+            //   dot_y_dy := dot(y, dy)
+            //   dx := dy
+            //   dx := dx - dot_y_dy
+            //   dx := dx * y
+            //   postorder:
+            //   dot_st1_dst1 := dot(st1, grad[st1])
+            //   grad[s0] := grad[st1]
+            //   grad[s0] := grad[s0] - dot_st1_dst1
+            //   grad[s0] := grad[s0] * st1
+
+            // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
+            // sm           := softmax(s0)
+            // grad[s0]     := sm*(1.0 - eps)
+            // grad[s0]     := grad[s0] + eps
+            // grad[s0]     := s1 / grad[s0]
+            // grad[s0]     := grad[s0]*(1.0-eps)*-grad[cel]
+            // dot_st1_dst1 := dot(sm, grad[s0])
+            // grad[s0]     := grad[s0] - dot_st1_dst1
+            // grad[s0]     := grad[s0] * sm
+        }
+
+        // soft_max
+        ggml_float sum = 0.0;
+        {
+            float max = -INFINITY;
+            ggml_vec_max_f32(nc, &max, s0);
+
+            uint16_t scvt;
+            for (int i = 0; i < nc; i++) {
+                if (s0[i] == -INFINITY) {
+                    sm[i] = 0.0f;
+                } else {
+                    // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
+                    ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
+                    memcpy(&scvt, &s, sizeof(scvt));
+                    const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+                    sum += (ggml_float)val;
+                    sm[i] = val;
+                }
+            }
+
+            assert(sum > 0.0);
+            sum = 1.0/sum;
+        }
 
-    assert( dst->nb[0] == sizeof(float));
-    assert(src0->nb[0] == sizeof(float));
-    assert(src1->nb[0] == sizeof(float));
+        float dot_st1_dst1 = 0;
+        ggml_vec_scale_f32(nc, sm, sum);
+        ggml_vec_cpy_f32  (nc, ds0, sm);
+        ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
+        ggml_vec_add1_f32 (nc, ds0, ds0, eps);
+        ggml_vec_div_f32  (nc, ds0, s1, ds0);
+        ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
+        ggml_vec_dot_f32  (nc, &dot_st1_dst1, sm, ds0);
+        ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
+        ggml_vec_mul_f32  (nc, ds0, ds0, sm);
 
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(sm[i]));
+            assert(!isinf(sm[i]));
+            assert(!isnan(ds0[i]));
+            assert(!isinf(ds0[i]));
+        }
+#endif
     }
 }
 
-
-static void ggml_compute_forward_map_binary(
+static void ggml_compute_forward_cross_entropy_loss_back(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
+                ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
             } break;
         default:
             {
@@ -13427,11 +14869,21 @@ static void ggml_compute_forward_map_binary(
     }
 }
 
+
 /////////////////////////////////
 
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
     GGML_ASSERT(params);
 
+#ifdef GGML_USE_CUBLAS
+    bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
+    if (skip_cpu) {
+        return;
+    }
+    GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU);
+    GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
+#endif // GGML_USE_CUBLAS
+
     switch (tensor->op) {
         case GGML_OP_DUP:
             {
@@ -13489,6 +14941,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_repeat(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_REPEAT_BACK:
+            {
+                ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_ABS:
             {
                 ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -13541,6 +14997,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_OUT_PROD:
+            {
+                ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_SCALE:
             {
                 ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@@ -13597,6 +15057,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_soft_max(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_SOFT_MAX_BACK:
+            {
+                ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_ROPE:
             {
                 ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@@ -13636,6 +15100,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
             } break;
+        case GGML_OP_FLASH_ATTN_BACK:
+            {
+                int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
+                GGML_ASSERT(t == 0 || t == 1);
+                bool masked = t != 0;
+                ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
+            } break;
         case GGML_OP_WIN_PART:
             {
                 ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
@@ -13656,6 +15127,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
             }
             break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            {
+                ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
+            }
+            break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            {
+                ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+            }
+            break;
         case GGML_OP_NONE:
             {
                 // nop
@@ -13794,11 +15275,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_impl(ctx,
                                 src0->grad,
-                                ggml_mul(ctx,
-                                    tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
+                                ggml_scale(ctx,
                                     ggml_div(ctx,
-                                        ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
-                                        tensor)),
+                                        tensor->grad,
+                                        tensor),
+                                    ggml_new_f32(ctx, 0.5f)),
                                 inplace);
                 }
             } break;
@@ -13844,43 +15325,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
-                    const int nc  = tensor->ne[0];
-                    const int nr  = tensor->ne[1];
-                    const int nc0 = src0->ne[0];
-                    const int nr0 = src0->ne[1];
-                    const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
-                    const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
-                    // tensor->grad [nc,nr,1,1]
-                    // reshape      [nc0,nc/nc0,nr0,nr/nr0]
-                    // permute      [nc0,nr0,nc/nc0,nr/nr0]
-                    // substitute   [nc0,nr0,ncr,nrr]
-                    // reshape      [nc0*nr0,ncr*nrr,1,1]
-                    // transpose    [ncr*nrr,nc0*nr0,1,1]
-                    // sum rows     [1,nc0*nr0,1,1]
-                    // transpose    [nc0*nr0,1,1]
-                    // reshape      [nc0,nr0,1,1] reshape_1d or reshape_2d
-                    // add to src0->grad
-
-                    int64_t ne[4]  = {nc0,ncr,nr0,nrr};
-
-                    struct ggml_tensor* F00 = tensor->grad;
-                    struct ggml_tensor* F01 = ggml_reshape   (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
-                    struct ggml_tensor* F02 = ggml_permute   (ctx, F01, 0,2,1,3);
-                    struct ggml_tensor* F03 = ggml_cont      (ctx, F02);
-                    struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
-                    struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
-                    struct ggml_tensor* F06 = ggml_cont      (ctx, F05);
-                    struct ggml_tensor* F07 = ggml_sum_rows  (ctx, F06);
-                    struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
-                    struct ggml_tensor* F09 = ggml_cont      (ctx, F08);
-                    struct ggml_tensor* F10 = ggml_reshape   (ctx, F09, src0->grad);
-
-                    src0->grad =
-                        ggml_add_impl(ctx,
-                                src0->grad,
-                                F10,
-                                inplace);
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_repeat_back(ctx, tensor->grad, src0->grad),
+                            inplace);
+                }
+            } break;
+        case GGML_OP_REPEAT_BACK:
+            {
+                if (src0->grad) {
+                    // TODO: test this
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            ggml_repeat(ctx, tensor->grad, src0->grad),
+                            inplace);
                 }
             } break;
         case GGML_OP_ABS:
@@ -13991,38 +15449,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
 
                 // necessary for llama
                 if (src0->grad) {
-                    // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
                     src0->grad =
                         ggml_add_impl(ctx,
                                 src0->grad,
-                                // ds0 = dt.dot(s1.T)
-                                // ggml_out_prod(ctx, // [n,m]
-                                //     src1,          // [n,p]
-                                //     tensor->grad), // [m,p]
-                                // for now just using A*B==(B.T*A.T).T
-                                ggml_cont(ctx,                      // [n,m]
-                                    ggml_transpose(ctx,             // [n,m]
-                                        ggml_mul_mat(ctx,           // [m,n]
-                                            ggml_cont(ctx,          // [p,m]
-                                                ggml_transpose(ctx, // [p,m]
-                                                    tensor->grad)), // [m,p]
-                                            ggml_cont(ctx,          // [p,n]
-                                                ggml_transpose(ctx, // [p,n]
-                                                    src1))))),      // [n,p]
+                                ggml_out_prod(ctx, // [n,m]
+                                    src1,          // [n,p]
+                                    tensor->grad), // [m,p]
                                 inplace);
                 }
                 if (src1->grad) {
                     src1->grad =
                         ggml_add_impl(ctx,
                                 src1->grad,
-                                // ds1 = s0.T.dot(dt):
-                                ggml_mul_mat(ctx,                   // [n,p]
-                                    ggml_cont(ctx,                  // [m,n]
-                                        ggml_transpose(ctx, src0)), // [m,n]
-                                    tensor->grad),                  // [m,p]
+                                // ggml_mul_mat(ctx,                   // [n,p]
+                                //     ggml_cont(ctx,                  // [m,n]
+                                //         ggml_transpose(ctx, src0)), // [m,n]
+                                //     tensor->grad),                  // [m,p]
+
+                                // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
+                                // // avoid transpose of src0, rather transpose smaller tensor->grad
+                                // // and then use ggml_out_prod
+                                ggml_out_prod(ctx,                  // [n,p]
+                                    src0,                           // [n,m]
+                                    ggml_transpose(ctx,             // [p,m]
+                                        tensor->grad)),             // [m,p]
                                 inplace);
                 }
             } break;
+        case GGML_OP_OUT_PROD:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_SCALE:
             {
                 // necessary for llama
@@ -14124,7 +15581,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     size_t offset;
-                    memcpy(&offset, tensor->padding, sizeof(offset));
+
+                    GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
+                    memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
 
                     size_t nb1     = tensor->nb[1];
                     size_t nb2     = tensor->nb[2];
@@ -14151,10 +15610,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    int axis0 = tensor->padding[0] & 0x3;
-                    int axis1 = tensor->padding[1] & 0x3;
-                    int axis2 = tensor->padding[2] & 0x3;
-                    int axis3 = tensor->padding[3] & 0x3;
+                    int32_t * axes = (int32_t *) tensor->opt[0]->data;
+                    int axis0 = axes[0] & 0x3;
+                    int axis1 = axes[1] & 0x3;
+                    int axis2 = axes[2] & 0x3;
+                    int axis3 = axes[3] & 0x3;
                     int axes_backward[4] = {0,0,0,0};
                     axes_backward[axis0] = 0;
                     axes_backward[axis1] = 1;
@@ -14238,50 +15698,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    // y = softmax(x)
-                    //
-                    // Jii = yi - yi*yi
-                    // Jij = -yi*yj
-                    // J = diag(y)-y.*y
-                    // dx = J * dy
-                    // dxk = sum(Jkj * dyk)
-
-                    int64_t ne2[4] = {
-                        tensor->ne[0],
-                        1,
-                        tensor->ne[1]*tensor->ne[2],
-                        tensor->ne[3]
-                    };
-                    struct ggml_tensor * tensor2 = ggml_cont(ctx,
-                        ggml_reshape_4d(ctx,
-                            ggml_cont(ctx, tensor),
-                            ne2[0], ne2[1], ne2[2], ne2[3]));
-
-                    struct ggml_tensor * grad2 = ggml_cont(ctx,
-                        ggml_reshape_4d(ctx,
-                            ggml_cont(ctx, tensor->grad),
-                            ne2[0], ne2[1], ne2[2], ne2[3]));
-
-                    struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
-                        ggml_permute(ctx,                           // [1,ne0,ne1*ne2,ne3]
-                            tensor2,                                // [ne0,1,ne1*ne2,ne3]
-                            1, 0, 2, 3));
-
                     src0->grad =
-                        ggml_add_impl(ctx,
-                            src0->grad,                   // [ne0,ne1,ne2,ne3]
-                            ggml_reshape(ctx,             // [ne0,ne1,ne2,ne3]
-                                ggml_mul_mat(ctx,         // [ne0,1,ne1*ne2,ne3]
-                                    ggml_sub(ctx,         // [ne0,ne0,ne1*ne2,ne3]
-                                        ggml_diag(ctx,    // [ne0,ne0,ne1*ne2,ne3]
-                                            tensor2),     // [ne0,1,ne1*ne2,ne3]
-                                        ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
-                                            tensor2_t,    // [1,ne0,ne1*ne2,ne3]
-                                            tensor2_t)),  // [1,ne0,ne1*ne2,ne3]
-                                    grad2),               // [ne0,1,ne1*ne2,ne3]
-                                src0->grad),
-                            inplace);
+                        ggml_add_impl(ctx, src0->grad,
+                            ggml_soft_max_back(ctx, tensor->grad, tensor),
+                        inplace);
                 }
+
+            } break;
+        case GGML_OP_SOFT_MAX_BACK:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
             } break;
         case GGML_OP_ROPE:
             {
@@ -14340,12 +15766,169 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_FLASH_ATTN:
             {
-                GGML_ASSERT(false); // not supported
+                struct ggml_tensor * flash_grad = NULL;
+                if (src0->grad || src1->grad || tensor->opt[0]->grad) {
+                    int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
+                    GGML_ASSERT(t == 0 || t == 1);
+                    bool masked = t != 0;
+                    flash_grad =
+                        ggml_flash_attn_back(ctx,
+                            src0,
+                            src1,
+                            tensor->opt[0],
+                            tensor->grad,
+                            masked);
+                }
+
+                if (src0->grad) {
+                    struct ggml_tensor * grad_q = NULL;
+                    const size_t nb0    = flash_grad->nb[0];
+                    const size_t offset = 0;
+                    switch(src0->n_dims) {
+                        case 2:
+                            {
+                                grad_q = ggml_view_2d(ctx,
+                                    flash_grad,
+                                    src0->ne[0],
+                                    src0->ne[1],
+                                    nb0*src0->ne[0],
+                                    offset);
+                            } break;
+                        case 3:
+                            {
+                                grad_q = ggml_view_3d(ctx,
+                                    flash_grad,
+                                    src0->ne[0],
+                                    src0->ne[1],
+                                    src0->ne[2],
+                                    nb0*src0->ne[0],
+                                    nb0*src0->ne[0]*src0->ne[1],
+                                    offset);
+                            } break;
+                        case 4:
+                            {
+                                grad_q = ggml_view_4d(ctx,
+                                    flash_grad,
+                                    src0->ne[0],
+                                    src0->ne[1],
+                                    src0->ne[2],
+                                    src0->ne[3],
+                                    nb0*src0->ne[0],
+                                    nb0*src0->ne[0]*src0->ne[1],
+                                    nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
+                                    offset);
+                            } break;
+                    }
+
+                    src0->grad = ggml_add_impl(ctx,
+                            src0->grad,
+                            grad_q,
+                            inplace);
+                }
+
+                if (src1->grad) {
+                    struct ggml_tensor * grad_k = NULL;
+                    const size_t nb0    = flash_grad->nb[0];
+                    const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
+                    switch(src1->n_dims) {
+                        case 2:
+                            {
+                                grad_k = ggml_view_2d(ctx,
+                                    flash_grad,
+                                    src1->ne[0],
+                                    src1->ne[1],
+                                    nb0*src1->ne[0],
+                                    offset);
+                            } break;
+                        case 3:
+                            {
+                                grad_k = ggml_view_3d(ctx,
+                                    flash_grad,
+                                    src1->ne[0],
+                                    src1->ne[1],
+                                    src1->ne[2],
+                                    nb0*src1->ne[0],
+                                    nb0*src1->ne[0]*src1->ne[1],
+                                    offset);
+                            } break;
+                        case 4:
+                            {
+                                grad_k = ggml_view_4d(ctx,
+                                    flash_grad,
+                                    src1->ne[0],
+                                    src1->ne[1],
+                                    src1->ne[2],
+                                    src1->ne[3],
+                                    nb0*src1->ne[0],
+                                    nb0*src1->ne[0]*src1->ne[1],
+                                    nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
+                                    offset);
+                            } break;
+                    }
+
+                    src1->grad = ggml_add_impl(ctx,
+                            src1->grad,
+                            grad_k,
+                            inplace);
+                }
+
+                struct ggml_tensor * opt0 = tensor->opt[0];
+
+                if (opt0->grad) {
+                    struct ggml_tensor * grad_v = NULL;
+                    const size_t nb0    = flash_grad->nb[0];
+                    const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
+                                        + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
+                    switch(opt0->n_dims) {
+                        case 2:
+                            {
+                                grad_v = ggml_view_2d(ctx,
+                                    flash_grad,
+                                    opt0->ne[0],
+                                    opt0->ne[1],
+                                    nb0*opt0->ne[0],
+                                    offset);
+                            } break;
+                        case 3:
+                            {
+                                grad_v = ggml_view_3d(ctx,
+                                    flash_grad,
+                                    opt0->ne[0],
+                                    opt0->ne[1],
+                                    opt0->ne[2],
+                                    nb0*opt0->ne[0],
+                                    nb0*opt0->ne[0]*opt0->ne[1],
+                                    offset);
+                            } break;
+                        case 4:
+                            {
+                                grad_v = ggml_view_4d(ctx,
+                                    flash_grad,
+                                    opt0->ne[0],
+                                    opt0->ne[1],
+                                    opt0->ne[2],
+                                    opt0->ne[3],
+                                    nb0*opt0->ne[0],
+                                    nb0*opt0->ne[0]*opt0->ne[1],
+                                    nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
+                                    offset);
+                            } break;
+                    }
+
+                    opt0->grad = ggml_add_impl(ctx,
+                            opt0->grad,
+                            grad_v,
+                            inplace);
+                }
             } break;
         case GGML_OP_FLASH_FF:
             {
                 GGML_ASSERT(false); // not supported
             } break;
+        case GGML_OP_FLASH_ATTN_BACK:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_MAP_UNARY:
@@ -14353,6 +15936,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // not supported
             } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            {
+                if (src0->grad) {
+                    src0->grad = ggml_add_impl(ctx,
+                                src0->grad,
+                                ggml_cross_entropy_loss_back(ctx,
+                                    src0,
+                                    src1,
+                                    tensor->grad),
+                                inplace);
+                }
+            } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
         case GGML_OP_NONE:
             {
                 // nop
@@ -14729,6 +16328,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 case GGML_OP_SUM_ROWS:
                 case GGML_OP_MEAN:
                 case GGML_OP_REPEAT:
+                case GGML_OP_REPEAT_BACK:
                 case GGML_OP_ABS:
                 case GGML_OP_SGN:
                 case GGML_OP_NEG:
@@ -14749,6 +16349,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         node->n_tasks = n_threads;
                     } break;
                 case GGML_OP_MUL_MAT:
+                case GGML_OP_OUT_PROD:
                     {
                         node->n_tasks = n_threads;
 
@@ -14765,7 +16366,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
                             node->n_tasks = 1; // TODO: this actually is doing nothing
                                                 //       the threads are still spinning
-                            cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
                         }
                         else
 #elif defined(GGML_USE_CLBLAST)
@@ -14832,6 +16432,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_DIAG_MASK_INF:
                 case GGML_OP_SOFT_MAX:
+                case GGML_OP_SOFT_MAX_BACK:
                 case GGML_OP_ROPE:
                 case GGML_OP_ROPE_BACK:
                     {
@@ -14946,6 +16547,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                             cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
                         }
 
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_FLASH_ATTN_BACK:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        const int64_t    D = node->src0->ne[0];
+                        const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+                        const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
+                        if (node->src1->type == GGML_TYPE_F32) {
+                            cur  = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        if (node->src1->type == GGML_TYPE_F16) {
+                            cur  = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
+                        }
+
                         work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_WIN_PART:
@@ -14955,6 +16577,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = 1;
                     } break;
+                case GGML_OP_CROSS_ENTROPY_LOSS:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
+
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
+
+                        work_size = MAX(work_size, cur);
+                    } break;
                 case GGML_OP_NONE:
                     {
                         node->n_tasks = 1;
@@ -15192,7 +16830,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
     const int64_t * ne = tensor->ne;
     const size_t  * nb = tensor->nb;
 
-    fprintf(fout, "%-6s %-12s %8d %8jd %jd %jd %jd %16zu %16zu %16zu %16zu %16p %16s\n",
+    fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
             ggml_type_name(tensor->type),
             ggml_op_name  (tensor->op),
             tensor->n_dims,
@@ -15206,7 +16844,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
     const int64_t * ne = tensor->ne;
     const size_t  * nb = tensor->nb;
 
-    fprintf(fout, "%-6s %-6s %-12s %8d %8jd %8jd %8jd %8jd %16zu %16zu %16zu %16zu %8d %16p %16s\n",
+    fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
             arg,
             ggml_type_name(tensor->type),
             ggml_op_name  (tensor->op),
@@ -15219,8 +16857,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
 }
 
 void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
-    assert(cgraph->work      == NULL);
-    assert(cgraph->work_size == 0);
+    //assert(cgraph->work      == NULL);
+    //assert(cgraph->work_size == 0);
 
     uint64_t size_eval = 0;
 
@@ -15235,11 +16873,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
         FILE * fout = stdout;
 
         fprintf(fout, "\n");
-        fprintf(fout, "%-16s %8x\n",  "magic",   GGML_FILE_MAGIC);
-        fprintf(fout, "%-16s %8d\n",  "version", GGML_FILE_VERSION);
-        fprintf(fout, "%-16s %8d\n",  "leafs",   cgraph->n_leafs);
-        fprintf(fout, "%-16s %8d\n",  "nodes",   cgraph->n_nodes);
-        fprintf(fout, "%-16s %8ju\n", "eval",    size_eval);
+        fprintf(fout, "%-16s %8x\n", "magic",        GGML_FILE_MAGIC);
+        fprintf(fout, "%-16s %8d\n", "version",      GGML_FILE_VERSION);
+        fprintf(fout, "%-16s %8d\n", "leafs",        cgraph->n_leafs);
+        fprintf(fout, "%-16s %8d\n", "nodes",        cgraph->n_nodes);
+        fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval);
 
         // header
         fprintf(fout, "\n");
@@ -15441,7 +17079,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
     // read file into data
     {
         FILE * fin = fopen(fname, "rb");
-
         if (!fin) {
             fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
             return result;
@@ -15589,6 +17226,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                 op     = *(const uint32_t *) ptr; ptr += sizeof(op);
                 n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
 
+                enum ggml_op eop = (enum ggml_op) op;
+
                 int64_t ne[GGML_MAX_DIMS];
                 size_t  nb[GGML_MAX_DIMS];
 
@@ -15603,42 +17242,77 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                     nb[j] = nb_cur;
                 }
 
-                struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
-
-                tensor->op = (enum ggml_op) op;
+                uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
 
-                uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur);
+                const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
 
-                memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
+                const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
 
-                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
-                    tensor->nb[j] = nb[j];
-                }
+                struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
 
                 // parse args
-                {
-                    struct ggml_tensor ** args[2 + GGML_MAX_OPT] = {
-                        &tensor->src0,
-                        &tensor->src1,
-                    };
+                for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
+                    const int32_t arg_idx = ptr_arg_idx[j];
 
-                    for (int j = 0; j < GGML_MAX_OPT; ++j) {
-                        args[2 + j] = &tensor->opt[j];
+                    if (arg_idx == -1) {
+                        continue;
                     }
 
-                    for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
-                        const int32_t arg_idx = *(const int32_t *) ptr; ptr += sizeof(arg_idx);
+                    if (arg_idx < GGML_MAX_NODES) {
+                        args[j] = result.leafs[arg_idx];
+                    } else {
+                        args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
+                    }
+                }
 
-                        if (arg_idx == -1) {
-                            continue;
-                        }
+                // create the tensor
+                // "view" operations are handled differently
+                // TODO: handle inplace ops - currently a copy is always made
+
+                struct ggml_tensor * tensor = NULL;
+
+                switch (eop) {
+                    // TODO: implement other view ops
+                    case GGML_OP_RESHAPE:
+                        {
+                            tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
+                        } break;
+                    case GGML_OP_VIEW:
+                        {
+                            tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
+
+                            uint64_t offs;
+                            memcpy(&offs, args[2]->data, sizeof(offs));
+
+                            tensor->data = ((char *) tensor->data) + offs;
+                        } break;
+                    case GGML_OP_TRANSPOSE:
+                        {
+                            tensor = ggml_transpose(*ctx_eval, args[0]);
+                        } break;
+                    case GGML_OP_PERMUTE:
+                        {
+                            tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
+                        } break;
+                    default:
+                        {
+                            tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
+
+                            tensor->op = eop;
+                        } break;
+                }
 
-                        if (arg_idx < GGML_MAX_NODES) {
-                            *args[j] = result.leafs[arg_idx];
-                        } else {
-                            *args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
-                        }
-                    }
+                memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
+
+                for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                    tensor->nb[j] = nb[j];
+                }
+
+                tensor->src0 = args[0];
+                tensor->src1 = args[1];
+
+                for (int j = 0; j < GGML_MAX_OPT; ++j) {
+                    tensor->opt[j] = args[2 + j];
                 }
 
                 result.nodes[i] = tensor;
@@ -15665,7 +17339,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
 
         perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
 
-        GGML_PRINT(" - %3d: [ %5jd, %5jd, %5jd] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
                 GGML_OP_NAME[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -15679,7 +17353,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
     for (int i = 0; i < cgraph->n_leafs; i++) {
         struct ggml_tensor * node = cgraph->leafs[i];
 
-        GGML_PRINT(" - %3d: [ %5jd, %5jd] %8s\n",
+        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
                 i,
                 node->ne[0], node->ne[1],
                 GGML_OP_NAME[node->op]);
@@ -15762,9 +17436,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
         }
 
         if (node->n_dims == 2) {
-            fprintf(fp, "%d [%jd, %jd] | <x>%s", i, node->ne[0], node->ne[1], GGML_OP_SYMBOL[node->op]);
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], GGML_OP_SYMBOL[node->op]);
         } else {
-            fprintf(fp, "%d [%jd, %jd, %jd] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], GGML_OP_SYMBOL[node->op]);
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], GGML_OP_SYMBOL[node->op]);
         }
 
 
@@ -15797,7 +17471,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
             }
         }
         else {
-            fprintf(fp, "CONST %d [%jd, %jd]", i, node->ne[0], node->ne[1]);
+            fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
         }
         fprintf(fp, "\"; ]\n");
     }
@@ -15898,6 +17572,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
 
 static enum ggml_opt_result ggml_opt_adam(
         struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
         struct ggml_opt_params params,
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
@@ -15923,25 +17598,29 @@ static enum ggml_opt_result ggml_opt_adam(
         }
     }
 
+    if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
+        int iter = opt->iter;
+        ggml_opt_init(opt->ctx, opt, params, nx);
+        opt->iter = iter;
+    }
+
     // constants
-    const float alpha = params.adam.alpha;
+    const float sched = params.adam.sched;
+    const float decay = params.adam.decay * sched;
+    const float alpha = params.adam.alpha * sched;
     const float beta1 = params.adam.beta1;
     const float beta2 = params.adam.beta2;
     const float eps   = params.adam.eps;
 
-    float * x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
-    float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
-    float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
-    float * m  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
-    float * v  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
-    float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
-    float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
-
-    float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+    float * x  = opt->adam.x->data;  // view of the parameters
+    float * g1 = opt->adam.g1->data; // gradient
+    float * g2 = opt->adam.g2->data; // gradient squared
+    float * m  = opt->adam.m->data;  // first moment
+    float * v  = opt->adam.v->data;  // second moment
+    float * mh = opt->adam.mh->data; // first moment hat
+    float * vh = opt->adam.vh->data; // second moment hat
 
-    // initialize
-    ggml_vec_set_f32(nx, m, 0.0f);
-    ggml_vec_set_f32(nx, v, 0.0f);
+    float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
 
     // update view
     ggml_opt_get_params(np, ps, x);
@@ -15951,16 +17630,27 @@ static enum ggml_opt_result ggml_opt_adam(
     ggml_set_f32      (f->grad, 1.0f);
     ggml_graph_compute(ctx, gb);
 
-    float fx_prev = ggml_get_f32_1d(f, 0);
+    opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
+    opt->adam.fx_best = opt->adam.fx_prev;
     if (pf) {
-        pf[0] = fx_prev;
+        pf[opt->iter % params.past] = opt->adam.fx_prev;
+    }
+
+    // initialize
+    if (opt->just_initialized) {
+        opt->adam.n_no_improvement = 0;
+        opt->just_initialized = false;
     }
 
-    int n_no_improvement = 0;
-    float fx_best = fx_prev;
+    float * fx_best = &opt->adam.fx_best;
+    float * fx_prev = &opt->adam.fx_prev;
+    int * n_no_improvement = &opt->adam.n_no_improvement;
+
+    int iter0 = opt->iter;
 
     // run the optimizer
     for (int t = 0; t < params.adam.n_iter; ++t) {
+        opt->iter = iter0 + t + 1;
         GGML_PRINT_DEBUG  ("=== iter %d ===\n", t);
 
         GGML_PRINT_DEBUG  ("f      = %10.6f\n", ggml_get_f32_1d(f, 0));
@@ -15994,17 +17684,22 @@ static enum ggml_opt_result ggml_opt_adam(
 
             // m^hat = m_t / (1 - beta1^t)
             // v^hat = v_t / (1 - beta2^t)
-            // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
+            // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
+            // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
+            // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
+            // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
+            // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
             ggml_vec_cpy_f32  (nx, mh, m);
             ggml_vec_cpy_f32  (nx, vh, v);
 
-            ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
-            ggml_vec_scale_f32(nx, vh,  1.0f/(1.0f - powf(beta2, t + 1)));
+            ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
+            ggml_vec_scale_f32(nx, vh,  1.0f/(1.0f - powf(beta2, opt->iter)));
 
             ggml_vec_sqrt_f32 (nx, vh, vh);
             ggml_vec_acc1_f32 (nx, vh, eps);
 
             ggml_vec_div_f32  (nx, mh, mh, vh);
+            ggml_vec_scale_f32(nx, x,  1.0f - decay);
             ggml_vec_sub_f32  (nx, x,  x,  mh);
 
             // update the parameters
@@ -16018,7 +17713,7 @@ static enum ggml_opt_result ggml_opt_adam(
         const float fx = ggml_get_f32_1d(f, 0);
 
         // check convergence
-        if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
+        if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
             GGML_PRINT_DEBUG("converged\n");
 
             return GGML_OPT_OK;
@@ -16027,32 +17722,32 @@ static enum ggml_opt_result ggml_opt_adam(
         // delta-based convergence test
         if (pf != NULL) {
             // need at least params.past iterations to start checking for convergence
-            if (params.past <= t) {
-                const float rate = (pf[t%params.past] - fx)/fx;
+            if (params.past <= iter0 + t) {
+                const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
 
                 if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
 
-            pf[t%params.past] = fx;
+            pf[(iter0 + t)%params.past] = fx;
         }
 
         // check for improvement
         if (params.max_no_improvement > 0) {
-            if (fx_best > fx) {
-                fx_best = fx;
-                n_no_improvement = 0;
+            if (fx_best[0] > fx) {
+                fx_best[0] = fx;
+                n_no_improvement[0] = 0;
             } else {
-                ++n_no_improvement;
+                ++n_no_improvement[0];
 
-                if (n_no_improvement >= params.max_no_improvement) {
+                if (n_no_improvement[0] >= params.max_no_improvement) {
                     return GGML_OPT_OK;
                 }
             }
         }
 
-        fx_prev = fx;
+        fx_prev[0] = fx;
 
         {
             const int64_t t_end_cpu = ggml_cycles();
@@ -16191,6 +17886,7 @@ static enum ggml_opt_result linesearch_backtracking(
 
 static enum ggml_opt_result ggml_opt_lbfgs(
         struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
         struct ggml_opt_params params,
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
@@ -16223,31 +17919,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         }
     }
 
-    float * x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
-    float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
-    float * g  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
-    float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
-    float * d  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
+    if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
+        int iter = opt->iter;
+        ggml_opt_init(ctx, opt, params, nx);
+        opt->iter = iter;
+    }
+
+    float * x  = opt->lbfgs.x->data;  // current parameters
+    float * xp = opt->lbfgs.xp->data; // previous parameters
+    float * g  = opt->lbfgs.g->data;  // current gradient
+    float * gp = opt->lbfgs.gp->data; // previous gradient
+    float * d  = opt->lbfgs.d->data;  // search direction
 
-    float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
+    float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
 
     float fx    = 0.0f; // cost function value
     float xnorm = 0.0f; // ||x||
     float gnorm = 0.0f; // ||g||
-    float step  = 0.0f;
 
     // initialize x from the graph nodes
     ggml_opt_get_params(np, ps, x);
 
     // the L-BFGS memory
-    struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
-
-    for (int i = 0; i < m; ++i) {
-        lm[i].alpha = 0.0f;
-        lm[i].ys    = 0.0f;
-        lm[i].s     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
-        lm[i].y     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
-    }
+    float * lm_alpha = opt->lbfgs.lmal->data;
+    float * lm_ys    = opt->lbfgs.lmys->data;
+    float * lm_s     = opt->lbfgs.lms->data;
+    float * lm_y     = opt->lbfgs.lmy->data;
 
     // evaluate the function value and its gradient
     {
@@ -16262,12 +17959,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         fx = ggml_get_f32_1d(f, 0);
     }
 
-    if (pf) {
-        pf[0] = fx;
-    }
-
-    float fx_best = fx;
-
     // search direction = -gradient
     ggml_vec_neg_f32(nx, d, g);
 
@@ -16284,26 +17975,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         return GGML_OPT_OK;
     }
 
-    // initial step
-    ggml_vec_norm_inv_f32(nx, &step, d);
+    if (opt->just_initialized) {
+        if (pf) {
+            pf[0] = fx;
+        }
+        opt->lbfgs.fx_best = fx;
+
+        // initial step
+        ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
+        opt->lbfgs.j                = 0;
+        opt->lbfgs.k                = 1;
+        opt->lbfgs.end              = 0;
+        opt->lbfgs.n_no_improvement = 0;
+        opt->just_initialized       = false;
+    }
+
+    float * fx_best        = &opt->lbfgs.fx_best;
+    float * step           = &opt->lbfgs.step;
+    int * j                = &opt->lbfgs.j;
+    int * k                = &opt->lbfgs.k;
+    int * end              = &opt->lbfgs.end;
+    int * n_no_improvement = &opt->lbfgs.n_no_improvement;
 
-    int j                = 0;
-    int k                = 1;
-    int ls               = 0;
-    int end              = 0;
-    int bound            = 0;
-    int n_no_improvement = 0;
+    int ls     = 0;
+    int bound  = 0;
 
     float ys   = 0.0f;
     float yy   = 0.0f;
     float beta = 0.0f;
 
+    int it = 0;
+
     while (true) {
         // store the current position and gradient vectors
         ggml_vec_cpy_f32(nx, xp, x);
         ggml_vec_cpy_f32(nx, gp, g);
 
-        ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
+        ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
 
         if (ls < 0) {
             // linesearch failed - go back to the previous point and return
@@ -16329,32 +18037,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         // delta-based convergence test
         if (pf != NULL) {
             // need at least params.past iterations to start checking for convergence
-            if (params.past <= k) {
-                const float rate = (pf[k%params.past] - fx)/fx;
+            if (params.past <= k[0]) {
+                const float rate = (pf[k[0]%params.past] - fx)/fx;
 
                 if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
 
-            pf[k%params.past] = fx;
+            pf[k[0]%params.past] = fx;
         }
 
         // check for improvement
         if (params.max_no_improvement > 0) {
-            if (fx < fx_best) {
-                fx_best = fx;
-                n_no_improvement = 0;
+            if (fx < fx_best[0]) {
+                fx_best[0] = fx;
+                n_no_improvement[0] = 0;
             } else {
-                n_no_improvement++;
+                n_no_improvement[0]++;
 
-                if (n_no_improvement >= params.max_no_improvement) {
+                if (n_no_improvement[0] >= params.max_no_improvement) {
                     return GGML_OPT_OK;
                 }
             }
         }
 
-        if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
+        if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
             // reached the maximum number of iterations
             return GGML_OPT_DID_NOT_CONVERGE;
         }
@@ -16363,50 +18071,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         //   s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
         //   y_{k+1} = g_{k+1} - g_{k}.
         //
-        ggml_vec_sub_f32(nx, lm[end].s, x, xp);
-        ggml_vec_sub_f32(nx, lm[end].y, g, gp);
+        ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
+        ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
 
         // compute scalars ys and yy:
         //     ys = y^t \cdot s    -> 1 / \rho.
         //     yy = y^t \cdot y.
         //
-        ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
-        ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
+        ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
+        ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
 
-        lm[end].ys = ys;
+        lm_ys[end[0]] = ys;
 
         // find new search direction
         //   ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
 
-        bound = (m <= k) ? m : k;
-        k++;
-        end = (end + 1)%m;
+        bound = (m <= k[0]) ? m : k[0];
+        k[0]++;
+        it++;
+        end[0] = (end[0] + 1)%m;
 
         // initialize search direction with -g
         ggml_vec_neg_f32(nx, d, g);
 
-        j = end;
+        j[0] = end[0];
         for (int i = 0; i < bound; ++i) {
-            j = (j + m - 1) % m;
+            j[0] = (j[0] + m - 1) % m;
             // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
-            ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
-            lm[j].alpha /= lm[j].ys;
+            ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
+            lm_alpha[j[0]] /= lm_ys[j[0]];
             // q_{i} = q_{i+1} - \alpha_{i} y_{i}
-            ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
+            ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
         }
 
         ggml_vec_scale_f32(nx, d, ys/yy);
 
         for (int i = 0; i < bound; ++i) {
             // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
-            ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
-            beta /= lm[j].ys;
+            ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
+            beta /= lm_ys[j[0]];
             // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
-            ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
-            j = (j + 1)%m;
+            ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
+            j[0] = (j[0] + 1)%m;
         }
 
-        step = 1.0;
+        step[0] = 1.0;
     }
 
     return GGML_OPT_DID_NOT_CONVERGE;
@@ -16431,6 +18140,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
 
                     .adam = {
                         .n_iter = 10000,
+                        .sched  = 1.000f,
+                        .decay  = 0.001f,
                         .alpha  = 0.001f,
                         .beta1  = 0.9f,
                         .beta2  = 0.999f,
@@ -16473,6 +18184,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
     return result;
 }
 
+GGML_API void ggml_opt_init(
+        struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
+        struct ggml_opt_params params,
+        int64_t nx) {
+    opt->ctx = ctx;
+    opt->params = params;
+    opt->iter = 0;
+    opt->nx = nx;
+    opt->just_initialized = true;
+    switch (opt->params.type) {
+        case GGML_OPT_ADAM:
+            {
+                opt->adam.x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.m  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.v  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.pf = params.past > 0
+                    ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+                    : NULL;
+                ggml_set_zero(opt->adam.x);
+                ggml_set_zero(opt->adam.g1);
+                ggml_set_zero(opt->adam.g2);
+                ggml_set_zero(opt->adam.m);
+                ggml_set_zero(opt->adam.v);
+                ggml_set_zero(opt->adam.mh);
+                ggml_set_zero(opt->adam.vh);
+                if (opt->adam.pf) {
+                    ggml_set_zero(opt->adam.pf);
+                }
+            } break;
+        case GGML_OPT_LBFGS:
+            {
+                opt->lbfgs.x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.g  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.d  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.pf = params.past > 0
+                    ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+                    : NULL;
+                opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
+                opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
+                opt->lbfgs.lms  = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+                opt->lbfgs.lmy  = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+                ggml_set_zero(opt->lbfgs.x);
+                ggml_set_zero(opt->lbfgs.xp);
+                ggml_set_zero(opt->lbfgs.g);
+                ggml_set_zero(opt->lbfgs.gp);
+                ggml_set_zero(opt->lbfgs.d);
+                ggml_set_zero(opt->lbfgs.pf);
+                if (opt->lbfgs.pf) {
+                    ggml_set_zero(opt->lbfgs.pf);
+                }
+                ggml_set_zero(opt->lbfgs.lmal);
+                ggml_set_zero(opt->lbfgs.lmys);
+                ggml_set_zero(opt->lbfgs.lms);
+                ggml_set_zero(opt->lbfgs.lmy);
+            } break;
+    }
+}
+
 enum ggml_opt_result ggml_opt(
         struct ggml_context * ctx,
         struct ggml_opt_params params,
@@ -16495,33 +18271,65 @@ enum ggml_opt_result ggml_opt(
 
     enum ggml_opt_result result = GGML_OPT_OK;
 
+    struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
+
+    ggml_opt_init(ctx, opt, params, 0);
+    result = ggml_opt_resume(ctx, opt, f);
+
+    if (free_ctx) {
+        ggml_free(ctx);
+    }
+
+    return result;
+}
+
+enum ggml_opt_result ggml_opt_resume(
+        struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
+        struct ggml_tensor * f) {
+
+    // build forward + backward compute graphs
+    struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
+    struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
+
+    struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
+    struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
+
+    *gf = ggml_build_forward (f);
+    *gb = ggml_build_backward(ctx, gf, true);
+
+    return ggml_opt_resume_g(ctx, opt, f, gf, gb);
+}
+
+enum ggml_opt_result ggml_opt_resume_g(
+        struct ggml_context * ctx,
+        struct ggml_opt_context * opt,
+        struct ggml_tensor * f,
+        struct ggml_cgraph * gf,
+        struct ggml_cgraph * gb) {
+
     // build forward + backward compute graphs
-    struct ggml_cgraph gf = ggml_build_forward (f);
-    struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
+    enum ggml_opt_result result = GGML_OPT_OK;
 
-    switch (params.type) {
+    switch (opt->params.type) {
         case GGML_OPT_ADAM:
             {
-                result = ggml_opt_adam(ctx, params, f, &gf, &gb);
+                result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
             } break;
         case GGML_OPT_LBFGS:
             {
-                result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
+                result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
             } break;
     }
 
-    if (params.print_forward_graph) {
-        ggml_graph_print   (&gf);
-        ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
-    }
-
-    if (params.print_backward_graph) {
-        ggml_graph_print   (&gb);
-        ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
+    if (opt->params.print_forward_graph) {
+        ggml_graph_print   (gf);
+        ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
     }
 
-    if (free_ctx) {
-        ggml_free(ctx);
+    if (opt->params.print_backward_graph) {
+        ggml_graph_print   (gb);
+        ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
     }
 
     return result;
@@ -16689,6 +18497,50 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
                 result = ggml_quantize_q8_0(src + start, block, n, n, hist);
             } break;
+#ifdef GGML_USE_K_QUANTS
+        case GGML_TYPE_Q2_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q2_K * block = (block_q2_K*)dst + start / QK_K;
+                result = ggml_quantize_q2_K(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q3_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q3_K * block = (block_q3_K*)dst + start / QK_K;
+                result = ggml_quantize_q3_K(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q4_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q4_K * block = (block_q4_K*)dst + start / QK_K;
+                result = ggml_quantize_q4_K(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q5_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q5_K * block = (block_q5_K*)dst + start / QK_K;
+                result = ggml_quantize_q5_K(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q6_K:
+            {
+                GGML_ASSERT(start % QK_K == 0);
+                block_q6_K * block = (block_q6_K*)dst + start / QK_K;
+                result = ggml_quantize_q6_K(src + start, block, n, n, hist);
+            } break;
+#endif
+        case GGML_TYPE_F16:
+            {
+                int elemsize = sizeof(ggml_fp16_t);
+                ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
+                result = n * elemsize;
+            } break;
+        case GGML_TYPE_F32:
+            {
+                int elemsize = sizeof(float);
+                result = n * elemsize;
+                memcpy((uint8_t *)dst + start * elemsize, src + start, result);
+            } break;
         default:
             assert(false);
     }