]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (memory allocator + cuda + metal)
authorGeorgi Gerganov <redacted>
Mon, 7 Aug 2023 09:09:58 +0000 (12:09 +0300)
committerGeorgi Gerganov <redacted>
Mon, 7 Aug 2023 11:23:05 +0000 (14:23 +0300)
ggml-ci

13 files changed:
examples/mnist/main-cpu.cpp
include/ggml/ggml.h
scripts/sync-llama.sh
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c
tests/CMakeLists.txt
tests/test-grad0.c [deleted file]
tests/test-grad0.cpp [new file with mode: 0644]
tests/test-opt.c [deleted file]
tests/test-opt.cpp [new file with mode: 0644]

index 3e8bfe674f05e54da6b2e70d9b80c58dfaa3af0e..ba0c313677e4dced2e9f8b445c0eb68004274780 100644 (file)
@@ -34,8 +34,7 @@
 int mnist_eval(
         const char * fname_cgraph,
         const int n_threads,
-        std::vector<float> digit
-        ) {
+        std::vector<float> digit) {
     // load the compute graph
     struct ggml_context * ctx_data = NULL;
     struct ggml_context * ctx_eval = NULL;
index c2c2155f8667581917c5000e19fbe97310f024d4..bdbd12800433242dec1e6dac8c96875540dd19e7 100644 (file)
 
 #define GGML_UNUSED(x) (void)(x)
 
+#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
 
 #define GGML_ASSERT(x) \
     do { \
@@ -409,6 +410,12 @@ extern "C" {
         GGML_UNARY_OP_SILU,
     };
 
+    enum ggml_object_type {
+        GGML_OBJECT_TENSOR,
+        GGML_OBJECT_GRAPH,
+        GGML_OBJECT_WORK_BUFFER
+    };
+
     // ggml object
     struct ggml_object {
         size_t offs;
@@ -416,7 +423,9 @@ extern "C" {
 
         struct ggml_object * next;
 
-        char padding[8];
+        enum ggml_object_type type;
+
+        char padding[4];
     };
 
     static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
@@ -437,7 +446,7 @@ extern "C" {
         enum ggml_op op;
 
         // op params - allocated as int32_t for alignment
-        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)];
+        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 
         bool is_param;
 
@@ -498,6 +507,8 @@ extern "C" {
         int64_t perf_time_us;
     };
 
+    static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
+
     // scratch buffer
     struct ggml_scratch {
         size_t offs;
@@ -1174,7 +1185,18 @@ extern "C" {
             int                   mode,
             int                   n_ctx);
 
-    // custom RoPE, in-place, returns view(a)
+    // custom RoPE
+    GGML_API struct ggml_tensor * ggml_rope_custom(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx,
+            float                 freq_base,
+            float                 freq_scale);
+
+    // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1233,7 +1255,7 @@ extern "C" {
 
     // conv_1d with padding = half
     // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
-    GGML_API struct ggml_tensor* ggml_conv_1d_ph(
+    GGML_API struct ggml_tensor * ggml_conv_1d_ph(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
@@ -1246,7 +1268,7 @@ extern "C" {
         GGML_OP_POOL_COUNT,
     };
 
-    GGML_API struct ggml_tensor* ggml_pool_1d(
+    GGML_API struct ggml_tensor * ggml_pool_1d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             enum ggml_op_pool     op,
@@ -1254,7 +1276,7 @@ extern "C" {
             int                   s0, // stride
             int                   p0); // padding
 
-    GGML_API struct ggml_tensor* ggml_pool_2d(
+    GGML_API struct ggml_tensor * ggml_pool_2d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             enum ggml_op_pool     op,
@@ -1472,11 +1494,17 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * tensor);
 
+
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
     GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
 
+    // graph allocation in a context
+    GGML_API struct ggml_cgraph * ggml_new_graph        (struct ggml_context * ctx);
+    GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API size_t ggml_graph_overhead(void);
+
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
index 0c6c424ede3ceb46dbcc4d9ab6bbf79ad4787c00..22d5fbecd24137dca4e43d38b05f66cc69f222ee 100755 (executable)
@@ -10,7 +10,7 @@ cp -rpv ../llama.cpp/ggml-metal.m     src/ggml-metal.m
 cp -rpv ../llama.cpp/ggml-metal.metal src/ggml-metal.metal
 cp -rpv ../llama.cpp/ggml.h           include/ggml/ggml.h
 
-cp -rpv ../llama.cpp/tests/test-opt.c             tests/test-opt.c
-cp -rpv ../llama.cpp/tests/test-grad0.c           tests/test-grad0.c
+cp -rpv ../llama.cpp/tests/test-opt.cpp           tests/test-opt.cpp
+cp -rpv ../llama.cpp/tests/test-grad0.cpp         tests/test-grad0.cpp
 cp -rpv ../llama.cpp/tests/test-quantize-fns.cpp  tests/test-quantize-fns.cpp
 cp -rpv ../llama.cpp/tests/test-quantize-perf.cpp tests/test-quantize-perf.cpp
index d31fc79c10961de17704b317b7489f7a0c5e1c9e..9d42efb0d0b03a98b11c16f19260a10c8b9009da 100644 (file)
@@ -52,13 +52,41 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     } while (0)
 #endif // CUDART_VERSION >= 11
 
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
 typedef half dfloat; // dequantize float
 typedef half2 dfloat2;
 #else
 typedef float dfloat; // dequantize float
 typedef float2 dfloat2;
-#endif //GGML_CUDA_DMMV_F16
+#endif //GGML_CUDA_F16
+
+static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
+    const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+
+    int x32 = 0;
+    x32 |= x16[0] <<  0;
+    x32 |= x16[1] << 16;
+
+    return x32;
+}
+
+static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
+    const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+
+    int x32 = 0;
+    x32 |= x16[0] <<  0;
+    x32 |= x16[1] << 16;
+
+    return x32;
+}
+
+static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
+    return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
+    return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
 
 typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
 typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
@@ -87,8 +115,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
 #define QR4_1 2
 #define QI4_1 (QK4_1 / (4 * QR4_1))
 typedef struct {
-    half    d;              // delta
-    half    m;              // min
+    half2   dm;             // dm.x = delta, dm.y = min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
 static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
@@ -107,8 +134,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
 #define QR5_1 2
 #define QI5_1 (QK5_1 / (4 * QR5_1))
 typedef struct {
-    half d;                 // delta
-    half m;                 // min
+    half2 dm;               // dm.x = delta, dm.y = min
     uint8_t qh[4];          // 5-th bit of quants
     uint8_t qs[QK5_1 / 2];  // nibbles / quants
 } block_q5_1;
@@ -127,13 +153,19 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
 #define QR8_1 1
 #define QI8_1 (QK8_1 / (4 * QR8_1))
 typedef struct {
-    half    d;              // delta
-    half    s;              // unquantized sum
+    half2   ds;             // ds.x = delta, ds.y = sum
     int8_t  qs[QK8_0];      // quants
 } block_q8_1;
 static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
 
-typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
+typedef void (*load_tiles_cuda_t)(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
+typedef float (*vec_dot_q_mul_mat_cuda_t)(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
 
 //================================= k-quants
 
@@ -150,8 +182,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
-    half d;                  // super-block scale for quantized scales
-    half dmin;               // super-block scale for quantized mins
+    half2 dm;                // super-block scale for quantized scales/mins
 } block_q2_K;
 static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
 
@@ -180,8 +211,7 @@ typedef struct {
 static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
 #else
 typedef struct {
-    half d;                    // super-block scale for quantized scales
-    half dmin;                 // super-block scale for quantized mins
+    half2 dm;                  // super-block scale for quantized scales/mins
     uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
     uint8_t qs[QK_K/2];        // 4--bit quants
 } block_q4_K;
@@ -200,11 +230,10 @@ typedef struct {
 static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
 #else
 typedef struct {
-    half d;               // super-block scale for quantized scales
-    half dmin;            // super-block scale for quantized mins
-    uint8_t scales[K_SCALE_SIZE];   // scales and mins, quantized with 6 bits
-    uint8_t qh[QK_K/8];          // quants, high bit
-    uint8_t qs[QK_K/2];          // quants, low 4 bits
+    half2 dm;                     // super-block scale for quantized scales/mins
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];           // quants, high bit
+    uint8_t qs[QK_K/2];           // quants, low 4 bits
 } block_q5_K;
 static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
 #endif
@@ -233,6 +262,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 
+#ifndef GGML_CUDA_MMQ_Y
+#define GGML_CUDA_MMQ_Y 64
+#endif // GGML_CUDA_MMQ_Y
+
 // dmmv = dequantize_mul_mat_vec
 #ifndef GGML_CUDA_DMMV_X
 #define GGML_CUDA_DMMV_X 32
@@ -367,33 +400,33 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
     v.x = vui & 0xF;
     v.y = vui >> 4;
 
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
     v = __hsub2(v, {8.0f, 8.0f});
     v = __hmul2(v, {d, d});
 #else
     v.x = (v.x - 8.0f) * d;
     v.y = (v.y - 8.0f) * d;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
 }
 
 static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q4_1 * x = (const block_q4_1 *) vx;
 
-    const dfloat d = x[ib].d;
-    const dfloat m = x[ib].m;
+    const dfloat d = x[ib].dm.x;
+    const dfloat m = x[ib].dm.y;
 
     const int vui = x[ib].qs[iqs];
 
     v.x = vui & 0xF;
     v.y = vui >> 4;
 
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
     v = __hmul2(v, {d, d});
     v = __hadd2(v, {m, m});
 #else
     v.x = (v.x * d) + m;
     v.y = (v.y * d) + m;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
 }
 
 static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
@@ -410,20 +443,20 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
     v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
     v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
 
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
     v = __hsub2(v, {16.0f, 16.0f});
     v = __hmul2(v, {d, d});
 #else
     v.x = (v.x - 16.0f) * d;
     v.y = (v.y - 16.0f) * d;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
 }
 
 static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q5_1 * x = (const block_q5_1 *) vx;
 
-    const dfloat d = x[ib].d;
-    const dfloat m = x[ib].m;
+    const dfloat d = x[ib].dm.x;
+    const dfloat m = x[ib].dm.y;
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -434,13 +467,13 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
     v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
     v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
 
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
     v = __hmul2(v, {d, d});
     v = __hadd2(v, {m, m});
 #else
     v.x = (v.x * d) + m;
     v.y = (v.y * d) + m;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
 }
 
 static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
@@ -451,12 +484,12 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
     v.x = x[ib].qs[iqs + 0];
     v.y = x[ib].qs[iqs + 1];
 
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
     v = __hmul2(v, {d, d});
 #else
     v.x *= d;
     v.y *= d;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
 }
 
 //================================== k-quants
@@ -475,8 +508,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const uint8_t q = x[i].qs[32*n + l];
     float * y = yy + i*QK_K + 128*n;
 
-    float dall = x[i].d;
-    float dmin = x[i].dmin;
+    float dall = x[i].dm.x;
+    float dmin = x[i].dm.y;
     y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
     y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
@@ -486,8 +519,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const int il = tid%16;  // 0...15
     const uint8_t q = x[i].qs[il] >> (2*is);
     float * y = yy + i*QK_K + 16*is + il;
-    float dall = x[i].d;
-    float dmin = x[i].dmin;
+    float dall = x[i].dm.x;
+    float dmin = x[i].dm.y;
     y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
     y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
 #endif
@@ -573,8 +606,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
 
     float * y = yy + i*QK_K + 64*il + n*ir;
 
-    const float dall = x[i].d;
-    const float dmin = x[i].dmin;
+    const float dall = x[i].dm.x;
+    const float dmin = x[i].dm.y;
 
     const uint8_t * q = x[i].qs + 32*il + n*ir;
 
@@ -612,8 +645,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
 
     float * y = yy + i*QK_K + 64*il + 2*ir;
 
-    const float dall = x[i].d;
-    const float dmin = x[i].dmin;
+    const float dall = x[i].dm.x;
+    const float dmin = x[i].dm.y;
 
     const uint8_t * ql = x[i].qs + 32*il + 2*ir;
     const uint8_t * qh = x[i].qh + 2*ir;
@@ -725,8 +758,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
         const float   * y = yy + i * QK_K + y_offset;
         const uint8_t * q = x[i].qs + q_offset;
 
-        const float dall = x[i].d;
-        const float dmin = x[i].dmin;
+        const float dall = x[i].dm.x;
+        const float dmin = x[i].dm.y;
 
         const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
         aux[0] = a[0] & 0x0f0f0f0f;
@@ -768,9 +801,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
         uaux[0] = s[0] & 0x0f0f0f0f;
         uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
 
-        const half2 * dh = (const half2 *)&x[i].d;
-
-        const float2 dall = __half22float2(dh[0]);
+        const float2 dall = __half22float2(x[i].dm);
 
         float sum1 = 0, sum2 = 0;
         for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
@@ -948,8 +979,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
         const float   * y1 = yy + i*QK_K + y_offset;
         const float   * y2 = y1 + 128;
 
-        const float dall = x[i].d;
-        const float dmin = x[i].dmin;
+        const float dall = x[i].dm.x;
+        const float dmin = x[i].dm.y;
 
         const uint16_t * a = (const uint16_t *)x[i].scales;
         aux[0] = a[im+0] & kmask1;
@@ -1081,8 +1112,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
         const float   * y1  = yy + i*QK_K + y_offset;
         const float   * y2  = y1 + 128;
 
-        const float dall = x[i].d;
-        const float dmin = x[i].dmin;
+        const float dall = x[i].dm.x;
+        const float dmin = x[i].dm.y;
 
         const uint16_t * a = (const uint16_t *)x[i].scales;
         aux[0] = a[im+0] & kmask1;
@@ -1270,19 +1301,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
     v.y = x[ib + iqs + 1];
 }
 
-static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
+    const int ix = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i >= k) {
+    if (ix >= kx_padded) {
         return;
     }
 
+    const int iy = blockDim.y*blockIdx.y + threadIdx.y;
+
+    const int i_padded = iy*kx_padded + ix;
+
     block_q8_1 * y = (block_q8_1 *) vy;
 
-    const int ib = i / QK8_1; // block index
-    const int iqs = i % QK8_1; // quant index
+    const int ib = i_padded / QK8_1; // block index
+    const int iqs = i_padded % QK8_1; // quant index
 
-    const float xi = i < ndata ? x[i] : 0.0f;
+    const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
     float amax = fabsf(xi);
     float sum = xi;
 
@@ -1301,8 +1336,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
         return;
     }
 
-    y[ib].d = d;
-    y[ib].s = sum;
+    y[ib].ds.x = d;
+    y[ib].ds.y = sum;
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -1326,485 +1361,1963 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
     y[iybs + iqs + y_offset] = v.y;
 }
 
-static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
+    const int * v, const int * u, const float & d4, const half2 & ds8) {
 
-    int vi;
-    memcpy(&vi,  &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
-    const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
-    const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
 
-    const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d);
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
 
-    // subtract 8 from each quantized value
-    const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
-    const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
+        // SIMD dot product of quantized values
+        sumi = __dp4a(vi0, u[2*i+0], sumi);
+        sumi = __dp4a(vi1, u[2*i+1], sumi);
+    }
 
-    // SIMD dot product of quantized values
-    int sumi = __dp4a(vi0, ui0, 0);
-    sumi     = __dp4a(vi1, ui1, sumi);
+    const float2 ds8f = __half22float2(ds8);
 
-    return sumi*d;
+    // second part effectively subtracts 8 from each quant value
+    return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ  4
 
-    const int vi  = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
-    const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
-    const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
+    const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
 
-    const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
-    const float m = bq4_1->m;
-    const float s = bq8_1->s;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
 
-    const int vi0 = (vi >> 0) & 0x0F0F0F0F;
-    const int vi1 = (vi >> 4) & 0x0F0F0F0F;
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
 
-    // SIMD dot product of quantized values
-    int sumi = __dp4a(vi0, ui0, 0);
-    sumi     = __dp4a(vi1, ui1, sumi);
+        // SIMD dot product of quantized values
+        sumi = __dp4a(vi0, u[2*i+0], sumi);
+        sumi = __dp4a(vi1, u[2*i+1], sumi);
+    }
 
-    return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
+#ifdef GGML_CUDA_F16
+    const float2 tmp = __half22float2(__hmul2(dm4, ds8));
+    const float d4d8 = tmp.x;
+    const float m4s8 = tmp.y;
+#else
+    const float2 dm4f = __half22float2(dm4);
+    const float2 ds8f = __half22float2(ds8);
+    const float d4d8 = dm4f.x * ds8f.x;
+    const float m4s8 = dm4f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
+    const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
+
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+    }
+
+    const float2 ds8f = __half22float2(ds8);
 
-    int qs;
-    memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
-    const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2);
-    const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2);
-    const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
-    const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);
-
-    const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d);
-
-    int vi0 = (qs  >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
-    vi0    |= (qh0 <<  4) & 0x00000010; // 1 ->  5
-    vi0    |= (qh0 << 11) & 0x00001000; // 2 -> 13
-    vi0    |= (qh0 << 18) & 0x00100000; // 3 -> 21
-    vi0    |= (qh0 << 25) & 0x10000000; // 4 -> 29
-    vi0     = __vsub4(vi0,  0x10101010); // subtract 16 from quantized values
-    int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
-
-    int vi1 = (qs  >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
-    vi1    |= (qh1 <<  4) & 0x00000010; // 1 ->  5
-    vi1    |= (qh1 << 11) & 0x00001000; // 2 -> 13
-    vi1    |= (qh1 << 18) & 0x00100000; // 3 -> 21
-    vi1    |= (qh1 << 25) & 0x10000000; // 4 -> 29
-    vi1     = __vsub4(vi1,  0x10101010); // subtract 16 from quantized values
-    sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
-
-    return sumi*d;
+    // second part effectively subtracts 16 from each quant value
+    return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
+    const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
+
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+    }
+
+#ifdef GGML_CUDA_F16
+    const float2 tmp = __half22float2(__hmul2(dm5, ds8));
+    const float d5d8 = tmp.x;
+    const float m5s8 = tmp.y;
+#else
+    const float2 dm5f = __half22float2(dm5);
+    const float2 ds8f = __half22float2(ds8);
+    const float d5d8 = dm5f.x * ds8f.x;
+    const float m5s8 = dm5f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
 
-    const int qs  = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
-    const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
-    const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
-    const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
-    const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
-
-    const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
-    const float m = bq5_1->m;
-    const float s = bq8_1->s;
-
-    int vi0 = (qs  >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
-    vi0    |= (qh0 <<  4) & 0x00000010; // 1 ->  5
-    vi0    |= (qh0 << 11) & 0x00001000; // 2 -> 13
-    vi0    |= (qh0 << 18) & 0x00100000; // 3 -> 21
-    vi0    |= (qh0 << 25) & 0x10000000; // 4 -> 29
-    int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
-
-    int vi1 = (qs  >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
-    vi1    |= (qh1 <<  4) & 0x00000010; // 1 ->  5
-    vi1    |= (qh1 << 11) & 0x00001000; // 2 -> 13
-    vi1    |= (qh1 << 18) & 0x00100000; // 3 -> 21
-    vi1    |= (qh1 << 25) & 0x10000000; // 4 -> 29
-    sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
-
-    return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
 
-    int vi;
-    memcpy(&vi,  &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
-    const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
+    const int * v, const int * u, const float & d8_0, const float & d8_1) {
 
-    const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d);
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    int sumi = 0;
 
-    // SIMD dot product of quantized values
-    int sumi = __dp4a(vi, ui, 0);
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = __dp4a(v[i], u[i], sumi);
+    }
 
-    return sumi*d;
+    return d8_0*d8_1 * sumi;
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
+    const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+    int sumi = 0;
 
-    const int bq8_offset = QR2_K * (iqs / QI8_1);
-    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = __dp4a(v[i], u[i], sumi);
+    }
 
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
+#ifdef GGML_CUDA_F16
+    const float2 tmp = __half22float2(__hmul2(dm8, ds8));
+    const float d8d8 = tmp.x;
+    const float m8s8 = tmp.y;
+#else
+    const float2 dm8f = __half22float2(dm8);
+    const float2 ds8f = __half22float2(ds8);
+    const float d8d8 = dm8.x * ds8.x;
+    const float m8s8 = dm8.y * ds8.y;
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
 
-    const float    d = bq2_K->d;
-    const float dmin = bq2_K->dmin;
+#define VDR_Q2_K_Q8_1_MMVQ 1
+#define VDR_Q2_K_Q8_1_MMQ  2
 
-    const int v = *((int *) &bq2_K->qs[sizeof(int) * iqs]);
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+    const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const half2 & dm2, const float * __restrict__ d8) {
 
-    for (int i = 0; i < QR2_K; ++i) {
-        const int sc = bq2_K->scales[scale_offset + 2*i];
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
 
-        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const float d8i = bq8i->d;
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++i) {
+        const int sc = scales[2*i];
 
         const int vi = (v >> (2*i)) & 0x03030303;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
 
-        sumf_d += d8i * (__dp4a(vi,         ui, 0) * (sc & 0xF)); // SIMD dot product
-        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * (sc >>  4)); // multiply constant q2_K part with sum of q8_1 values
+        sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+        sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
     }
 
-    return d*sumf_d - dmin*sumf_m;
+    const float2 dm2f = __half22float2(dm2);
+
+    return dm2f.x*sumf_d - dm2f.y*sumf_m;
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const half2 & dm2, const float & d8) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+    int sumi_d = 0;
+    int sumi_m = 0;
 
-    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
-    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+#pragma unroll
+    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
+        int sumi_d_sc = 0;
 
-    float sumf = 0.0f;
+        const int sc = scales[i0 / (QI8_1/2)];
 
-    const float d = bq3_K->d;
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+
+#pragma unroll
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
+            sumi_m    = __dp4a(m,    u[i], sumi_m); // multiply sum of q8_1 values with m
+        }
+
+        sumi_d += sumi_d_sc * (sc & 0xF);
+    }
 
-    int vl;
-    memcpy(&vl, &bq3_K->qs[sizeof(int) * iqs], sizeof(int));
+    const float2 dm2f = __half22float2(dm2);
 
-    int vh;
-    memcpy(&vh, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int));
-    vh = ~vh; // invert the mask so that a 0/1 results in 4/0 being subtracted
-    vh >>= bq8_offset;
+    return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+#define VDR_Q3_K_Q8_1_MMQ  2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+    const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const int & scale_offset, const float & d3, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf = 0.0f;
 
+#pragma unroll
     for (int i = 0; i < QR3_K; ++i) {
         const int isc = scale_offset + 2*i;
 
         const int isc_low = isc % (QK_K/32);
         const int sc_shift_low = 4 * (isc / (QK_K/32));
-        const int sc_low  = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF;
+        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;
 
         const int isc_high = isc % (QK_K/64);
         const int sc_shift_high = 2 * (isc / (QK_K/64));
-        const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
 
         const int sc = (sc_low | sc_high) - 32;
 
-        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
-        const float d8i = bq8i->d;
-
         const int vil = (vl >> (2*i)) & 0x03030303;
 
         const int vih = ((vh >> i) << 2) & 0x04040404;
 
         const int vi = __vsubss4(vil, vih);
 
-        sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
+        sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
     }
 
-    return d*sumf;
+    return d3 * sumf;
 #else
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
+    const float & d3, const float & d8) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+    int sumi = 0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+        int sumi_sc = 0;
+
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+        }
+
+        sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+    }
+
+    return d3*d8 * sumi;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+#define VDR_Q4_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
 
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
 
-#ifndef GGML_QKK_64
+#pragma unroll
+    for (int i = 0; i < QR4_K; ++i) {
+        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
 
-    // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
-    const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
+        const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
 
-    const float    d = bq4_K->d;
-    const float dmin = bq4_K->dmin;
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
+    }
 
-    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
-    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
-    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
-    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+    const float2 dm4f = __half22float2(dm4);
 
-    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
-    const int v1 = q4[0];
-    const int v2 = q4[4];
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
 
-    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
-    uint16_t aux[2];
-    const int j = bq8_offset/2;
-    if (j < 2) {
-        aux[0] = scales[j+0] & 0x3f3f;
-        aux[1] = scales[j+2] & 0x3f3f;
-    } else {
-        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
-        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+// contiguous u/y values
+// also used for q5_K
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < VDR_Q4_K_Q8_1_MMQ; i0 += (QI8_1/QR4_K)) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int i = i0; i < i0 + (QI8_1/QR4_K); ++i) {
+            sumi_d = __dp4a(v[2*i+0], u[2*i+0], sumi_d); // SIMD dot product
+            sumi_d = __dp4a(v[2*i+1], u[2*i+1], sumi_d); // SIMD dot product
+        }
+
+        const float2 ds8f = __half22float2(ds8[i0 / 4]);
+
+        sumf_d += ds8f.x * (sc[i0/4] * sumi_d);
+        sumf_m += ds8f.y *   m[i0/4]; // sum of q8_1 block * q4_K min val
     }
-    const uint8_t * sc = (const uint8_t *)aux;
-    const uint8_t * m  = sc + 2;
 
-    for (int i = 0; i < QR4_K; ++i) {
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+#define VDR_Q5_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
+    const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+        const int v0i = vl0i | vh0i;
+        const int v1i = vl1i | vh1i;
+
+        const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);
+
+    }
+
+    const float2 dm5f = __half22float2(dm5);
+
+    return dm5f.x*sumf_d - dm5f.y*sumf_m;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+#define VDR_Q6_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
+    const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
+    const float & d, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        const int sc = scales[4*i];
+
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+        sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+// contiguous u/y values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
+    const float & d6, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    float sumf_d = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+        int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+        for (int i = i0; i < i0 + 2; ++i) {
+            sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
+            sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
+
+            sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
+            sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
+        }
+
+        sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
+    }
+
+    return d6 * sumf_d;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+
+    int v[VDR_Q4_0_Q8_1_MMVQ];
+    int u[2*VDR_Q4_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+        v[i]     = get_int_from_uint8(bq4_0->qs, iqs + i);
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
+    }
+
+    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int  tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
+    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0) + GGML_CUDA_MMQ_Y/QI4_0];
+
+    *x_ql = tile_x_qs;
+    *x_dm = (half2 *) tile_x_d;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI4_0;
+    const int kqsx = k % QI4_0;
+
+    const block_q4_0 * bx0 = (block_q4_0 *) vx;
+
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
+    }
+
+//     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+//     const int kbxd = k % blocks_per_tile_x_row;
+
+// #pragma unroll
+//     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_0) {
+//         FIXME out-of-bounds
+//         const int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
+
+//         if (i >= GGML_CUDA_MMQ_Y) {
+//             return;
+//         }
+
+//         const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+//         x_dm[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd].x = bxi->d;
+//     }
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q4_0_Q8_1_MMQ == 0);
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const float * x_dmf = (float *) x_dm;
+
+    int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
+        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI4_0];
+    }
+
+    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
+         y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+
+    int v[VDR_Q4_1_Q8_1_MMVQ];
+    int u[2*VDR_Q4_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+        v[i]    = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
+    }
+
+    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int   tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE) +     + GGML_CUDA_MMQ_Y];
+    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1) + GGML_CUDA_MMQ_Y/QI4_1];
+
+    *x_ql = tile_x_qs;
+    *x_dm = tile_x_dm;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI4_1;
+    const int kqsx = k % QI4_1;
+
+    const block_q4_1 * bx0 = (block_q4_1 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_1) {
+        int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q4_1_Q8_1_MMQ == 0);
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+
+    int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
+        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI4_1];
+    }
+
+    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
+         y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+
+    int vl[VDR_Q5_0_Q8_1_MMVQ];
+    int vh[VDR_Q5_0_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+        vl[i]    = get_int_from_uint8(bq5_0->qs, iqs + i);
+        vh[i]    = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
+    }
+
+    return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int  tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
+    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0) + GGML_CUDA_MMQ_Y/QI5_0];
+
+    *x_ql = tile_x_ql;
+    *x_dm = (half2 *) tile_x_d;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI5_0;
+    const int kqsx = k % QI5_0;
+
+    const block_q5_0 * bx0 = (block_q5_0 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        const int ql = get_int_from_uint8(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
+
+        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
+        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
+        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
+        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
+        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
+        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
+        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
+        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
+        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
+        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
+        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+    const int kbxd = k % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_0) {
+        int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q5_0_Q8_1_MMQ == 0);
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    int u[2*VDR_Q5_0_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
+        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI5_0];
+    }
+
+    return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+
+    int vl[VDR_Q5_1_Q8_1_MMVQ];
+    int vh[VDR_Q5_1_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+        vl[i]   = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
+        vh[i]   = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
+    }
+
+    return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
+    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1) + GGML_CUDA_MMQ_Y/QI5_1];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI5_1;
+    const int kqsx = k % QI5_1;
+
+    const block_q5_1 * bx0 = (block_q5_1 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
+
+        int qs0 = (ql >>  0) & 0x0F0F0F0F;
+        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
+        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
+        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
+        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+        int qs1 = (ql >>  4) & 0x0F0F0F0F;
+        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
+        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
+        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
+        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_1) {
+        int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q5_1_Q8_1_MMQ == 0);
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
+
+    int u[2*VDR_Q5_1_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
+        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI5_1];
+    }
+
+    return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+
+    int v[VDR_Q8_0_Q8_1_MMVQ];
+    int u[VDR_Q8_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+        v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
+        u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+    }
+
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int  tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
+    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0) + GGML_CUDA_MMQ_Y/QI8_0];
+
+    *x_ql = tile_x_qs;
+    *x_dm = (half2 *) tile_x_d;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI8_0;
+    const int kqsx = k % QI8_0;
+    float * x_dmf = (float *) x_dm;
+
+    const block_q8_0 * bx0 = (block_q8_0 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
+        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbx] = bxi->d;
+    }
+
+//     const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
+//     const int kbxd = k % blocks_per_tile_x_row;
+
+// #pragma unroll
+//     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI8_0) {
+//         FIXME out-of-bounds
+//         const int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
+
+// #if GGML_CUDA_MMQ_Y < 64
+//         if (i >= GGML_CUDA_MMQ_Y) {
+//             return;
+//         }
+// #endif // GGML_CUDA_MMQ_Y < 64
+
+//         const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+//         x_dm[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd].x = bxi->d;
+//     }
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q8_0_Q8_1_MMQ == 0);
+
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
+         y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+
+    const int bq8_offset = QR2_K * (iqs / QI8_1);
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    const uint8_t * scales = bq2_K->scales + scale_offset;
+
+    const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
+    int    u[QR2_K];
+    float d8[QR2_K];
+
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++ i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+        d8[i] = bq8_1[bq8_offset + i].ds.x;
+    }
+
+    return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
+    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI2_K) + GGML_CUDA_MMQ_Y/QI2_K];
+    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)     + GGML_CUDA_MMQ_Y/4];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI2_K;
+    const int kqsx = k % QI2_K;
+
+    const block_q2_K * bx0 = (block_q2_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI2_K) {
+        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q2_K_Q8_1_MMQ == 0);
+
+    const int kbx = k / QI2_K;
+    const int ky  = (k % QI2_K) * QR2_K;
+    const float * y_df = (const float *) y_ds;
+
+    int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
+
+    const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
+    const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
+
+#pragma unroll
+    for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
+        v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
+    }
+
+    const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+
+    const int index_y = j * (QR2_K*WARP_SIZE) + QR2_K*k;
+    return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+
+    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    const float d = bq3_K->d;
+
+    const int vl = get_int_from_uint8(bq3_K->qs, iqs);
+
+    // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+    const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+    int    u[QR3_K];
+    float d8[QR3_K];
+
+#pragma unroll
+    for (int i = 0; i < QR3_K; ++i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+        d8[i] = bq8_1[bq8_offset + i].ds.x;
+    }
+
+    return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
+    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI3_K) + GGML_CUDA_MMQ_Y/QI3_K];
+    __shared__ int   tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2)     + GGML_CUDA_MMQ_Y/2];
+    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)     + GGML_CUDA_MMQ_Y/4];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_qh = tile_x_qh;
+    *x_sc = tile_x_sc;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI3_K;
+    const int kqsx = k % QI3_K;
+
+    const block_q3_K * bx0 = (block_q3_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
+    const int kbxd = k % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI3_K) {
+        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) {
+        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
+
+        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
+
+        const int ksc = k % (QI3_K/4);
+
+        const int ksc_low = ksc % (QI3_K/8);
+        const int shift_low = 4 * (ksc / (QI3_K/8));
+        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+        const int ksc_high = QI3_K/8;
+        const int shift_high = 2 * ksc;
+        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q3_K_Q8_1_MMQ == 0);
+
+    const int kbx  = k / QI3_K;
+    const int ky  = (k % QI3_K) * QR3_K;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+
+    int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
+        const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
+        const int shift = 2 * ((ky % 32) / 8);
+        const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
+
+        const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
+        const int vlh = (vh << 2) & 0x04040404;
+
+        v[l] = __vsubss4(vll, vlh);
+    }
+
+    const int index_y = j * (QR3_K*WARP_SIZE) + k*QR3_K;
+    return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+#ifndef GGML_QKK_64
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+    int    v[2];
+    int    u[2*QR4_K];
+    float d8[QR4_K];
+
+    // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
+
+    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    v[0] = q4[0];
+    v[1] = q4[4];
+
+    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
+
+    for (int i = 0; i < QR4_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i] = bq8i->ds.x;
+
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+        u[2*i+0] = q8[0];
+        u[2*i+1] = q8[4];
+    }
+
+    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+
+#else
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    const uint16_t * a = (const uint16_t *)bq4_K->scales;
+    aux16[0] = a[0] & 0x0f0f;
+    aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+    const float dall = bq4_K->d[0];
+    const float dmin = bq4_K->d[1];
+
+    const float d8_1 = bq8_1[0].ds.x;
+    const float d8_2 = bq8_1[1].ds.x;
+
+    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
+
+    const int * q4 = (const int *)bq4_K->qs + (iqs/2);
+    const int v1 = q4[0];
+    const int v2 = q4[4];
+
+    const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
+    const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
+    const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+    const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
+
+    sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
+    sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+
+    return dall * sumf_d - dmin * sumf_m;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+
+#endif
+}
+
+static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
+    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K) + GGML_CUDA_MMQ_Y/QI4_K];
+    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)     + GGML_CUDA_MMQ_Y/8];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI4_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI4_K; // == k if QK_K == 256
+
+    const block_q4_K * bx0 = (block_q4_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_K) {
+        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
+
+        const int * scales = (int *) bxi->scales;
+
+        const int ksc = k % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q4_K_Q8_1_MMQ == 0);
+
+    int v[QR4_K*VDR_Q4_K_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_K_Q8_1_MMQ; ++l) {
+        v[l + 0]         = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 0) & 0x0F0F0F0F;
+        v[l + (QI4_K/4)] = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 4) & 0x0F0F0F0F;
+    }
+
+    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
+
+    const int index_y = j * (QR4_K*WARP_SIZE) + QR4_K*k;
+    return vec_dot_q4_K_q8_1_impl_mmq(v, &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+#ifndef GGML_QKK_64
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+    int   vl[2];
+    int   vh[2];
+    int    u[2*QR5_K];
+    float d8[QR5_K];
+
+    const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
+
+    vl[0] = ql[0];
+    vl[1] = ql[4];
+
+    vh[0] = qh[0] >> bq8_offset;
+    vh[1] = qh[4] >> bq8_offset;
+
+    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i] = bq8i->ds.x;
+
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+        u[2*i+0] = q8[0];
+        u[2*i+1] = q8[4];
+    }
+
+    return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, bq5_K->dm, d8);
+
+#else
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+    const int8_t * s = bq5_K->scales;
+
+    const float d = bq5_K->d;
+
+    const float d8_1 = bq8_1[0].ds.x;
+    const float d8_2 = bq8_1[1].ds.x;
+
+    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
+
+    const int * ql = (const int *)bq5_K->qs + (iqs/2);
+    const int vl1 = ql[0];
+    const int vl2 = ql[4];
+
+    const int step = 4 * (iqs/2); // 0, 4, 8, 12
+    const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
+    const int in = step%8; // 0, 4, 0, 4
+    const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+
+    const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
+    const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
+    const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
+    const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+
+    const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+                       + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
+
+    return d * sumf_d;
+
+#else
+    return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+
+#endif
+}
+
+static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
+    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_K) + GGML_CUDA_MMQ_Y/QI5_K];
+    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)     + GGML_CUDA_MMQ_Y/8];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
+
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+
+    const int kbx  = k / QI5_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI5_K; // == k if QK_K == 256
+
+    const block_q5_K * bx0 = (block_q5_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ky = QR5_K*kqsx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+        const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_K) {
+        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
+
+        const int * scales = (int *) bxi->scales;
+
+        const int ksc = k % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q5_K_Q8_1_MMQ == 0);
+
+    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
+
+    const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
+    const int index_y = j * (QR5_K*WARP_SIZE)     + QR5_K*k;
+    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
 
-        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const float d8i = bq8i->d;
-        const int * q8 = (const int *)bq8i->qs + (iqs%4);
-        const int ui1 = q8[0];
-        const int ui2 = q8[4];
+    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
 
-        const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
-        const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
+    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
 
-        const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
-        const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+    const int vl = get_int_from_uint8(bq6_K->ql, iqs);
+    const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
 
-        sumf_d += d8i * (dot1 * sc[i]);
-        sumf_m += d8i * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
-    }
+    const int8_t * scales = bq6_K->scales + scale_offset;
 
-    return d*sumf_d - dmin*sumf_m;
+    int    u[QR6_K];
+    float d8[QR6_K];
 
-#else
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+        d8[i] = bq8_1[bq8_offset + 2*i].ds.x;
+    }
 
-    uint16_t aux16[2];
-    const uint8_t * s = (const uint8_t *)aux16;
+    return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
+}
 
-    const uint16_t * a = (const uint16_t *)bq4_K->scales;
-    aux16[0] = a[0] & 0x0f0f;
-    aux16[1] = (a[0] >> 4) & 0x0f0f;
+static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    const float dall = bq4_K->d[0];
-    const float dmin = bq4_K->d[1];
+    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
+    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K) + GGML_CUDA_MMQ_Y/QI6_K];
+    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)     + GGML_CUDA_MMQ_Y/8];
 
-    const float d8_1 = bq8_1[0].d;
-    const float d8_2 = bq8_1[1].d;
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
 
-    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
-    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
-    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
-    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    const int * q4 = (const int *)bq4_K->qs + iqs;
-    const int v1 = q4[0];
-    const int v2 = q4[4];
+    __builtin_assume(i_offset >= 0);
+    __builtin_assume(i_offset <  8);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
 
-    const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
-    const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
-    const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
-    const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
+    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI6_K; // == k if QK_K == 256
 
-    sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
-    sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+    const block_q6_K * bx0 = (block_q6_K *) vx;
 
-    return dall * sumf_d - dmin * sumf_m;
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+        int i = i0 + i_offset;
 
-#endif
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ky = QR6_K*kqsx;
 
-static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+        const int ql = get_int_from_uint8(bxi->ql, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
 
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
+        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
 
-#ifndef GGML_QKK_64
+        const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
+        const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
 
-    const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
-    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
-    const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+    }
 
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
+    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+    float * x_dmf = (float *) x_dm;
 
-    const float    d = bq5_K->d;
-    const float dmin = bq5_K->dmin;
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI6_K) {
+        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
 
-    const int vl1 = ql[0];
-    const int vl2 = ql[4];
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
-    const int vh1 = qh[0] >> bq8_offset;
-    const int vh2 = qh[4] >> bq8_offset;
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
-    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
-    uint16_t aux[2];
-    const int j = bq8_offset/2;
-    if (j < 2) {
-        aux[0] = scales[j+0] & 0x3f3f;
-        aux[1] = scales[j+2] & 0x3f3f;
-    } else {
-        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
-        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
     }
-    const uint8_t * sc = (const uint8_t *)aux;
-    const uint8_t * m  = sc + 2;
 
-    for (int i = 0; i < QR5_K; ++i) {
+#pragma unroll
+    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
 
-        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        const float d8i = bq8i->d;
-        const int * q8 = (const int *)bq8i->qs + (iqs%4);
-        const int ui1 = q8[0];
-        const int ui2 = q8[4];
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
-        const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F;
-        const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F;
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
 
-        const int vih1 = ((vh1 >> i) << 4) & 0x10101010;
-        const int vih2 = ((vh2 >> i) << 4) & 0x10101010;
+        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
+    }
+}
 
-        const int vi1 = vil1 | vih1;
-        const int vi2 = vil2 | vih2;
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-        const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
-        const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+    __builtin_assume(i >= 0);
+    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
+    __builtin_assume(j >= 0);
+    __builtin_assume(j <  WARP_SIZE);
+    __builtin_assume(k >= 0);
+    __builtin_assume(k <  WARP_SIZE);
+    __builtin_assume(k % VDR_Q6_K_Q8_1_MMQ == 0);
 
-        sumf_d += d8i * (dot1 * sc[i]);
-        sumf_m += d8i * (dot2 * m[i]);
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
 
-    }
+    const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
 
-    return d*sumf_d - dmin*sumf_m;
+    const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
+    const int index_y = j * (QR6_K*WARP_SIZE)     + QR6_K*k;
+    return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
+}
 
-#else
+template <int qk, int qr, int qi, bool need_sum, typename block_q_t,
+              allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
+static __global__ void mul_mat_q(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-    const int8_t * s = bq5_K->scales;
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
 
-    const float d = bq5_K->d;
+    const int blocks_per_row_x = ncols_x / qk;
+    const int blocks_per_col_y = nrows_y / QK8_1;
+    const int blocks_per_warp = WARP_SIZE / qi;
 
-    const float d8_1 = bq8_1[0].d;
-    const float d8_2 = bq8_1[1].d;
+    const int & ncols_dst = ncols_y;
 
-    const int ui1 = *((const int *)bq8_1[0].qs + iqs);
-    const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
-    const int ui3 = *((const int *)bq8_1[1].qs + iqs);
-    const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+    const int tid_x = threadIdx.x;
+    const int tid_y = threadIdx.y;
 
-    const int * ql = (const int *)bq5_K->qs + iqs;
-    const int vl1 = ql[0];
-    const int vl2 = ql[4];
+    const int row_dst_0 = blockIdx.x*GGML_CUDA_MMQ_Y;
+    const int & row_x_0 = row_dst_0;
+    const int row_dst = row_dst_0 + tid_x;
 
-    const int step = 4 * iqs; // 0, 4, 8, 12
-    const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
-    const int in = step%8; // 0, 4, 0, 4
-    const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+    const int col_dst_0 = blockIdx.y*WARP_SIZE;
+    const int & col_y_0 = col_dst_0;
 
-    const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
-    const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
-    const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
-    const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+    int   * tile_x_ql = nullptr;
+    half2 * tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
 
-    const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
-                       + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
+    allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
 
-    return d * sumf_d;
+    const int blocks_per_tile_y_col = qr*WARP_SIZE/QI8_1;
 
-#endif
+    __shared__ int    tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)];
+    __shared__ half2  tile_y_ds[(WARP_SIZE) * blocks_per_tile_y_col];
 
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
+    float sum[GGML_CUDA_MMQ_Y/WARP_SIZE][4] = {0.0f};
 
-static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+    for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
 
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+        load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
+                   tid_y, nrows_x-row_x_0-1, tid_x, blocks_per_row_x);
 
-    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
-    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
-    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+        for (int ir = 0; ir < qr; ++ir) {
+            const int kqs = ir*WARP_SIZE + tid_x;
+            const int kbxd = kqs / QI8_1;
 
-    float sumf = 0.0f;
+            for (int i = 0; i < WARP_SIZE; i += 8) {
+                const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
 
-    const float d = bq6_K->d;
+                const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
 
-    int vl;
-    memcpy(&vl, &bq6_K->ql[sizeof(int) * iqs], sizeof(int));
+                tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = get_int_from_int8_aligned(by0->qs, tid_x % QI8_1);
+            }
+        }
 
-    int vh;
-    memcpy(&vh, &bq6_K->qh[sizeof(int) * ((QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4))], sizeof(int));
+        for (int ids0 = 0; ids0 < WARP_SIZE; ids0 += 8 * (WARP_SIZE/blocks_per_tile_y_col)) {
+            const int ids = (ids0 + tid_y * (WARP_SIZE/blocks_per_tile_y_col) + tid_x / blocks_per_tile_y_col) % WARP_SIZE;
+            const int kby = tid_x % blocks_per_tile_y_col;
+            const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
 
-    for (int i = 0; i < QR6_K; ++i) {
-        const int sc = bq6_K->scales[scale_offset + 4*i];
+            // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+            const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds;
+            half2       * dsi_dst = &tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby];
+            if (need_sum) {
+                *dsi_dst = *dsi_src;
+            } else {
+                float * dfi_dst = (float *) dsi_dst;
+                *dfi_dst = (*dsi_src).x;
+            }
+        }
 
-        const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i;
-        const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]);
-        const float d8i = bq8i->d;
+        __syncthreads();
 
-        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+#if __CUDA_ARCH__ >= 700 // Unrolling the loop is slower on Pascal
+#pragma unroll
+#endif // __CUDA_ARCH__ >= 700
+        for (int k = 0; k < WARP_SIZE; k += vdr) {
+#pragma unroll
+            for (int j = 0; j < WARP_SIZE; j += 8) {
+#pragma unroll
+                for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
+                    sum[i/WARP_SIZE][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
+                                                     tid_x + i, tid_y + j, k);
+                }
+            }
+        }
 
-        const int vih = ((vh >> (vh_shift + 4*i)) << 4) & 0x30303030;
+        __syncthreads();
+    }
 
-        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
 
-        sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
+    if (row_dst >= nrows_dst) {
+        return;
     }
 
-    return d*sumf;
-#else
-    return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+    for (int j = 0; j < WARP_SIZE; j += 8) {
+        const int col_dst = col_dst_0 + j + tid_y;
+
+        if (col_dst >= ncols_dst) {
+            return;
+        }
+
+        for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
+            dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/8];
+        }
+    }
 }
 
-template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
+template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
 static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
 
@@ -1813,7 +3326,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = WARP_SIZE / qi;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
 
 // partial sum for each thread
     float tmp = 0.0f;
@@ -1822,11 +3335,11 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
     for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
-        const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
+        const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index
 
-        const int iby = (i + threadIdx.x / qi) * qk/QK8_1; // y block index that aligns with ibx
+        const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx
 
-        const int iqs  = threadIdx.x % qi; // x block quant index when casting the quants to int
+        const int iqs  = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
 
         tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
     }
@@ -1859,11 +3372,11 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
     const int y_offset = qr == 1 ? 1 : qk/2;
 
 // partial sum for each thread
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
     half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
 #else
     float tmp = 0.0f;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
 
     for (int i = 0; i < ncols; i += iter_stride) {
         const int col = i + vals_per_iter*tid;
@@ -1883,7 +3396,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
 
             // matrix multiplication
             // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
             tmp += __hmul2(v, {
                 y[iybs + iqs + j/qr + 0],
                 y[iybs + iqs + j/qr + y_offset]
@@ -1891,7 +3404,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
 #else
             tmp += v.x * y[iybs + iqs + j/qr + 0];
             tmp += v.y * y[iybs + iqs + j/qr + y_offset];
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
         }
     }
 
@@ -1902,11 +3415,11 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
     }
 
     if (tid == 0) {
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
         dst[row] = tmp.x + tmp.y;
 #else
         dst[row] = tmp;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
     }
 }
 
@@ -2046,7 +3559,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
 }
 
 // rope == RoPE == rotary positional embedding
-static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
+static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
+                                const float p_delta, const int p_delta_rows, const float theta_scale) {
     const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
 
     if (col >= ncols) {
@@ -2056,7 +3570,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     const int row = blockDim.y*blockIdx.y + threadIdx.y;
     const int i = row*ncols + col;
 
-    const float theta = p*powf(theta_scale, col/2);
+    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
     const float sin_theta = sinf(theta);
     const float cos_theta = cosf(theta);
 
@@ -2203,9 +3717,11 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
     rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
 }
 
-static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
-    quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, ndata, k);
+static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
+    const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, ky, 1);
+    const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
+    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
 }
 
 static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -2366,7 +3882,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, vec_dot_q4_0_q8_1>
+    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2375,7 +3891,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, vec_dot_q4_1_q8_1>
+    mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2384,7 +3900,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, vec_dot_q5_0_q8_1>
+    mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2393,7 +3909,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, vec_dot_q5_1_q8_1>
+    mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2402,7 +3918,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, vec_dot_q8_0_q8_1>
+    mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2411,7 +3927,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI2_K, block_q2_K, vec_dot_q2_K_q8_1>
+    mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2420,7 +3936,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI3_K, block_q3_K, vec_dot_q3_K_q8_1>
+    mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2429,10 +3945,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per
-    //       kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
-    //       is better amortized.
-    mul_mat_vec_q<QK_K, QI4_K/2, block_q4_K, vec_dot_q4_K_q8_1>
+    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2441,10 +3954,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per
-    //       kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
-    //       is better amortized.
-    mul_mat_vec_q<QK_K, QI5_K/2, block_q5_K, vec_dot_q5_K_q8_1>
+    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2453,7 +3963,7 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(1, block_num_y, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK_K, QI6_K, block_q6_K, vec_dot_q6_K_q8_1>
+    mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
@@ -2500,6 +4010,186 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     }
 }
 
+static void ggml_mul_mat_q4_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q4_1_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<false>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<true>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q5_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<false>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<true>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q5_1_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<false>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<true>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q8_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<false>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<true>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q2_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<false>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<true>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q3_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<false>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<true>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q4_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<false>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<true>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q5_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<false>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<true>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+static void ggml_mul_mat_q6_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+
+    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
+        mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<false>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<true>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
 static void ggml_mul_mat_p021_f16_f32_cuda(
     const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
     const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
@@ -2544,12 +4234,13 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
 }
 
-static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
+static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
+                          const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
     GGML_ASSERT(nrows % 2 == 0);
     const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
+    rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
 }
 
 static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -2676,10 +4367,9 @@ static size_t g_scratch_offset = 0;
 
 static int g_device_count = -1;
 static int g_main_device = 0;
-#ifndef GGML_CUDA_FORCE_DMMV
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
-#endif
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
+static bool g_mul_mat_q = false;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
@@ -2701,9 +4391,7 @@ void ggml_init_cublas() {
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
 
-#ifndef GGML_CUDA_FORCE_DMMV
             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
-#endif
         }
         for (int id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
@@ -2965,6 +4653,83 @@ inline void ggml_cuda_op_rms_norm(
     (void) i1;
 }
 
+inline void ggml_cuda_op_mul_mat_q(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddq_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    const int64_t i01_diff = i01_high - i01_low;
+
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
+    const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
+
+    const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
+        ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
+    size_t as;
+    void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*sizeof(block_q8_1)/QK8_1, &as);
+    quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, cudaStream_main);
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q4_1:
+            ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_0:
+            ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_1:
+            ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q8_0:
+            ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q2_K:
+            ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q3_K:
+            ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q4_K:
+            ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_K:
+            ggml_mul_mat_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        case GGML_TYPE_Q6_K:
+            ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+
+    ggml_cuda_pool_free(src1_q8_1, as);
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
 inline void ggml_cuda_op_mul_mat_vec(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2979,6 +4744,7 @@ inline void ggml_cuda_op_mul_mat_vec(
 
 #ifdef GGML_CUDA_FORCE_DMMV
     const bool use_mul_mat_vec_q = false;
+    (void) g_compute_capabilities[0];
 #else
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
@@ -3006,7 +4772,7 @@ inline void ggml_cuda_op_mul_mat_vec(
             ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
         size_t as;
         void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
-        quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
+        quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main);
 
         switch (src0->type) {
             case GGML_TYPE_Q4_0:
@@ -3047,7 +4813,7 @@ inline void ggml_cuda_op_mul_mat_vec(
         ggml_cuda_pool_free(src1_q8_1, as);
     } else {
         // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
         size_t ash;
         dfloat * src1_dfloat = nullptr; // dfloat == half
 
@@ -3063,7 +4829,7 @@ inline void ggml_cuda_op_mul_mat_vec(
         }
 #else
         dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
 
         switch (src0->type) {
             case GGML_TYPE_Q4_0:
@@ -3104,11 +4870,11 @@ inline void ggml_cuda_op_mul_mat_vec(
                 break;
         }
 
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
         if (src1_convert_f16) {
             ggml_cuda_pool_free(src1_dfloat, ash);
         }
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
     }
 
     (void) src1;
@@ -3168,6 +4934,7 @@ inline void ggml_cuda_op_rope(
     GGML_ASSERT(dst_ddf_i != nullptr);
 
     const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
     const int64_t i01_diff = i01_high - i01_low;
 
     const int n_past = ((int32_t *) dst->op_params)[0];
@@ -3181,17 +4948,18 @@ inline void ggml_cuda_op_rope(
     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
-    const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
 
-    bool is_glm = mode & 4;
+    const bool is_glm = mode & 4;
 
     // compute
     if (is_glm) {
+        const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
         const float id_p = min(p, n_ctx - 2.f);
         const float block_p = max(p - (n_ctx - 2.f), 0.f);
         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
     } else {
-        rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+        const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
+        rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
     }
 
     (void) src1;
@@ -3363,7 +5131,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         int64_t row_low, row_high;
         if (split) {
             row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
-            row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
+            row_low -= row_low % GGML_CUDA_MMQ_Y;
+
+            if (id == g_device_count - 1) {
+                row_high = nrows0;
+            } else {
+                row_high = nrows0*g_tensor_split[id + 1];
+                row_high -= row_high % GGML_CUDA_MMQ_Y;
+            }
         } else {
             row_low = 0;
             row_high = nrows0*i02_divisor;
@@ -3529,13 +5304,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                     if (split) {
                         // src0 = weight matrix is saved as a transposed matrix for better memory layout.
                         // dst is NOT transposed.
-                        // The outputs of cuBLAS matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
                         // Instead they need to be copied to the correct slice in ne0 = dst row index.
                         // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
-                        for (int64_t j = 0; j < ne1; ++j) {
-                            float * dhf_dst_i = (float *) ((char *) dst_off_device + (j*ne0 + i01_low)*sizeof(float) + i02*nb2 + i03*nb3);
-                            CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i + j*i01_diff, i01_diff*sizeof(float), kind, cudaStream_main));
-                        }
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i01_low*sizeof(float) + i02*nb2 + i03*nb3);
+                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_ddf_i, i01_diff*sizeof(float),
+                                                     i01_diff*sizeof(float), ne1, kind, cudaStream_main));
                     } else {
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
                         CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
@@ -3576,7 +5350,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     if (split && g_device_count > 1) {
         CUDA_CHECK(cudaSetDevice(g_main_device));
         for (int id = 0; id < g_device_count; ++id) {
-            if (id != g_main_device) {
+            if (id != g_main_device && src0_extra->events[id]) {
                 CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
             }
         }
@@ -3718,7 +5492,19 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
         if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
             ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
         } else {
-            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+            int min_compute_capability = INT_MAX;
+            for (int id = 0; id < g_device_count; ++id) {
+                if (min_compute_capability > g_compute_capabilities[id]
+                        && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+                    min_compute_capability = g_compute_capabilities[id];
+                }
+            }
+
+            if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
+                ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
+            } else {
+                ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+            }
         }
     } else {
         GGML_ASSERT(false);
@@ -3795,7 +5581,10 @@ void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml
 
 void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
+
+    const int mode = ((int32_t *) dst->op_params)[2];
+    const bool is_glm = mode & 4;
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
 }
 
 void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3828,7 +5617,14 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
             row_high = nrows;
         } else if (backend == GGML_BACKEND_GPU_SPLIT) {
             row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
-            row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
+            row_low -= row_low % GGML_CUDA_MMQ_Y;
+
+            if (id == g_device_count - 1) {
+                row_high = nrows;
+            } else {
+                row_high = nrows*g_tensor_split[id + 1];
+                row_high -= row_high % GGML_CUDA_MMQ_Y;
+            }
         } else {
             GGML_ASSERT(false);
         }
@@ -4002,6 +5798,10 @@ void ggml_cuda_set_main_device(int main_device) {
     }
 }
 
+void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
+    g_mul_mat_q = mul_mat_q;
+}
+
 void ggml_cuda_set_scratch_size(size_t scratch_size) {
     g_scratch_size = scratch_size;
 }
index 3c1e8deb6a6ddbfbb43e22bd01cc64e560fe313d..72d7afa463d741498af0063f0ccf14e5b9028bf9 100644 (file)
@@ -27,6 +27,7 @@ void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
 void   ggml_cuda_set_main_device(int main_device);
+void   ggml_cuda_set_mul_mat_q(bool mul_mat_q);
 void   ggml_cuda_set_scratch_size(size_t scratch_size);
 void   ggml_cuda_free_scratch(void);
 bool   ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
index 74a6bff40411784f2b13ca4c1a7bf607bfc400c4..b47a98e214b613fb0f022f8606047e6771201ddd 100644 (file)
@@ -7,6 +7,11 @@
 #import <Metal/Metal.h>
 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
 
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
 #ifdef GGML_METAL_NDEBUG
 #define metal_printf(...)
 #else
@@ -15,6 +20,8 @@
 
 #define UNUSED(x) (void)(x)
 
+#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
+
 struct ggml_metal_buffer {
     const char * name;
 
@@ -36,7 +43,7 @@ struct ggml_metal_context {
     int n_buffers;
     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
 
-    int concur_list[GGML_MAX_NODES];
+    int concur_list[GGML_MAX_CONCUR];
     int concur_list_len;
 
     // custom kernels
@@ -370,15 +377,15 @@ void ggml_metal_graph_find_concurrency(
         struct ggml_metal_context * ctx,
         struct ggml_cgraph * gf) {
     int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
-    int nodes_unused[GGML_MAX_NODES];
+    int nodes_unused[GGML_MAX_CONCUR];
 
-    for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
-    for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
+    for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
+    for (int i = 0; i < gf->n_nodes;     i++) { nodes_unused[i]     = 1; }
     ctx->concur_list_len = 0;
 
-    int n_left = gf->n_nodes;
-    int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
-    int level_pos = 0;  // at ctx->concur_list, the last layer (level) ends at level_pos
+    int n_left    = gf->n_nodes;
+    int n_start   = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
+    int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
 
     while (n_left > 0) {
         // number of nodes at a layer (that can be issued concurrently)
@@ -386,28 +393,40 @@ void ggml_metal_graph_find_concurrency(
         for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
             if (nodes_unused[i]) {
                 // if the requirements for gf->nodes[i] are satisfied
-                int exe_flag=1;
+                int exe_flag = 1;
+
                 // scan all srcs
                 for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
                     struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
                     if (src_cur) {
                         // if is leaf nodes it's satisfied.
-                        if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
+                        // TODO: ggml_is_leaf()
+                        if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
+                            continue;
+                        }
 
                         // otherwise this src should be the output from previous nodes.
                         int is_found = 0;
+
                         // scan 2*search_depth back because we inserted barrier.
-                        for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
-                            if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
+                        //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
+                        for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
+                            if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
+                                is_found = 1;
+                                break;
+                            }
+                        }
+                        if (is_found == 0) {
+                            exe_flag = 0;
+                            break;
                         }
-                        if (is_found == 0) {exe_flag = 0; break;}
                     }
                 }
                 if (exe_flag) {
                     // check if nodes[i]'s data will be overwritten by a node before nodes[i].
                     // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
                     int64_t data_start = (int64_t) gf->nodes[i]->data;
-                    int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
+                    int64_t length     = (int64_t) ggml_nbytes(gf->nodes[i]);
                     for (int j = n_start; j < i; j++) {
                         if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
                                             && gf->nodes[j]->op != GGML_OP_VIEW \
@@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency(
                             if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
                                 ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
                                 continue;
-                            } else {
-                                exe_flag = 0;
                             }
+
+                            exe_flag = 0;
                         }
                     }
                 }
@@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency(
         ctx->concur_list[level_pos + concurrency] = -1;
         ctx->concur_list_len++;
         // jump all sorted nodes at nodes_bak
-        while (!nodes_unused[n_start]) {n_start++;}
+        while (!nodes_unused[n_start]) {
+            n_start++;
+        }
         level_pos += concurrency + 1;
     }
 
-    if (ctx->concur_list_len > GGML_MAX_NODES) {
+    if (ctx->concur_list_len > GGML_MAX_CONCUR) {
         fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
     }
 }
@@ -453,7 +474,7 @@ void ggml_metal_graph_compute(
     // else fallback to serial dispatch
     MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
 
-    const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
+    const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
 
     const int n_nodes  = has_concur ? ctx->concur_list_len      : gf->n_nodes;
     edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
@@ -718,7 +739,8 @@ void ggml_metal_graph_compute(
                             // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
 
                             GGML_ASSERT(ne00 == ne10);
-                            GGML_ASSERT(ne02 == ne12);
+                            // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
+                            GGML_ASSERT(ne03 == ne13);
 
                             if (ggml_is_contiguous(src0) &&
                                 ggml_is_contiguous(src1) &&
@@ -746,11 +768,11 @@ void ggml_metal_graph_compute(
                                     initWithDevice:ctx->device transposeLeft:false transposeRight:true
                                         resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
 
-                                // we need to do ne02 multiplications
+                                // we need to do ne12 multiplications
                                 // TODO: is there a way to do this in parallel - currently very slow ..
                                 // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
-                                for (int64_t i02 = 0; i02 < ne02; ++i02) {
-                                    size_t offs_src0_cur = offs_src0 + i02*nb02;
+                                for (int64_t i02 = 0; i02 < ne12; ++i02) {
+                                    size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
                                     size_t offs_src1_cur = offs_src1 + i02*nb12;
                                     size_t offs_dst_cur  = offs_dst  + i02*nb2;
 
@@ -772,8 +794,6 @@ void ggml_metal_graph_compute(
                                 switch (src0t) {
                                     case GGML_TYPE_F16:
                                         {
-                                            GGML_ASSERT(ne02 == ne12);
-
                                             nth0 = 64;
                                             nth1 = 1;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@@ -853,16 +873,18 @@ void ggml_metal_graph_compute(
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
                                 [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
                                 [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
-                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
-                                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
-                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
-                                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
-                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
-                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
-                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
-                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
-                                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
+                                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
+                                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
+                                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
+                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
+                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
+                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
+                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
+                                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
index 696b33ce75cf4fa8d92850ef5e762c32190d6410..8d26b5ec2dfa4f649b7a6d478f73238a0c700da6 100644 (file)
@@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
+        constant   int64_t & ne02,
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
+        constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
@@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
     const int64_t r1 = tgpig.y;
     const int64_t im = tgpig.z;
 
-    device const half  * x = (device const half  *) (src0 + r0*nb01 + im*nb02);
+    device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
     sum[tpitg.x] = 0.0f;
@@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32(
     }
 }
 
+
 kernel void kernel_alibi_f32(
         device const float * src0,
         device       float * dst,
index 36105634ef7a72613dd7c3a3db9275c8cdcabe5f..9c4b49db86be9856945261f9c4965fb47fce7ac0 100644 (file)
@@ -195,8 +195,8 @@ typedef void * thread_ret_t;
 #define GGML_ALIGNED_MALLOC(size)  _aligned_malloc(size, GGML_MEM_ALIGN)
 #define GGML_ALIGNED_FREE(ptr)     _aligned_free(ptr)
 #else
-inline static void* ggml_aligned_malloc(size_t size) {
-    void* aligned_memory = NULL;
+inline static void * ggml_aligned_malloc(size_t size) {
+    void * aligned_memory = NULL;
 #ifdef GGML_USE_METAL
     int result = posix_memalign(&aligned_memory, getpagesize(), size);
 #else
@@ -4071,8 +4071,8 @@ bool ggml_is_numa(void) {
 ////////////////////////////////////////////////////////////////////////////////
 
 void ggml_print_object(const struct ggml_object * obj) {
-    GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
-            obj->offs, obj->size, (const void *) obj->next);
+    GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
+            obj->type, obj->offs, obj->size, (const void *) obj->next);
 }
 
 void ggml_print_objects(const struct ggml_context * ctx) {
@@ -4212,7 +4212,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
 }
 
 size_t ggml_tensor_overhead(void) {
-    return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
+    return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;
 }
 
 bool ggml_is_transposed(const struct ggml_tensor * tensor) {
@@ -4383,7 +4383,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         return NULL;
     }
 
-    const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
+    const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
 
     *ctx = (struct ggml_context) {
         /*.mem_size           =*/ mem_size,
@@ -4472,12 +4472,14 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
     struct ggml_object * obj = ctx->objects_begin;
 
     while (obj != NULL) {
-        struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
+        if (obj->type == GGML_OBJECT_TENSOR) {
+            struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
 
-        const size_t size = ggml_nbytes(tensor);
+            const size_t size = ggml_nbytes(tensor);
 
-        if (max_size < size) {
-            max_size = size;
+            if (max_size < size) {
+                max_size = size;
+            }
         }
 
         obj = obj->next;
@@ -4509,12 +4511,7 @@ static void ggml_scratch_load(struct ggml_context * ctx) {
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static struct ggml_tensor * ggml_new_tensor_impl(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int    n_dims,
-        const int64_t* ne,
-        void*  data) {
+static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
     // always insert objects at the end of the context's memory pool
     struct ggml_object * obj_cur = ctx->objects_end;
 
@@ -4522,77 +4519,81 @@ static struct ggml_tensor * ggml_new_tensor_impl(
     const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
     const size_t cur_end  = cur_offs + cur_size;
 
-    size_t size_needed = 0;
-
-    if (data == NULL && !ctx->no_alloc) {
-        size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
-        for (int i = 1; i < n_dims; i++) {
-            size_needed *= ne[i];
-        }
-        // align to GGML_MEM_ALIGN
-        size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
-    }
+    // align to GGML_MEM_ALIGN
+    size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
 
     char * const mem_buffer = ctx->mem_buffer;
     struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
 
-    if (ctx->scratch.data == NULL || data != NULL) {
-        size_needed += GGML_TENSOR_SIZE;
+    if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+        GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
+                __func__, cur_end + size_needed, ctx->mem_size);
+        assert(false);
+        return NULL;
+    }
 
-        if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
-            GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
-                    __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
-            assert(false);
-            return NULL;
-        }
+    *obj_new = (struct ggml_object) {
+        .offs = cur_end + GGML_OBJECT_SIZE,
+        .size = size_needed,
+        .next = NULL,
+        .type = type,
+    };
 
-        *obj_new = (struct ggml_object) {
-            .offs = cur_end + GGML_OBJECT_SIZE,
-            .size = size_needed,
-            .next = NULL,
-        };
+    ggml_assert_aligned(mem_buffer + obj_new->offs);
+
+    if (obj_cur != NULL) {
+        obj_cur->next = obj_new;
     } else {
-        if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
-            GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
-                    __func__, ctx->scratch.offs + size_needed, ctx->scratch.size);
-            assert(false);
-            return NULL;
+        // this is the first object in this context
+        ctx->objects_begin = obj_new;
+    }
+
+    ctx->objects_end = obj_new;
+
+    //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
+
+    return obj_new;
+}
+
+static struct ggml_tensor * ggml_new_tensor_impl(
+        struct ggml_context * ctx,
+        enum   ggml_type      type,
+        int                   n_dims,
+        const int64_t       * ne,
+        void                * data) {
+
+    assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
+
+    size_t data_size = 0;
+
+    if (data == NULL && !ctx->no_alloc) {
+        data_size += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
+        for (int i = 1; i < n_dims; i++) {
+            data_size *= ne[i];
         }
+    }
 
-        if (cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE > ctx->mem_size) {
-            GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
-                    __func__, cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE, ctx->mem_size);
+    if (ctx->scratch.data != NULL && data == NULL) {
+        // allocate tensor data in the scratch buffer
+        if (ctx->scratch.offs + data_size > ctx->scratch.size) {
+            GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
+                    __func__, ctx->scratch.offs + data_size, ctx->scratch.size);
             assert(false);
             return NULL;
         }
 
         data = (char * const) ctx->scratch.data + ctx->scratch.offs;
 
-        *obj_new = (struct ggml_object) {
-            .offs = cur_end + GGML_OBJECT_SIZE,
-            .size = GGML_TENSOR_SIZE,
-            .next = NULL,
-        };
-
-        //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed);
+        ctx->scratch.offs += data_size;
 
-        ctx->scratch.offs += size_needed;
+        data_size = 0;
     }
 
-    if (obj_cur != NULL) {
-        obj_cur->next = obj_new;
-    } else {
-        // this is the first object in this context
-        ctx->objects_begin = obj_new;
-    }
-
-    ctx->objects_end = obj_new;
-
-    //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
+    struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + data_size);
 
-    struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs);
+    // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
 
-    ggml_assert_aligned(result);
+    struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
 
     *result = (struct ggml_tensor) {
         /*.type         =*/ type,
@@ -4601,7 +4602,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
         /*.op           =*/ GGML_OP_NONE,
-        /*.op_params    =*/ {0},
+        /*.op_params    =*/ { 0 },
         /*.is_param     =*/ false,
         /*.grad         =*/ NULL,
         /*.src          =*/ { NULL },
@@ -4632,12 +4633,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
     return result;
 }
 
-static void ggml_get_op_params(const struct ggml_tensor * tensor, void * params, size_t params_size) {
-    assert(params_size <= GGML_MAX_OP_PARAMS);
-    memcpy(params, tensor->op_params, params_size);
-}
-
 static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
+    GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings
     assert(params_size <= GGML_MAX_OP_PARAMS);
     memcpy(tensor->op_params, params, params_size);
 }
@@ -4654,22 +4651,22 @@ static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int3
 
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
-        enum   ggml_type type,
-        int    n_dims,
-        const int64_t * ne) {
+        enum   ggml_type      type,
+        int                   n_dims,
+        const int64_t       * ne) {
     return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
 }
 
 struct ggml_tensor * ggml_new_tensor_1d(
         struct ggml_context * ctx,
-        enum   ggml_type type,
+        enum   ggml_type      type,
         int64_t ne0) {
     return ggml_new_tensor(ctx, type, 1, &ne0);
 }
 
 struct ggml_tensor * ggml_new_tensor_2d(
         struct ggml_context * ctx,
-        enum   ggml_type type,
+        enum   ggml_type      type,
         int64_t ne0,
         int64_t ne1) {
     const int64_t ne[2] = { ne0, ne1 };
@@ -4678,7 +4675,7 @@ struct ggml_tensor * ggml_new_tensor_2d(
 
 struct ggml_tensor * ggml_new_tensor_3d(
         struct ggml_context * ctx,
-        enum   ggml_type type,
+        enum   ggml_type      type,
         int64_t ne0,
         int64_t ne1,
         int64_t ne2) {
@@ -4988,11 +4985,6 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
     return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
 }
 
-static void ggml_set_unary_op(struct ggml_tensor * tensor, enum ggml_unary_op op) {
-    GGML_ASSERT(tensor->op = GGML_OP_UNARY);
-    ggml_set_op_params_i32(tensor, 0, (int32_t) op);
-}
-
 const char * ggml_get_name(const struct ggml_tensor * tensor) {
     return tensor->name;
 }
@@ -5031,9 +5023,11 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam
     char * const mem_buffer = ctx->mem_buffer;
 
     while (obj != NULL) {
-        struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
-        if (strcmp(cur->name, name) == 0) {
-            return cur;
+        if (obj->type == GGML_OBJECT_TENSOR) {
+            struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
+            if (strcmp(cur->name, name) == 0) {
+                return cur;
+            }
         }
 
         obj = obj->next;
@@ -6247,6 +6241,27 @@ struct ggml_tensor * ggml_reshape_4d(
 
 // ggml_view_1d
 
+static struct ggml_tensor * ggml_view_tensor_offset(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_dims,
+        const int64_t       * ne,
+        size_t                offset) {
+    // don't calculate an offset from an unallocated tensor
+    void * data = NULL;
+    if (a->data != NULL) {
+        data = (char *) a->data + offset;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, data);
+
+    ggml_format_name(result, "%s (view)", a->name);
+
+    ggml_set_op_params(result, &offset, sizeof(offset));
+
+    return result;
+}
+
 struct ggml_tensor * ggml_view_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -6259,10 +6274,7 @@ struct ggml_tensor * ggml_view_1d(
         is_node = true;
     }
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
-    ggml_format_name(result, "%s (view)", a->name);
-
-    ggml_set_op_params(result, &offset, sizeof(offset));
+    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 1, &ne0, offset);
 
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6289,10 +6301,7 @@ struct ggml_tensor * ggml_view_2d(
 
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
-    ggml_format_name(result, "%s (view)", a->name);
-
-    ggml_set_op_params(result, &offset, sizeof(offset));
+    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 2, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = result->nb[1]*ne1;
@@ -6325,10 +6334,7 @@ struct ggml_tensor * ggml_view_3d(
 
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
-    ggml_format_name(result, "%s (view)", a->name);
-
-    ggml_set_op_params(result, &offset, sizeof(offset));
+    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 3, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -6363,10 +6369,7 @@ struct ggml_tensor * ggml_view_4d(
 
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 };
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
-    ggml_format_name(result, "%s (view)", a->name);
-
-    ggml_set_op_params(result, &offset, sizeof(offset));
+    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 4, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -6437,7 +6440,7 @@ struct ggml_tensor * ggml_permute(
     result->src[0] = a;
 
     int32_t params[] = { axis0, axis1, axis2, axis3 };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     return result;
 }
@@ -6563,7 +6566,7 @@ static struct ggml_tensor * ggml_diag_mask_inf_impl(
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     int32_t params[] = { n_past, inplace ? 1 : 0 };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_DIAG_MASK_INF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6603,7 +6606,7 @@ static struct ggml_tensor * ggml_diag_mask_zero_impl(
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     int32_t params[] = { n_past, inplace ? 1 : 0 };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_DIAG_MASK_ZERO;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6719,9 +6722,9 @@ static struct ggml_tensor * ggml_rope_impl(
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     int32_t params[6] = { n_past, n_dims, mode, n_ctx };
-    memcpy(params + 4, &freq_base, sizeof(float));
+    memcpy(params + 4, &freq_base,  sizeof(float));
     memcpy(params + 5, &freq_scale, sizeof(float));
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6750,6 +6753,18 @@ struct ggml_tensor * ggml_rope_inplace(
     return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
 }
 
+struct ggml_tensor * ggml_rope_custom(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx,
+        float                 freq_base,
+        float                 freq_scale) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false);
+}
+
 struct ggml_tensor * ggml_rope_custom_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -6783,7 +6798,7 @@ struct ggml_tensor * ggml_rope_back(
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
     int32_t params[] = { n_past, n_dims, mode, n_ctx };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6814,7 +6829,7 @@ struct ggml_tensor * ggml_alibi(
 
     int32_t op_params[3] = { n_past, n_head };
     memcpy(op_params + 2, &bias_max, sizeof(float));
-    ggml_set_op_params(result, &op_params, sizeof(op_params));
+    ggml_set_op_params(result, op_params, sizeof(op_params));
 
     result->op   = GGML_OP_ALIBI;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6841,7 +6856,7 @@ struct ggml_tensor * ggml_clamp(
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
     float params[] = { min, max };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_CLAMP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6876,11 +6891,10 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
         a->ne[2], 1, 1,
     };
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
     int32_t params[] = { s0, p0, d0 };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op = GGML_OP_CONV_1D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6893,9 +6907,9 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
 // ggml_conv_2d
 
 struct ggml_tensor * ggml_conv_2d(
-    struct ggml_context* ctx,
-    struct ggml_tensor * a,
-    struct ggml_tensor * b,
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    struct ggml_tensor  * b,
     int                  s0,
     int                  s1,
     int                  p0,
@@ -6916,11 +6930,10 @@ struct ggml_tensor * ggml_conv_2d(
         ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
         a->ne[3], b->ne[3],
     };
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { s0, s1, p0, p1, d0, d1 };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op = GGML_OP_CONV_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6970,11 +6983,10 @@ struct ggml_tensor * ggml_pool_1d(
         ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
         a->ne[1],
     };
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
     int32_t params[] = { op, k0, s0, p0 };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op = GGML_OP_POOL_1D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7008,11 +7020,10 @@ struct ggml_tensor * ggml_pool_2d(
         ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
         a->ne[2],
     };
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
 
     int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op = GGML_OP_POOL_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7180,7 +7191,7 @@ struct ggml_tensor * ggml_win_part(
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { npx, npy, w };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_WIN_PART;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7210,7 +7221,7 @@ struct ggml_tensor * ggml_win_unpart(
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
 
     int32_t params[] = { w };
-    ggml_set_op_params(result, &params, sizeof(params));
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_WIN_UNPART;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7234,7 +7245,7 @@ static struct ggml_tensor * ggml_unary_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_set_unary_op(result, op);
+    ggml_set_op_params_i32(result, 0, (int32_t) op);
 
     result->op   = GGML_OP_UNARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -9457,8 +9468,8 @@ static void ggml_compute_forward_sum_rows_f32(
     for (int64_t i3 = 0; i3 < ne03; i3++) {
         for (int64_t i2 = 0; i2 < ne02; i2++) {
             for (int64_t i1 = 0; i1 < ne01; i1++) {
-                float* src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
-                float* dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
+                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
+                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
                 float row_sum = 0;
                 ggml_vec_sum_f32(ne00, &row_sum, src_row);
                 dst_row[0] = row_sum;
@@ -13068,7 +13079,7 @@ static void ggml_compute_forward_pool_1d(
         const struct ggml_tensor * src0,
               struct ggml_tensor * dst) {
 
-    const int32_t* opts = (const int32_t*)dst->op_params;
+    const int32_t * opts = (const int32_t *)dst->op_params;
     enum ggml_op_pool op = opts[0];
     const int k0 = opts[1];
     const int s0 = opts[2];
@@ -16032,6 +16043,35 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
     return result;
 }
 
+struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE);
+    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
+
+    *cgraph = (struct ggml_cgraph) {
+        /*.n_nodes      =*/ 0,
+        /*.n_leafs      =*/ 0,
+        /*.nodes        =*/ { NULL },
+        /*.grads        =*/ { NULL },
+        /*.leafs        =*/ { NULL },
+        /*.hash_table   =*/ { NULL },
+        /*.perf_runs    =*/ 0,
+        /*.perf_cycles  =*/ 0,
+        /*.perf_time_us =*/ 0,
+    };
+
+    return cgraph;
+}
+
+struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) {
+    struct ggml_cgraph * cgraph = ggml_new_graph(ctx);
+    ggml_build_forward_impl(cgraph, tensor, false);
+    return cgraph;
+}
+
+size_t ggml_graph_overhead(void) {
+    return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN);
+}
+
 //
 // thread data
 //
@@ -16774,10 +16814,9 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
 void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
     struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
 
-    struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
-    GGML_ASSERT(buf);
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
 
-    cplan.work_data = buf->data;
+    cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
 
     ggml_graph_compute(cgraph, &cplan);
 }
index 6a8ce8fe7fd47583546db680bfa774cd53858506..4b4d8573ba1eaa7bec1c538810dc7f4e403c0e72 100644 (file)
@@ -156,7 +156,7 @@ endif()
 # test-grad0
 
 set(TEST_TARGET test-grad0)
-add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml)
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
@@ -165,7 +165,7 @@ set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_
 # test-opt
 
 set(TEST_TARGET test-opt)
-add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml)
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
diff --git a/tests/test-grad0.c b/tests/test-grad0.c
deleted file mode 100644 (file)
index 6d31221..0000000
+++ /dev/null
@@ -1,1525 +0,0 @@
-#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
-#include "ggml.h"
-
-#include <math.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <assert.h>
-
-#if defined(_MSC_VER)
-#pragma warning(disable: 4244 4267) // possible loss of data
-#endif
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wdouble-promotion"
-#endif
-
-#define MAX_NARGS 3
-
-#undef MIN
-#undef MAX
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
-#define GGML_SILU_FP16
-
-//
-// logging
-//
-
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#if (GGML_DEBUG >= 5)
-#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_5(...)
-#endif
-
-#if (GGML_DEBUG >= 10)
-#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_10(...)
-#endif
-
-#define GGML_PRINT(...) printf(__VA_ARGS__)
-
-float frand(void) {
-    return (float)rand()/(float)RAND_MAX;
-}
-
-int irand(int n) {
-    if (n == 0) return 0;
-    return rand()%n;
-}
-
-void get_random_dims(int64_t * dims, int ndims) {
-    dims[0] = dims[1] = dims[2] = dims[3] = 1;
-
-    for (int i = 0; i < ndims; i++) {
-        dims[i] = 1 + irand(4);
-    }
-}
-
-struct ggml_tensor * get_random_tensor_f32(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        float fmin,
-        float fmax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    };
-
-    return result;
-}
-
-struct ggml_tensor * get_random_tensor_f16(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        float fmin,
-        float fmax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    };
-
-    return result;
-}
-
-struct ggml_tensor * get_random_tensor_i32(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        int32_t imin,
-        int32_t imax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((int32_t *)result->data)[i0] = irand(imax - imin) + imin;
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin;
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    };
-
-    return result;
-}
-
-void print_elements(const char* label, const struct ggml_tensor * t) {
-    if (!t) {
-        printf("%s: %s = null\n", __func__, label);
-        return;
-    }
-    const int nelements = ggml_nelements(t);
-    printf("%s: %s = [", __func__, label);
-    for (int k = 0; k < nelements; ++k) {
-        if (k > 0) { printf(", "); }
-        printf("%.5f", ggml_get_f32_1d(t, k));
-    }
-    printf("] shape: [");
-    for (int k = 0; k < t->n_dims; ++k) {
-        if (k > 0) { printf(", "); }
-        printf("%d", (int)t->ne[k]);
-    }
-    printf("]\n");
-
-}
-
-bool check_gradient(
-        const char * op_name,
-        struct ggml_context * ctx0,
-        struct ggml_tensor * x[],
-        struct ggml_tensor * f,
-        int ndims,
-        int nargs,
-        float eps,
-        float max_error_abs,
-        float max_error_rel) {
-
-    static int n_threads = -1;
-    if (n_threads < 0) {
-        n_threads = GGML_DEFAULT_N_THREADS;
-
-        const char *env = getenv("GGML_N_THREADS");
-        if (env) {
-            n_threads = atoi(env);
-        }
-
-        printf("GGML_N_THREADS = %d\n", n_threads);
-    }
-
-    struct ggml_cgraph gf = ggml_build_forward (f);
-    struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
-
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-
-    ggml_graph_reset  (&gf);
-    ggml_set_f32      (f->grad, 1.0f);
-
-    ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
-
-    // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
-    // ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
-
-    for (int i = 0; i < nargs; ++i) {
-        const int nelements = ggml_nelements(x[i]);
-        for (int k = 0; k < nelements; ++k) {
-            // compute gradient using finite differences
-            const float x0 = ggml_get_f32_1d(x[i], k);
-            const float xm = x0 - eps;
-            const float xp = x0 + eps;
-            ggml_set_f32_1d(x[i], k, xp);
-
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-
-            const float f0 = ggml_get_f32_1d(f, 0);
-
-            ggml_set_f32_1d(x[i], k, xm);
-
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-
-            const float f1 = ggml_get_f32_1d(f, 0);
-            const float g0 = (f0 - f1)/(2.0f*eps);
-
-            ggml_set_f32_1d(x[i], k, x0);
-
-            // compute gradient using backward graph
-            ggml_graph_reset  (&gf);
-            ggml_set_f32      (f->grad, 1.0f);
-
-            ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
-
-            const float g1 = ggml_get_f32_1d(x[i]->grad, k);
-
-            const float error_abs = fabsf(g0 - g1);
-            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
-
-            if (error_abs > max_error_abs || error_rel > max_error_rel) {
-                printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
-                            op_name, ndims, i, k, x0, xm, xp, f0, f1, g0, g1, eps, error_abs, error_rel);
-                //assert(false);
-                return false;
-            }
-        }
-    }
-
-    return true;
-}
-
-// TODO: clean-up this ..
-bool check_mat_mul(
-        const struct ggml_tensor * y,
-        const struct ggml_tensor * x0,
-        const struct ggml_tensor * x1) {
-    float * dst  = (float *) y->data;
-    float * src0 = (float *) x0->data;
-    float * src1 = (float *) x1->data;
-
-    const int nc = x0->ne[1];
-    const int nr = x1->ne[1];
-    const int nk = x0->ne[0];
-
-    GGML_PRINT_DEBUG("check_mat_mul: nc=%d, nr=%d, nk=%d\n", nc, nr, nk);
-
-    GGML_PRINT_DEBUG("x0:\n");
-    for (int j = 0; j < x0->ne[1]; ++j) {
-        for (int i = 0; i < x0->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", src0[j*nk + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-    GGML_PRINT_DEBUG("\n");
-
-    GGML_PRINT_DEBUG("x1:\n");
-    for (int j = 0; j < x1->ne[1]; ++j) {
-        for (int i = 0; i < x1->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", src1[j*nk + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-    GGML_PRINT_DEBUG("\n");
-
-    GGML_PRINT_DEBUG("y: n_dims = %d, (%lld, %lld)\n", y->n_dims, y->ne[0], y->ne[1]);
-    for (int j = 0; j < y->ne[1]; ++j) {
-        for (int i = 0; i < y->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", dst[j*nr + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-
-    for (int i = 0; i < nr; ++i) {
-        for (int j = 0; j < nc; ++j) {
-            float sum = 0.0f;
-
-            for (int k = 0; k < nk; ++k) {
-                sum += src0[j*nk + k]*src1[i*nk + k];
-            }
-
-            if (fabsf(dst[i*nc + j] - sum) > 1e-5f) {
-                fprintf(stderr, "check_mat_mul: dst[%d] = %f, sum = %f\n", i*nc + j, dst[i*nc + j], sum);
-                assert(false);
-                return false;
-            }
-        }
-    }
-
-    return true;
-}
-
-#define NUM_PERMUTATIONS (4*3*2*1)
-
-int main(int argc, const char ** argv) {
-    struct ggml_init_params params = {
-        .mem_size   = 128*1024*1024,
-        .mem_buffer = NULL,
-        .no_alloc   = false,
-    };
-
-    int64_t ne[4];
-
-    int all_permutations[4 * NUM_PERMUTATIONS];
-    {
-        int count = 0;
-        for (int ax0=0; ax0<4; ++ax0) {
-            for (int ax1=0; ax1<4; ++ax1) {
-                if (ax1 == ax0) continue;
-                for (int ax2=0; ax2<4; ++ax2) {
-                    if (ax2 == ax0) continue;
-                    if (ax2 == ax1) continue;
-                    for (int ax3=0; ax3<4; ++ax3) {
-                        if (ax3 == ax0) continue;
-                        if (ax3 == ax1) continue;
-                        if (ax3 == ax2) continue;
-                        assert(count < NUM_PERMUTATIONS);
-                        all_permutations[count*4+0] = ax0;
-                        all_permutations[count*4+1] = ax1;
-                        all_permutations[count*4+2] = ax2;
-                        all_permutations[count*4+3] = ax3;
-                        ++count;
-                    }
-                }
-            }
-        }
-    }
-
-
-    // original loop: 1000
-    int niter = 4;
-    const char *env = getenv("GGML_NLOOP");
-    if (env != NULL) {
-        niter = atoi(env);
-    }
-    if (argc > 1) {
-        niter = atoi(argv[1]);
-    }
-    for (int iter = 0; iter < niter; ++iter) {
-        printf("test-grad0: iter:%d/%d\n", iter, niter);
-        struct ggml_context * ctx0 = ggml_init(params);
-
-        get_random_dims(ne, 4);
-
-        struct ggml_tensor * x[MAX_NARGS];
-
-        // add f32
-        {
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
-
-                check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
-            }
-        }
-
-        // add f16
-        {
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
-
-                check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
-            }
-        }
-
-        // sub
-        {
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
-
-                check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // mul
-        {
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
-
-                check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // div
-        {
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
-
-                check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f);
-            }
-        }
-
-        // sqr
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
-
-                check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // sqrt
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
-
-                check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
-            }
-        }
-
-        // log
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0]));
-
-                check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
-            }
-        }
-
-        // sum
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
-
-                check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-
-        // sum_rows
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0])));
-
-                check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
-            }
-        }
-
-        // mean, not yet fully implemented
-        if(0)
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
-
-                check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // argmax
-        if (0)
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
-
-                check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // repeat
-        {
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            ne2[0] = ne[0] * ne2[0];
-            ne2[1] = ne[1] * ne2[1];
-            ne2[2] = 1;
-            ne2[3] = 1;
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
-
-                check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
-            }
-        }
-
-        // repeat back
-        {
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            ne2[0] = ne[0] * ne2[0];
-            ne2[1] = ne[1] * ne2[1];
-            ne2[2] = 1;
-            ne2[3] = 1;
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
-
-                check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
-            }
-        }
-
-        // abs (finite differences do not work)
-        //{
-        //    const int nargs = 1;
-
-        //    for (int ndims = 1; ndims <= 2; ++ndims) {
-        //        for (int i = 0; i < nargs; ++i) {
-        //            x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-        //            ggml_set_param(ctx0, x[i]);
-        //        }
-
-        //        struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
-
-        //        check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f);
-        //    }
-        //}
-
-        // sgn
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
-
-                check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // neg
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
-
-                check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // step
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
-
-                check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // tanh, not yet fully implemented
-        if(0)
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
-
-                check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // mul_mat
-        {
-            const int nargs = 2;
-
-            for (int ndims = 2; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                {
-                    int64_t ne2[4];
-                    get_random_dims(ne2, 4);
-                    ne2[0] = ne[0];
-                    x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                }
-
-                ggml_set_param(ctx0, x[0]);
-                ggml_set_param(ctx0, x[1]);
-
-                struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
-                struct ggml_tensor * f = ggml_sum(ctx0, m);
-
-                GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
-
-                check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-                check_mat_mul(m, x[1], x[0]);
-            }
-        }
-
-        // elu, not yet fully implemented
-        if(0)
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
-
-                check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // relu
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
-
-                check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // gelu, not yet fully implemented
-        if(0)
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
-
-                check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
-            }
-        }
-
-        // silu
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_silu(ctx0, x[0]));
-
-#ifdef GGML_SILU_FP16
-                // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
-                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY);
-#else
-                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-#endif
-            }
-        }
-
-        // rms_norm
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
-
-                check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
-            }
-        }
-
-        // scale
-        {
-            const int nargs = 2;
-
-            int64_t ne2[4];
-            ne2[0] = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-                ggml_set_param(ctx0, x[1]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], x[1]));
-
-                check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // cpy f32
-        {
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
-
-                check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // cpy f16
-        {
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
-
-                check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
-            }
-        }
-
-        // reshape (1d->nd)
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                int64_t ne2[4];
-                ne2[0] = 1;
-                ne2[1] = 1;
-                ne2[2] = 1;
-                ne2[3] = 1;
-                for (int i = 0; i < ndims; ++i) {
-                    ne2[0] *= ne[i];
-                }
-                x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
-                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // reshape (nd->1d)
-        {
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                int64_t ne2[4];
-                ne2[0] = 1;
-                ne2[1] = 1;
-                ne2[2] = 1;
-                ne2[3] = 1;
-                for (int i = 0; i < ndims; ++i) {
-                    ne2[0] *= ne[i];
-                }
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
-                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // acc 1d
-        {
-            int64_t ne2[4] = { 1, 1, 1, 1 };
-
-            const int nargs = 2;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 1);
-                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 1);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
-                const int offset = irand(max_offset) * ggml_element_size(x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // acc 2d
-        {
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 2; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 2);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 2);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                const int offset = offsets[0] + offsets[1];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // acc 3d
-        {
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 3);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 3);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
-                const int offset = offsets[0] + offsets[1] + offsets[2];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // acc 4d
-        {
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 4; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 4);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[3] > ne[3]) || (ne2[0]*ne2[1]*ne2[2]*ne2[3] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 4);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
-                max_offsets[3] = MAX(0, x[0]->ne[3] - x[1]->ne[3]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
-                offsets[3] = irand(max_offsets[3]) * x[0]->nb[3];
-                const int offset = offsets[0] + offsets[1] + offsets[2] + offsets[3];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // set_1d
-        {
-            int64_t ne2[4];
-
-            const int nargs = 2;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 1);
-                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 1);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
-                const int offset = irand(max_offset) * ggml_element_size(x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset));
-
-                check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // set_2d
-        {
-            int64_t ne2[4];
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 1;
-            for (int ndims = 2; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 2);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 2);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                const int offset = offsets[0] + offsets[1];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset));
-
-                check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // view_1d
-        {
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int k0 = irand(ggml_nelements(x[0]));
-                const int k1 = irand(ggml_nelements(x[0]));
-                const int i0 = MIN(k0, k1);
-                const int i1 = MAX(k0, k1);
-
-                const int offset = i0 * sizeof(float);
-                const int nelem  = i1 - i0;
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset));
-
-                check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // view_2d
-        {
-            int64_t ne2[4];
-            int64_t nb2[4];
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                get_random_dims(ne2, 2);
-                while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
-                    get_random_dims(ne2, 2);
-                }
-                const int count = ne2[0]*ne2[1];
-
-                nb2[0] = sizeof(float);
-                nb2[1] = nb2[0]*ne2[0];
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int max_offset = ggml_nelements(x[0]) - count;
-                const int offset = irand(max_offset+1) * sizeof(float);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset));
-
-                check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // view_3d
-        {
-            int64_t ne2[4] = {1,1,1,1};
-            int64_t nb2[4] = {0,0,0,0};
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                get_random_dims(ne2, 3);
-                while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
-                    get_random_dims(ne2, 3);
-                }
-                const int count = ne2[0]*ne2[1]*ne2[2];
-
-                nb2[0] = sizeof(float);
-                nb2[1] = nb2[0]*ne2[0];
-                nb2[2] = nb2[1]*ne2[1];
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int max_offset = ggml_nelements(x[0]) - count;
-                const int offset = irand(max_offset+1) * sizeof(float);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset));
-
-                check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // permute
-        {
-            int64_t ne2[4];
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims)
-            {
-                // ggml_permute will set axes of dimensions below n_dims to 1.
-                // to make ggml_permute work correctly on all axes,
-                // the input tensor needs maximal n_dim of 4.
-                for (int i=0; i<ndims; ++i) {
-                    ne2[i] = ne[i];
-                }
-                for (int i=ndims; i<4; ++i) {
-                    ne2[i] = 1;
-                }
-                x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int p = irand(NUM_PERMUTATIONS);
-                const int ax0 = all_permutations[p*4+0];
-                const int ax1 = all_permutations[p*4+1];
-                const int ax2 = all_permutations[p*4+2];
-                const int ax3 = all_permutations[p*4+3];
-
-                // sum requires contiguous tensor rows
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3)));
-
-                check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // transpose
-        {
-            int64_t ne2[4];
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims)
-            {
-                // ggml_transpose will set axes of dimensions below n_dims to 1.
-                // to make ggml_transpose work correctly on all axes,
-                // the input tensor needs maximal n_dim of 4.
-                for (int i=0; i<ndims; ++i) {
-                    ne2[i] = ne[i];
-                }
-                for (int i=ndims; i<4; ++i) {
-                    ne2[i] = 1;
-                }
-                x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                // sum requires contiguous tensor rows
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0])));
-
-                check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // get_rows
-        {
-            int64_t ne2[4] = {ne[0], ne[1], 1, 1};
-            int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
-            const int nargs = 1;
-            const int ndims = 2;
-            x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-            x[1] = get_random_tensor_i32(ctx0, 1, ne3, 0, ne2[1]);
-
-            ggml_set_param(ctx0, x[0]);
-
-            struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
-
-            check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-        }
-
-        // diag_mask_inf
-        {
-            const int nargs = 1;
-            const int ndims = 2;
-
-            x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-            ggml_set_param(ctx0, x[0]);
-
-            int n_past = irand(ne[0]);
-
-            struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past));
-
-            check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-        }
-
-        // diag_mask_zero
-        {
-            const int nargs = 1;
-            const int ndims = 2;
-
-            x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-            ggml_set_param(ctx0, x[0]);
-
-            int n_past = irand(ne[0]);
-
-            struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past));
-
-            check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-        }
-
-        // softmax
-        {
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            for (int ndims = 1; ndims <= 3; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
-
-                check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-            }
-        }
-
-        // cross_entropy_loss
-        {
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            for (int ndims = 1; ndims <= 3; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
-
-                check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-1f, 1e-2f, INFINITY);
-                // finite differences regularly fails!
-            }
-        }
-
-        // rope f32
-        {
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-            ne2[0] += ne2[0] % 2;
-            int n_rot = ne2[0];
-
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-                for (int mode = 0; mode < 4; ++mode) {
-                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
-                        x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-
-                        ggml_set_param(ctx0, x[0]);
-
-                        const bool skip_past = (mode & 1);
-                        if (skip_past) {
-                            // we have no past, so this would have to work on uninitialized memory.
-                            // we only test the gradients here;
-                            // skip_past should have no influence on gradient computation.
-                            // so when other modes work, we assume that this does as well.
-                            continue;
-                        }
-
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
-
-                        GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
-                    }
-                }
-            }
-        }
-
-        // rope f16
-        {
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-            ne2[0] += ne2[0] % 2;
-            int n_rot = ne2[0];
-
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-                for (int mode = 0; mode < 4; ++mode) {
-                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
-                        x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
-
-                        ggml_set_param(ctx0, x[0]);
-
-                        const bool skip_past = (mode & 1);
-                        if (skip_past) {
-                            // we have no past, so this would have to work on uninitialized memory.
-                            // we only test the gradients here;
-                            // skip_past should have no influence on gradient computation.
-                            // so when other modes work, we assume that this does as well.
-                            continue;
-                        }
-
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
-
-                        GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
-                    }
-                }
-            }
-        }
-
-        // flash_attn f32
-        {
-            const int nargs = 3;
-
-            int64_t ne2[4];
-
-            get_random_dims(ne2, 4);
-            int64_t D = ne2[0];
-            int64_t N = ne2[1];
-            int64_t M = ne2[2] + N;
-            int64_t B = ne2[3];
-
-            for (int masked = 0; masked <= 1; ++masked) {
-                for (int ndims = 2; ndims <= 4; ++ndims) {
-                    int64_t neq[4] = { D, N, B, ne[3] };
-                    int64_t nek[4] = { D, M, B, ne[3] };
-                    int64_t nev[4] = { M, D, B, ne[3] };
-                    if (ndims == 2) {
-                        neq[2] = 1; neq[3] = 1;
-                        nek[2] = 1; nek[3] = 1;
-                        nev[2] = 1; nev[3] = 1;
-                    } else if (ndims == 3) {
-                        neq[3] = 1;
-                        nek[3] = 1;
-                        nev[3] = 1;
-                    }
-                    x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
-                    x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
-                    x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
-                    ggml_set_param(ctx0, x[0]);
-                    ggml_set_param(ctx0, x[1]);
-                    ggml_set_param(ctx0, x[2]);
-
-                    struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
-
-                    check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
-                }
-            }
-        }
-
-        // flash_attn f16, not yet fully implemented
-        if(0)
-        {
-            const int nargs = 3;
-
-            int64_t ne2[4];
-
-            get_random_dims(ne2, 4);
-            int64_t D = ne2[0];
-            int64_t N = ne2[1];
-            int64_t M = ne2[2] + N;
-            int64_t B = ne2[3];
-
-            for (int masked = 0; masked <= 1; ++masked) {
-                for (int ndims = 2; ndims <= 4; ++ndims) {
-                    int64_t neq[4] = { D, N, B, ne[3] };
-                    int64_t nek[4] = { D, M, B, ne[3] };
-                    int64_t nev[4] = { M, D, B, ne[3] };
-                    if (ndims == 2) {
-                        neq[2] = 1; neq[3] = 1;
-                        nek[2] = 1; nek[3] = 1;
-                        nev[2] = 1; nev[3] = 1;
-                    } else if (ndims == 3) {
-                        neq[3] = 1;
-                        nek[3] = 1;
-                        nev[3] = 1;
-                    }
-                    x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
-                    x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
-                    x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
-                    ggml_set_param(ctx0, x[0]);
-                    ggml_set_param(ctx0, x[1]);
-                    ggml_set_param(ctx0, x[2]);
-
-                    struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
-
-                    check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
-                }
-            }
-        }
-        ggml_free(ctx0);
-    }
-
-    return 0;
-}
diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp
new file mode 100644 (file)
index 0000000..75a698d
--- /dev/null
@@ -0,0 +1,1525 @@
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#include "ggml.h"
+
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cassert>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
+#endif
+
+#define MAX_NARGS 3
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#define GGML_SILU_FP16
+
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+static float frand(void) {
+    return (float)rand()/(float)RAND_MAX;
+}
+
+static int irand(int n) {
+    if (n == 0) return 0;
+    return rand()%n;
+}
+
+static void get_random_dims(int64_t * dims, int ndims) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = 1 + irand(4);
+    }
+}
+
+static struct ggml_tensor * get_random_tensor_f32(
+        struct ggml_context * ctx0,
+        int ndims,
+        int64_t ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+static struct ggml_tensor * get_random_tensor_f16(
+        struct ggml_context * ctx0,
+        int ndims,
+        int64_t ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+static struct ggml_tensor * get_random_tensor_i32(
+        struct ggml_context * ctx0,
+        int ndims,
+        int64_t ne[],
+        int32_t imin,
+        int32_t imax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((int32_t *)result->data)[i0] = irand(imax - imin) + imin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+static void print_elements(const char* label, const struct ggml_tensor * t) {
+    if (!t) {
+        printf("%s: %s = null\n", __func__, label);
+        return;
+    }
+    const int nelements = ggml_nelements(t);
+    printf("%s: %s = [", __func__, label);
+    for (int k = 0; k < nelements; ++k) {
+        if (k > 0) { printf(", "); }
+        printf("%.5f", ggml_get_f32_1d(t, k));
+    }
+    printf("] shape: [");
+    for (int k = 0; k < t->n_dims; ++k) {
+        if (k > 0) { printf(", "); }
+        printf("%d", (int)t->ne[k]);
+    }
+    printf("]\n");
+
+}
+
+static bool check_gradient(
+        const char * op_name,
+        struct ggml_context * ctx0,
+        struct ggml_tensor * x[],
+        struct ggml_tensor * f,
+        int ndims,
+        int nargs,
+        float eps,
+        float max_error_abs,
+        float max_error_rel) {
+
+    static int n_threads = -1;
+    if (n_threads < 0) {
+        n_threads = GGML_DEFAULT_N_THREADS;
+
+        const char *env = getenv("GGML_N_THREADS");
+        if (env) {
+            n_threads = atoi(env);
+        }
+
+        printf("GGML_N_THREADS = %d\n", n_threads);
+    }
+
+    struct ggml_cgraph gf = ggml_build_forward (f);
+    struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+
+    ggml_graph_reset  (&gf);
+    ggml_set_f32      (f->grad, 1.0f);
+
+    ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+
+    // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
+    // ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
+
+    for (int i = 0; i < nargs; ++i) {
+        const int nelements = ggml_nelements(x[i]);
+        for (int k = 0; k < nelements; ++k) {
+            // compute gradient using finite differences
+            const float x0 = ggml_get_f32_1d(x[i], k);
+            const float xm = x0 - eps;
+            const float xp = x0 + eps;
+            ggml_set_f32_1d(x[i], k, xp);
+
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+
+            const float f0 = ggml_get_f32_1d(f, 0);
+
+            ggml_set_f32_1d(x[i], k, xm);
+
+            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+
+            const float f1 = ggml_get_f32_1d(f, 0);
+            const float g0 = (f0 - f1)/(2.0f*eps);
+
+            ggml_set_f32_1d(x[i], k, x0);
+
+            // compute gradient using backward graph
+            ggml_graph_reset  (&gf);
+            ggml_set_f32      (f->grad, 1.0f);
+
+            ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+
+            const float g1 = ggml_get_f32_1d(x[i]->grad, k);
+
+            const float error_abs = fabsf(g0 - g1);
+            const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
+
+            if (error_abs > max_error_abs || error_rel > max_error_rel) {
+                printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
+                            op_name, ndims, i, k, x0, xm, xp, f0, f1, g0, g1, eps, error_abs, error_rel);
+                //assert(false);
+                return false;
+            }
+        }
+    }
+
+    return true;
+}
+
+// TODO: clean-up this ..
+static bool check_mat_mul(
+        const struct ggml_tensor * y,
+        const struct ggml_tensor * x0,
+        const struct ggml_tensor * x1) {
+    float * dst  = (float *) y->data;
+    float * src0 = (float *) x0->data;
+    float * src1 = (float *) x1->data;
+
+    const int nc = x0->ne[1];
+    const int nr = x1->ne[1];
+    const int nk = x0->ne[0];
+
+    GGML_PRINT_DEBUG("check_mat_mul: nc=%d, nr=%d, nk=%d\n", nc, nr, nk);
+
+    GGML_PRINT_DEBUG("x0:\n");
+    for (int j = 0; j < x0->ne[1]; ++j) {
+        for (int i = 0; i < x0->ne[0]; ++i) {
+            GGML_PRINT_DEBUG("%6.3f ", src0[j*nk + i]);
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+    GGML_PRINT_DEBUG("\n");
+
+    GGML_PRINT_DEBUG("x1:\n");
+    for (int j = 0; j < x1->ne[1]; ++j) {
+        for (int i = 0; i < x1->ne[0]; ++i) {
+            GGML_PRINT_DEBUG("%6.3f ", src1[j*nk + i]);
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+    GGML_PRINT_DEBUG("\n");
+
+    GGML_PRINT_DEBUG("y: n_dims = %d, (%lld, %lld)\n", y->n_dims, y->ne[0], y->ne[1]);
+    for (int j = 0; j < y->ne[1]; ++j) {
+        for (int i = 0; i < y->ne[0]; ++i) {
+            GGML_PRINT_DEBUG("%6.3f ", dst[j*nr + i]);
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+
+    for (int i = 0; i < nr; ++i) {
+        for (int j = 0; j < nc; ++j) {
+            float sum = 0.0f;
+
+            for (int k = 0; k < nk; ++k) {
+                sum += src0[j*nk + k]*src1[i*nk + k];
+            }
+
+            if (fabsf(dst[i*nc + j] - sum) > 1e-5f) {
+                fprintf(stderr, "check_mat_mul: dst[%d] = %f, sum = %f\n", i*nc + j, dst[i*nc + j], sum);
+                assert(false);
+                return false;
+            }
+        }
+    }
+
+    return true;
+}
+
+#define NUM_PERMUTATIONS (4*3*2*1)
+
+int main(int argc, const char ** argv) {
+    struct ggml_init_params params = {
+        /* .mem_size   = */ 128*1024*1024,
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ false,
+    };
+
+    int64_t ne[4];
+
+    int all_permutations[4 * NUM_PERMUTATIONS];
+    {
+        int count = 0;
+        for (int ax0=0; ax0<4; ++ax0) {
+            for (int ax1=0; ax1<4; ++ax1) {
+                if (ax1 == ax0) continue;
+                for (int ax2=0; ax2<4; ++ax2) {
+                    if (ax2 == ax0) continue;
+                    if (ax2 == ax1) continue;
+                    for (int ax3=0; ax3<4; ++ax3) {
+                        if (ax3 == ax0) continue;
+                        if (ax3 == ax1) continue;
+                        if (ax3 == ax2) continue;
+                        assert(count < NUM_PERMUTATIONS);
+                        all_permutations[count*4+0] = ax0;
+                        all_permutations[count*4+1] = ax1;
+                        all_permutations[count*4+2] = ax2;
+                        all_permutations[count*4+3] = ax3;
+                        ++count;
+                    }
+                }
+            }
+        }
+    }
+
+
+    // original loop: 1000
+    int niter = 4;
+    const char *env = getenv("GGML_NLOOP");
+    if (env != NULL) {
+        niter = atoi(env);
+    }
+    if (argc > 1) {
+        niter = atoi(argv[1]);
+    }
+    for (int iter = 0; iter < niter; ++iter) {
+        printf("test-grad0: iter:%d/%d\n", iter, niter);
+        struct ggml_context * ctx0 = ggml_init(params);
+
+        get_random_dims(ne, 4);
+
+        struct ggml_tensor * x[MAX_NARGS];
+
+        // add f32
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
+
+                check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
+            }
+        }
+
+        // add f16
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
+
+                check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
+            }
+        }
+
+        // sub
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
+
+                check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // mul
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
+
+                check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // div
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
+
+                check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f);
+            }
+        }
+
+        // sqr
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
+
+                check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // sqrt
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
+
+                check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
+            }
+        }
+
+        // log
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0]));
+
+                check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
+            }
+        }
+
+        // sum
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
+
+                check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+
+        // sum_rows
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0])));
+
+                check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+            }
+        }
+
+        // mean, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
+
+                check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // argmax
+        if (0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
+
+                check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // repeat
+        {
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+
+            ne2[0] = ne[0] * ne2[0];
+            ne2[1] = ne[1] * ne2[1];
+            ne2[2] = 1;
+            ne2[3] = 1;
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
+
+                check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+            }
+        }
+
+        // repeat back
+        {
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+
+            ne2[0] = ne[0] * ne2[0];
+            ne2[1] = ne[1] * ne2[1];
+            ne2[2] = 1;
+            ne2[3] = 1;
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
+
+                check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+            }
+        }
+
+        // abs (finite differences do not work)
+        //{
+        //    const int nargs = 1;
+
+        //    for (int ndims = 1; ndims <= 2; ++ndims) {
+        //        for (int i = 0; i < nargs; ++i) {
+        //            x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+        //            ggml_set_param(ctx0, x[i]);
+        //        }
+
+        //        struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
+
+        //        check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f);
+        //    }
+        //}
+
+        // sgn
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
+
+                check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // neg
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
+
+                check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // step
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
+
+                check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // tanh, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
+
+                check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // mul_mat
+        {
+            const int nargs = 2;
+
+            for (int ndims = 2; ndims <= 2; ++ndims) {
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                {
+                    int64_t ne2[4];
+                    get_random_dims(ne2, 4);
+                    ne2[0] = ne[0];
+                    x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                }
+
+                ggml_set_param(ctx0, x[0]);
+                ggml_set_param(ctx0, x[1]);
+
+                struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
+                struct ggml_tensor * f = ggml_sum(ctx0, m);
+
+                GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
+
+                check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_mat_mul(m, x[1], x[0]);
+            }
+        }
+
+        // elu, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
+
+                check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // relu
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
+
+                check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // gelu, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
+
+                check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // silu
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_silu(ctx0, x[0]));
+
+#ifdef GGML_SILU_FP16
+                // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
+                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY);
+#else
+                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+#endif
+            }
+        }
+
+        // rms_norm
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
+
+                check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
+            }
+        }
+
+        // scale
+        {
+            const int nargs = 2;
+
+            int64_t ne2[4];
+            ne2[0] = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+                ggml_set_param(ctx0, x[1]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], x[1]));
+
+                check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // cpy f32
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
+
+                check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // cpy f16
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
+
+                check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
+            }
+        }
+
+        // reshape (1d->nd)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                int64_t ne2[4];
+                ne2[0] = 1;
+                ne2[1] = 1;
+                ne2[2] = 1;
+                ne2[3] = 1;
+                for (int i = 0; i < ndims; ++i) {
+                    ne2[0] *= ne[i];
+                }
+                x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
+                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // reshape (nd->1d)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                int64_t ne2[4];
+                ne2[0] = 1;
+                ne2[1] = 1;
+                ne2[2] = 1;
+                ne2[3] = 1;
+                for (int i = 0; i < ndims; ++i) {
+                    ne2[0] *= ne[i];
+                }
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
+                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // acc 1d
+        {
+            int64_t ne2[4] = { 1, 1, 1, 1 };
+
+            const int nargs = 2;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 1);
+                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 1);
+                }
+
+                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
+                const int offset = irand(max_offset) * ggml_element_size(x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
+
+                check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // acc 2d
+        {
+            int64_t ne2[4]         = { 1, 1, 1, 1 };
+            int64_t max_offsets[4] = { 0, 0, 0, 0 };
+            int64_t offsets[4]     = { 0, 0, 0, 0 };
+
+            const int nargs = 2;
+            for (int ndims = 2; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 2);
+                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 2);
+                }
+
+                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
+                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
+                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
+                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
+                const int offset = offsets[0] + offsets[1];
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
+
+                check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // acc 3d
+        {
+            int64_t ne2[4]         = { 1, 1, 1, 1 };
+            int64_t max_offsets[4] = { 0, 0, 0, 0 };
+            int64_t offsets[4]     = { 0, 0, 0, 0 };
+
+            const int nargs = 2;
+            for (int ndims = 3; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 3);
+                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 3);
+                }
+
+                x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
+                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
+                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
+                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
+                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
+                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
+                const int offset = offsets[0] + offsets[1] + offsets[2];
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
+
+                check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // acc 4d
+        {
+            int64_t ne2[4]         = { 1, 1, 1, 1 };
+            int64_t max_offsets[4] = { 0, 0, 0, 0 };
+            int64_t offsets[4]     = { 0, 0, 0, 0 };
+
+            const int nargs = 2;
+            for (int ndims = 4; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 4);
+                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[3] > ne[3]) || (ne2[0]*ne2[1]*ne2[2]*ne2[3] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 4);
+                }
+
+                x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
+                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
+                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
+                max_offsets[3] = MAX(0, x[0]->ne[3] - x[1]->ne[3]);
+                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
+                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
+                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
+                offsets[3] = irand(max_offsets[3]) * x[0]->nb[3];
+                const int offset = offsets[0] + offsets[1] + offsets[2] + offsets[3];
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
+
+                check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // set_1d
+        {
+            int64_t ne2[4];
+
+            const int nargs = 2;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 1);
+                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 1);
+                }
+
+                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
+                const int offset = irand(max_offset) * ggml_element_size(x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset));
+
+                check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // set_2d
+        {
+            int64_t ne2[4];
+            int64_t max_offsets[4] = { 0, 0, 0, 0 };
+            int64_t offsets[4]     = { 0, 0, 0, 0 };
+
+            const int nargs = 1;
+            for (int ndims = 2; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                get_random_dims(ne2, 2);
+                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
+                    get_random_dims(ne2, 2);
+                }
+
+                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[1]);
+
+                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
+                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
+                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
+                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
+                const int offset = offsets[0] + offsets[1];
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset));
+
+                check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // view_1d
+        {
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int k0 = irand(ggml_nelements(x[0]));
+                const int k1 = irand(ggml_nelements(x[0]));
+                const int i0 = MIN(k0, k1);
+                const int i1 = MAX(k0, k1);
+
+                const int offset = i0 * sizeof(float);
+                const int nelem  = i1 - i0;
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset));
+
+                check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // view_2d
+        {
+            int64_t ne2[4];
+            int64_t nb2[4];
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                get_random_dims(ne2, 2);
+                while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
+                    get_random_dims(ne2, 2);
+                }
+                const int count = ne2[0]*ne2[1];
+
+                nb2[0] = sizeof(float);
+                nb2[1] = nb2[0]*ne2[0];
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int max_offset = ggml_nelements(x[0]) - count;
+                const int offset = irand(max_offset+1) * sizeof(float);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset));
+
+                check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // view_3d
+        {
+            int64_t ne2[4] = {1,1,1,1};
+            int64_t nb2[4] = {0,0,0,0};
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+
+                get_random_dims(ne2, 3);
+                while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
+                    get_random_dims(ne2, 3);
+                }
+                const int count = ne2[0]*ne2[1]*ne2[2];
+
+                nb2[0] = sizeof(float);
+                nb2[1] = nb2[0]*ne2[0];
+                nb2[2] = nb2[1]*ne2[1];
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int max_offset = ggml_nelements(x[0]) - count;
+                const int offset = irand(max_offset+1) * sizeof(float);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset));
+
+                check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // permute
+        {
+            int64_t ne2[4];
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims)
+            {
+                // ggml_permute will set axes of dimensions below n_dims to 1.
+                // to make ggml_permute work correctly on all axes,
+                // the input tensor needs maximal n_dim of 4.
+                for (int i=0; i<ndims; ++i) {
+                    ne2[i] = ne[i];
+                }
+                for (int i=ndims; i<4; ++i) {
+                    ne2[i] = 1;
+                }
+                x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int p = irand(NUM_PERMUTATIONS);
+                const int ax0 = all_permutations[p*4+0];
+                const int ax1 = all_permutations[p*4+1];
+                const int ax2 = all_permutations[p*4+2];
+                const int ax3 = all_permutations[p*4+3];
+
+                // sum requires contiguous tensor rows
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3)));
+
+                check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // transpose
+        {
+            int64_t ne2[4];
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 4; ++ndims)
+            {
+                // ggml_transpose will set axes of dimensions below n_dims to 1.
+                // to make ggml_transpose work correctly on all axes,
+                // the input tensor needs maximal n_dim of 4.
+                for (int i=0; i<ndims; ++i) {
+                    ne2[i] = ne[i];
+                }
+                for (int i=ndims; i<4; ++i) {
+                    ne2[i] = 1;
+                }
+                x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                // sum requires contiguous tensor rows
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0])));
+
+                check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // get_rows
+        {
+            int64_t ne2[4] = {ne[0], ne[1], 1, 1};
+            int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
+            const int nargs = 1;
+            const int ndims = 2;
+            x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+            x[1] = get_random_tensor_i32(ctx0, 1, ne3, 0, ne2[1]);
+
+            ggml_set_param(ctx0, x[0]);
+
+            struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
+
+            check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+        }
+
+        // diag_mask_inf
+        {
+            const int nargs = 1;
+            const int ndims = 2;
+
+            x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+            ggml_set_param(ctx0, x[0]);
+
+            int n_past = irand(ne[0]);
+
+            struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past));
+
+            check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+        }
+
+        // diag_mask_zero
+        {
+            const int nargs = 1;
+            const int ndims = 2;
+
+            x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+            ggml_set_param(ctx0, x[0]);
+
+            int n_past = irand(ne[0]);
+
+            struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past));
+
+            check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+        }
+
+        // softmax
+        {
+            const int nargs = 1;
+
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+
+            for (int ndims = 1; ndims <= 3; ++ndims) {
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
+
+                check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // cross_entropy_loss
+        {
+            const int nargs = 1;
+
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+
+            for (int ndims = 1; ndims <= 3; ++ndims) {
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
+
+                check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-1f, 1e-2f, INFINITY);
+                // finite differences regularly fails!
+            }
+        }
+
+        // rope f32
+        {
+            const int nargs = 1;
+
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+            ne2[0] += ne2[0] % 2;
+            int n_rot = ne2[0];
+
+            for (int ndims = 3; ndims <= 4; ++ndims) {
+                for (int mode = 0; mode < 4; ++mode) {
+                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
+                        x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+
+                        ggml_set_param(ctx0, x[0]);
+
+                        const bool skip_past = (mode & 1);
+                        if (skip_past) {
+                            // we have no past, so this would have to work on uninitialized memory.
+                            // we only test the gradients here;
+                            // skip_past should have no influence on gradient computation.
+                            // so when other modes work, we assume that this does as well.
+                            continue;
+                        }
+
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
+
+                        GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
+                        check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
+                    }
+                }
+            }
+        }
+
+        // rope f16
+        {
+            const int nargs = 1;
+
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+            ne2[0] += ne2[0] % 2;
+            int n_rot = ne2[0];
+
+            for (int ndims = 3; ndims <= 4; ++ndims) {
+                for (int mode = 0; mode < 4; ++mode) {
+                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
+                        x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
+
+                        ggml_set_param(ctx0, x[0]);
+
+                        const bool skip_past = (mode & 1);
+                        if (skip_past) {
+                            // we have no past, so this would have to work on uninitialized memory.
+                            // we only test the gradients here;
+                            // skip_past should have no influence on gradient computation.
+                            // so when other modes work, we assume that this does as well.
+                            continue;
+                        }
+
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
+
+                        GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
+                        check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
+                    }
+                }
+            }
+        }
+
+        // flash_attn f32
+        {
+            const int nargs = 3;
+
+            int64_t ne2[4];
+
+            get_random_dims(ne2, 4);
+            int64_t D = ne2[0];
+            int64_t N = ne2[1];
+            int64_t M = ne2[2] + N;
+            int64_t B = ne2[3];
+
+            for (int masked = 0; masked <= 1; ++masked) {
+                for (int ndims = 2; ndims <= 4; ++ndims) {
+                    int64_t neq[4] = { D, N, B, ne[3] };
+                    int64_t nek[4] = { D, M, B, ne[3] };
+                    int64_t nev[4] = { M, D, B, ne[3] };
+                    if (ndims == 2) {
+                        neq[2] = 1; neq[3] = 1;
+                        nek[2] = 1; nek[3] = 1;
+                        nev[2] = 1; nev[3] = 1;
+                    } else if (ndims == 3) {
+                        neq[3] = 1;
+                        nek[3] = 1;
+                        nev[3] = 1;
+                    }
+                    x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
+                    x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
+                    x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
+                    ggml_set_param(ctx0, x[0]);
+                    ggml_set_param(ctx0, x[1]);
+                    ggml_set_param(ctx0, x[2]);
+
+                    struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+
+                    check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
+                }
+            }
+        }
+
+        // flash_attn f16, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 3;
+
+            int64_t ne2[4];
+
+            get_random_dims(ne2, 4);
+            int64_t D = ne2[0];
+            int64_t N = ne2[1];
+            int64_t M = ne2[2] + N;
+            int64_t B = ne2[3];
+
+            for (int masked = 0; masked <= 1; ++masked) {
+                for (int ndims = 2; ndims <= 4; ++ndims) {
+                    int64_t neq[4] = { D, N, B, ne[3] };
+                    int64_t nek[4] = { D, M, B, ne[3] };
+                    int64_t nev[4] = { M, D, B, ne[3] };
+                    if (ndims == 2) {
+                        neq[2] = 1; neq[3] = 1;
+                        nek[2] = 1; nek[3] = 1;
+                        nev[2] = 1; nev[3] = 1;
+                    } else if (ndims == 3) {
+                        neq[3] = 1;
+                        nek[3] = 1;
+                        nev[3] = 1;
+                    }
+                    x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
+                    x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
+                    x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
+                    ggml_set_param(ctx0, x[0]);
+                    ggml_set_param(ctx0, x[1]);
+                    ggml_set_param(ctx0, x[2]);
+
+                    struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+
+                    check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
+                }
+            }
+        }
+        ggml_free(ctx0);
+    }
+
+    return 0;
+}
diff --git a/tests/test-opt.c b/tests/test-opt.c
deleted file mode 100644 (file)
index 4eef62b..0000000
+++ /dev/null
@@ -1,211 +0,0 @@
-#include "ggml.h"
-
-#include <math.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <assert.h>
-
-#define MAX_NARGS 2
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wdouble-promotion"
-#endif
-
-//
-// logging
-//
-#define GGML_DEBUG 0
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#if (GGML_DEBUG >= 5)
-#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_5(...)
-#endif
-
-#if (GGML_DEBUG >= 10)
-#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_10(...)
-#endif
-
-#define GGML_PRINT(...) printf(__VA_ARGS__)
-
-
-float frand(void) {
-    return (float)rand()/(float)RAND_MAX;
-}
-
-int irand(int n) {
-    return rand()%n;
-}
-
-void get_random_dims(int64_t * dims, int ndims) {
-    dims[0] = dims[1] = dims[2] = dims[3] = 1;
-
-    for (int i = 0; i < ndims; i++) {
-        dims[i] = 1 + irand(4);
-    }
-}
-
-void get_random_dims_minmax(int64_t * dims, int ndims, int min, int max) {
-    dims[0] = dims[1] = dims[2] = dims[3] = 1;
-
-    for (int i = 0; i < ndims; i++) {
-        dims[i] = min + irand(max-min);
-    }
-}
-
-
-struct ggml_tensor * get_random_tensor(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        float fmin,
-        float fmax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    };
-
-    return result;
-}
-
-float get_element(const struct ggml_tensor * t, int idx) {
-    return ((float *)t->data)[idx];
-}
-
-void set_element(struct ggml_tensor * t, int idx, float value) {
-    ((float *)t->data)[idx] = value;
-}
-
-int main(void) {
-    struct ggml_init_params params = {
-        .mem_size   = 1024*1024*1024,
-        .mem_buffer = NULL,
-        .no_alloc   = false,
-    };
-    struct ggml_context * ctx = ggml_init(params);
-
-    int64_t ne1[4] = {4, 128, 1, 1};
-    int64_t ne2[4] = {4, 256, 1, 1};;
-    int64_t ne3[4] = {128, 256, 1, 1};
-
-    struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
-    struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
-    ggml_set_param(ctx, a);
-    ggml_set_param(ctx, b);
-
-    struct ggml_tensor * c = get_random_tensor(ctx, 2, ne3, -1, +1);
-
-    struct ggml_tensor * ab = ggml_mul_mat(ctx, a, b);
-    struct ggml_tensor * d  = ggml_sub(ctx, c, ab);
-    struct ggml_tensor * e  = ggml_sum(ctx, ggml_sqr(ctx, d));
-
-    struct ggml_cgraph ge = ggml_build_forward(e);
-    ggml_graph_reset(&ge);
-
-    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
-
-    const float fe = ggml_get_f32_1d(e, 0);
-    printf("%s: e = %.4f\n", __func__, fe);
-
-    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
-
-    ggml_opt(ctx, opt_params, e);
-
-    ggml_graph_reset(&ge);
-
-    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
-
-    const float fe_opt = ggml_get_f32_1d(e, 0);
-    printf("%s: original  e = %.4f\n", __func__, fe);
-    printf("%s: optimized e = %.4f\n", __func__, fe_opt);
-
-    const bool success = (fe_opt <= fe);
-    assert(success);
-
-    ggml_free(ctx);
-    return success ? 0 : -1;
-}
-// int64_t ne1[4] = {4, 128, 1, 1};
-// int64_t ne2[4] = {4, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 25890.9375
-// main: optimized e = 10094.7031
-
-// int64_t ne1[4] = {8, 128, 1, 1};
-// int64_t ne2[4] = {8, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 39429.5078
-// main: optimized e = 9275.8936
-
-// int64_t ne1[4] = {16, 128, 1, 1};
-// int64_t ne2[4] = {16, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 68371.1328
-// main: optimized e = 7854.4502
-
-
-// int64_t ne1[4] = {32, 128, 1, 1};
-// int64_t ne2[4] = {32, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 126061.1953
-// main: optimized e = 5451.0166
-
-// int64_t ne1[4] = {4, 1024, 1, 1};
-// int64_t ne2[4] = {4, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 1620817.8750
-// main: optimized e = 698387.6875
-
-// another run on M1
-// int64_t ne1[4] = {4, 1024, 1, 1};
-// int64_t ne2[4] = {4, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 1629595.6250
-// main: optimized e = 698169.1250
-
-// int64_t ne1[4] = {32, 1024, 1, 1};
-// int64_t ne2[4] = {32, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 8146770.5000
-// main: optimized e = 651119.1250
diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp
new file mode 100644 (file)
index 0000000..8ab2402
--- /dev/null
@@ -0,0 +1,212 @@
+#include "ggml.h"
+
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cassert>
+
+#define MAX_NARGS 2
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
+#endif
+
+//
+// logging
+//
+#define GGML_DEBUG 0
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+
+float frand(void) {
+    return (float)rand()/(float)RAND_MAX;
+}
+
+int irand(int n) {
+    return rand()%n;
+}
+
+void get_random_dims(int64_t * dims, int ndims) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = 1 + irand(4);
+    }
+}
+
+void get_random_dims_minmax(int64_t * dims, int ndims, int min, int max) {
+    dims[0] = dims[1] = dims[2] = dims[3] = 1;
+
+    for (int i = 0; i < ndims; i++) {
+        dims[i] = min + irand(max-min);
+    }
+}
+
+
+struct ggml_tensor * get_random_tensor(
+        struct ggml_context * ctx0,
+        int ndims,
+        int64_t ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+float get_element(const struct ggml_tensor * t, int idx) {
+    return ((float *)t->data)[idx];
+}
+
+void set_element(struct ggml_tensor * t, int idx, float value) {
+    ((float *)t->data)[idx] = value;
+}
+
+int main(void) {
+    struct ggml_init_params params = {
+        /* .mem_size   = */ 1024*1024*1024,
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ false,
+    };
+
+    struct ggml_context * ctx = ggml_init(params);
+
+    int64_t ne1[4] = {4, 128, 1, 1};
+    int64_t ne2[4] = {4, 256, 1, 1};;
+    int64_t ne3[4] = {128, 256, 1, 1};
+
+    struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
+    struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
+    ggml_set_param(ctx, a);
+    ggml_set_param(ctx, b);
+
+    struct ggml_tensor * c = get_random_tensor(ctx, 2, ne3, -1, +1);
+
+    struct ggml_tensor * ab = ggml_mul_mat(ctx, a, b);
+    struct ggml_tensor * d  = ggml_sub(ctx, c, ab);
+    struct ggml_tensor * e  = ggml_sum(ctx, ggml_sqr(ctx, d));
+
+    struct ggml_cgraph ge = ggml_build_forward(e);
+    ggml_graph_reset(&ge);
+
+    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+
+    const float fe = ggml_get_f32_1d(e, 0);
+    printf("%s: e = %.4f\n", __func__, fe);
+
+    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
+
+    ggml_opt(ctx, opt_params, e);
+
+    ggml_graph_reset(&ge);
+
+    ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+
+    const float fe_opt = ggml_get_f32_1d(e, 0);
+    printf("%s: original  e = %.4f\n", __func__, fe);
+    printf("%s: optimized e = %.4f\n", __func__, fe_opt);
+
+    const bool success = (fe_opt <= fe);
+    assert(success);
+
+    ggml_free(ctx);
+    return success ? 0 : -1;
+}
+// int64_t ne1[4] = {4, 128, 1, 1};
+// int64_t ne2[4] = {4, 256, 1, 1};;
+// int64_t ne3[4] = {128, 256, 1, 1};
+// main: original  e = 25890.9375
+// main: optimized e = 10094.7031
+
+// int64_t ne1[4] = {8, 128, 1, 1};
+// int64_t ne2[4] = {8, 256, 1, 1};;
+// int64_t ne3[4] = {128, 256, 1, 1};
+// main: original  e = 39429.5078
+// main: optimized e = 9275.8936
+
+// int64_t ne1[4] = {16, 128, 1, 1};
+// int64_t ne2[4] = {16, 256, 1, 1};;
+// int64_t ne3[4] = {128, 256, 1, 1};
+// main: original  e = 68371.1328
+// main: optimized e = 7854.4502
+
+
+// int64_t ne1[4] = {32, 128, 1, 1};
+// int64_t ne2[4] = {32, 256, 1, 1};;
+// int64_t ne3[4] = {128, 256, 1, 1};
+// main: original  e = 126061.1953
+// main: optimized e = 5451.0166
+
+// int64_t ne1[4] = {4, 1024, 1, 1};
+// int64_t ne2[4] = {4, 2048, 1, 1};;
+// int64_t ne3[4] = {1024, 2048, 1, 1};
+// main: original  e = 1620817.8750
+// main: optimized e = 698387.6875
+
+// another run on M1
+// int64_t ne1[4] = {4, 1024, 1, 1};
+// int64_t ne2[4] = {4, 2048, 1, 1};;
+// int64_t ne3[4] = {1024, 2048, 1, 1};
+// main: original  e = 1629595.6250
+// main: optimized e = 698169.1250
+
+// int64_t ne1[4] = {32, 1024, 1, 1};
+// int64_t ne2[4] = {32, 2048, 1, 1};;
+// int64_t ne3[4] = {1024, 2048, 1, 1};
+// main: original  e = 8146770.5000
+// main: optimized e = 651119.1250