]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : llama.cpp (training, refactoring) (#548)
authorGeorgi Gerganov <redacted>
Wed, 4 Oct 2023 12:53:05 +0000 (15:53 +0300)
committerGitHub <redacted>
Wed, 4 Oct 2023 12:53:05 +0000 (15:53 +0300)
* sync : llama.cpp (training, refactoring)

* examples : fix ggml_rope

* ggml : better optimizer cancel handling

ggml-ci

* ggml : fix UBs

ggml-ci

* ggml : add TODO for refactoring the opt cancellation

23 files changed:
CMakeLists.txt
examples/dolly-v2/main.cpp
examples/gpt-j/main.cpp
examples/gpt-neox/main.cpp
include/ggml/ggml-alloc.h
include/ggml/ggml.h
src/ggml-alloc.c
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml-metal.h
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-opencl.cpp
src/ggml.c
tests/test-blas0.c
tests/test-grad0.cpp
tests/test-mul-mat1.c
tests/test-opt.cpp
tests/test-quantize-fns.cpp
tests/test-quantize-perf.cpp
tests/test-svd0.c
tests/test-vec2.c
tests/test-xpos.c

index bf1491bf31ea931a2ea0dc37ea2078c50c541bf8..db1179d035bdb36c33668f3906cc94be262872f0 100644 (file)
@@ -13,8 +13,20 @@ else()
     set(GGML_STANDALONE OFF)
 endif()
 
+if (EMSCRIPTEN)
+    set(BUILD_SHARED_LIBS_DEFAULT OFF)
+else()
+    if (MINGW)
+        set(BUILD_SHARED_LIBS_DEFAULT OFF)
+    else()
+        set(BUILD_SHARED_LIBS_DEFAULT ON)
+    endif()
+endif()
+
 # options
 
+option(BUILD_SHARED_LIBS            "ggml: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
+
 option(GGML_ALL_WARNINGS            "ggml: enable all compiler warnings"                   ON)
 option(GGML_ALL_WARNINGS_3RD_PARTY  "ggml: enable all compiler warnings in 3rd party libs" OFF)
 
index 70200eb3a1f2c2ef5f49f0eda72f9933c695a971..1a10880bb551336a49389714e47865212215deb5 100644 (file)
@@ -499,6 +499,13 @@ bool dollyv2_eval(
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = { };
 
+    // KQ_pos - contains the positions
+    struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    int * data = (int *) KQ_pos->data;
+    for (int i = 0; i < N; ++i) {
+        data[i] = n_past + i;
+    }
+
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
 
@@ -536,8 +543,8 @@ bool dollyv2_eval(
             struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head));
 
             // using mode = 2 for GPT-NeoX mode
-            Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0);
-            Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0);
+            Qcur = ggml_rope_inplace(ctx0, Qcur, KQ_pos, n_rot, 2, 0);
+            Kcur = ggml_rope_inplace(ctx0, Kcur, KQ_pos, n_rot, 2, 0);
 
             // store key and value to memory
             {
index e5ab5badf4dd8a8968220d0ea150809a51df2f61..c2c4c1fef9067c66adec5b95d038f6c1c3c31729 100644 (file)
@@ -427,6 +427,13 @@ bool gptj_eval(
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
 
+    // KQ_pos - contains the positions
+    struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    int * data = (int *) KQ_pos->data;
+    for (int i = 0; i < N; ++i) {
+        data[i] = n_past + i;
+    }
+
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
 
@@ -452,8 +459,8 @@ bool gptj_eval(
 
         // self-attention
         {
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
+            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
 
             // store key and value to memory
             {
index 457a58387166f5fb6f592e36a23dce5047e3afdc..43e16ce3e9ed8087c61615ea6ba23e7924a063a3 100644 (file)
@@ -479,6 +479,13 @@ bool gpt_neox_eval(
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
 
+    // KQ_pos - contains the positions
+    struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    int * data = (int *) KQ_pos->data;
+    for (int i = 0; i < N; ++i) {
+        data[i] = n_past + i;
+    }
+
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
 
@@ -518,8 +525,8 @@ bool gpt_neox_eval(
             struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head));
 
             // using mode = 2 for GPT-NeoX mode
-            Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0);
-            Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0);
+            Qcur = ggml_rope_inplace(ctx0, Qcur, KQ_pos, n_rot, 2, 0);
+            Kcur = ggml_rope_inplace(ctx0, Kcur, KQ_pos, n_rot, 2, 0);
 
             // store key and value to memory
             {
index 9559da75871a608fb29f08edebf860674295e043..0c224f174f396eb187ddc5a48f955f7c0ad4418e 100644 (file)
@@ -19,6 +19,7 @@ GGML_API bool   ggml_allocr_is_measure(struct ggml_allocr * alloc);
 GGML_API void   ggml_allocr_reset(struct ggml_allocr * alloc);
 GGML_API void   ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
 GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
+GGML_API size_t ggml_allocr_max_size(struct ggml_allocr * alloc);
 
 
 #ifdef  __cplusplus
index 751a60b6b24a872e7dcddea1081f5b7f8f8e3b75..a9d4e33d9885c7c427764ecfc4801ec1bb6a7108 100644 (file)
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 #define GGML_MAX_DIMS          4
-#define GGML_MAX_NODES         4096
-#define GGML_MAX_PARAMS        256
+#define GGML_MAX_NODES         16384
+#define GGML_MAX_PARAMS        1024
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_SRC           6
 #define GGML_MAX_NAME          64
         } \
     } while (0)
 
+#ifndef NDEBUG
+#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
+#elif defined(__GNUC__)
+#define GGML_UNREACHABLE() __builtin_unreachable()
+#else
+#define GGML_UNREACHABLE() ((void) 0)
+#endif
+
 // used to copy the number of elements and stride in bytes of tensors into local variables.
 // main purpose is to reduce code duplication and improve readability.
 //
@@ -449,6 +457,12 @@ extern "C" {
         GGML_OBJECT_WORK_BUFFER
     };
 
+    enum ggml_log_level {
+        GGML_LOG_LEVEL_ERROR = 2,
+        GGML_LOG_LEVEL_WARN = 3,
+        GGML_LOG_LEVEL_INFO = 4
+    };
+
     // ggml object
     struct ggml_object {
         size_t offs;
@@ -471,8 +485,8 @@ extern "C" {
         int     n_dims;
         int64_t ne[GGML_MAX_DIMS]; // number of elements
         size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
-                                   // nb[0] = sizeof(type)
-                                   // nb[1] = nb[0]   * ne[0] + padding
+                                   // nb[0] = ggml_type_size(type)
+                                   // nb[1] = nb[0]   * (ne[0] / ggml_blck_size(type)) + padding
                                    // nb[i] = nb[i-1] * ne[i-1]
 
         // compute data
@@ -524,7 +538,15 @@ extern "C" {
     // next prime after GGML_MAX_NODES
     // #define GGML_GRAPH_HASHTABLE_SIZE 4099
     // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
-    #define GGML_GRAPH_HASHTABLE_SIZE 8273
+    // #define GGML_GRAPH_HASHTABLE_SIZE 8273
+    // #define GGML_GRAPH_HASHTABLE_SIZE 16411
+    #define GGML_GRAPH_HASHTABLE_SIZE 32771
+
+    enum ggml_cgraph_eval_order {
+        GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
+        GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
+        GGML_CGRAPH_EVAL_ORDER_COUNT
+    };
 
     // computation graph
     struct ggml_cgraph {
@@ -537,6 +559,8 @@ extern "C" {
 
         void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
 
+        enum ggml_cgraph_eval_order order;
+
         // performance
         int     perf_runs;
         int64_t perf_cycles;
@@ -684,12 +708,21 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
     GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
 
+    // Converts a flat index into coordinates
+    GGML_API void    ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
+
     GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
     GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
 
+    GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_API void    ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
+
     GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
     GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
 
+    GGML_API float   ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_API void    ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
+
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
@@ -723,6 +756,12 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_add_cast(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            enum   ggml_type      type);
+
     GGML_API struct ggml_tensor * ggml_add1(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -832,6 +871,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // sums repetitions in a into shape of b
     GGML_API struct ggml_tensor * ggml_repeat_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1053,7 +1093,6 @@ extern "C" {
             size_t                nb1,
             size_t                offset);
 
-
     // a -> b, return view(b)
     GGML_API struct ggml_tensor * ggml_cpy(
             struct ggml_context * ctx,
@@ -1076,6 +1115,33 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // make contiguous, with new shape
+    GGML_API struct ggml_tensor * ggml_cont_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0);
+
+    GGML_API struct ggml_tensor * ggml_cont_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1);
+
+    GGML_API struct ggml_tensor * ggml_cont_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2);
+
+    GGML_API struct ggml_tensor * ggml_cont_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3);
+
     // return view(a), b specifies the new shape
     // TODO: when we start computing gradient, make a copy instead of view
     GGML_API struct ggml_tensor * ggml_reshape(
@@ -1223,14 +1289,15 @@ extern "C" {
             struct ggml_tensor  * b);
 
     // rotary position embedding
-    // if mode & 1 == 1, skip n_past elements
+    // if mode & 1 == 1, skip n_past elements (DEPRECATED)
     // if mode & 2 == 1, GPT-NeoX style
     // if mode & 4 == 1, ChatGLM style
-    // TODO: avoid creating a new tensor every time
+    //
+    // b is an int32 vector with size a->ne[2], it contains the positions
     GGML_API struct ggml_tensor * ggml_rope(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
             int                   n_ctx);
@@ -1239,7 +1306,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_rope_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
             int                   n_ctx);
@@ -1248,7 +1315,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_rope_custom(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
@@ -1259,7 +1326,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
@@ -1270,7 +1337,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             float                 base,
             bool                  down);
@@ -1280,7 +1347,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_rope_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   n_past,
+            struct ggml_tensor  * b,
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
@@ -1668,6 +1735,16 @@ extern "C" {
     // dump the graph into a file using the dot format
     GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
 
+    // build gradient checkpointing backward graph gb for gf using provided checkpoints
+    // gb_tmp will contain original backward graph with rewritten backward process nodes,
+    // but without the second forward pass nodes.
+    GGML_API void ggml_build_backward_gradient_checkpointing(
+            struct ggml_context   * ctx,
+            struct ggml_cgraph    * gf,
+            struct ggml_cgraph    * gb,
+            struct ggml_cgraph    * gb_tmp,
+            struct ggml_tensor  * * checkpoints,
+            int                     n_checkpoints);
     //
     // optimization
     //
@@ -1694,6 +1771,7 @@ extern "C" {
         GGML_OPT_NO_CONTEXT,
         GGML_OPT_INVALID_WOLFE,
         GGML_OPT_FAIL,
+        GGML_OPT_CANCEL,
 
         GGML_LINESEARCH_FAIL = -128,
         GGML_LINESEARCH_MINIMUM_STEP,
@@ -1702,7 +1780,8 @@ extern "C" {
         GGML_LINESEARCH_INVALID_PARAMETERS,
     };
 
-    typedef void (*ggml_opt_callback)(void * data, float * sched);
+    typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
+    typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
 
     // optimization parameters
     //
@@ -1733,6 +1812,8 @@ extern "C" {
         bool print_forward_graph;
         bool print_backward_graph;
 
+        int n_gradient_accumulation;
+
         // ADAM parameters
         struct {
             int n_iter;
@@ -1778,6 +1859,7 @@ extern "C" {
         float loss_after;
 
         struct {
+            struct ggml_tensor * g;  // current gradient
             struct ggml_tensor * m;  // first moment
             struct ggml_tensor * v;  // second moment
             struct ggml_tensor * pf; // past function values
@@ -1894,26 +1976,26 @@ extern "C" {
 
     GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
     GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
-    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
-
-    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
-    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
-
-    // results are undefined if the wrong type is used for the key
-    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int i);
-    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int i);
-    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int i);
-    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int i);
-    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int i);
-    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int i);
-    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int i);
-    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int i);
-    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int i);
-    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int i);
-    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int i);
-    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
-    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int i);
-    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
+
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
+
+    // will abort if the wrong type is used for the key
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
+    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
     GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
 
     GGML_API int    gguf_get_n_tensors    (const struct gguf_context * ctx);
index 304964be4f3f07e1c10db0f46c4e013e813c989b..805759db74fef6f8175e1ab47dfe711b640da100 100644 (file)
@@ -77,7 +77,7 @@ struct free_block {
     size_t size;
 };
 
-#define MAX_FREE_BLOCKS 128
+#define MAX_FREE_BLOCKS 256
 
 struct ggml_allocr {
     void * data;
@@ -187,6 +187,7 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
     }
 
     tensor->data = addr;
+    AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data);
 
 #ifdef GGML_ALLOCATOR_DEBUG
     add_allocated_tensor(alloc, tensor);
@@ -218,7 +219,8 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
 
     size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
     size = aligned_offset(NULL, size, alloc->alignment);
-    AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
+    AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
+    AT_PRINTF("%s: alloc->data = %p alloc->data+alloc->size = %p alloc->data+alloc->max_size = %p\n", __func__, alloc->data, (char*)alloc->data + alloc->size, (char*)alloc->data + alloc->max_size);
 
 #ifdef GGML_ALLOCATOR_DEBUG
     remove_allocated_tensor(alloc, tensor);
@@ -631,3 +633,7 @@ static size_t ggml_allocr_alloc_graph_tensors_n(
 size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
     return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
 }
+
+size_t ggml_allocr_max_size(struct ggml_allocr * alloc) {
+    return alloc->max_size;
+}
index e0163ae0c6c33c59772a55382ca6444dbcd3111e..989c419cd0ea44a6db666f3b63958f9ae6bb642f 100644 (file)
@@ -1,3 +1,4 @@
+#include <algorithm>
 #include <cstddef>
 #include <cstdint>
 #include <limits>
 #ifdef __HIP_PLATFORM_AMD__
 // for rocblas_initialize()
 #include "rocblas/rocblas.h"
-#endif
+#endif // __HIP_PLATFORM_AMD__
+#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
 #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
 #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_OP_N HIPBLAS_OP_N
 #define CUBLAS_OP_T HIPBLAS_OP_T
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
@@ -31,6 +34,9 @@
 #define cublasSetStream hipblasSetStream
 #define cublasSgemm hipblasSgemm
 #define cublasStatus_t hipblasStatus_t
+#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
 #define cudaDeviceProp hipDeviceProp_t
 #define cudaDeviceSynchronize hipDeviceSynchronize
 #define cudaError_t hipError_t
 #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
 #define cudaStreamNonBlocking hipStreamNonBlocking
 #define cudaStreamSynchronize hipStreamSynchronize
-#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
+#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
 #else
 #include <cuda_runtime.h>
 #include <cublas_v2.h>
 #include <cuda_fp16.h>
-#endif
+#endif // defined(GGML_USE_HIPBLAS)
 
 #include "ggml-cuda.h"
 #include "ggml.h"
 
-#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#ifndef CC_TURING
-#define CC_TURING   700
-#endif
+#define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define CC_VOLTA      700
+#define CC_OFFSET_AMD 1000000
+#define CC_RDNA2      (CC_OFFSET_AMD + 1030)
 
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
 
+#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
+    defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3
+#endif
+
+#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
+    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
+#define RDNA2
+#endif
+
 #ifndef __has_builtin
     #define __has_builtin(x) 0
 #endif
@@ -132,7 +148,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 #endif
     return c;
 }
-#endif
+#endif // defined(GGML_USE_HIPBLAS)
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -144,8 +160,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cudaError_t err_ = (err);                                                       \
         if (err_ != cudaSuccess) {                                                      \
-            fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
+            int id;                                                                     \
+            cudaGetDevice(&id);                                                         \
+            fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
                 cudaGetErrorString(err_));                                              \
+            fprintf(stderr, "current device: %d\n", id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -155,8 +174,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            int id;                                                                     \
+            cudaGetDevice(&id);                                                         \
             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
+            fprintf(stderr, "current device: %d\n", id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -165,12 +187,21 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            int id;                                                                     \
+            cudaGetDevice(&id);                                                         \
             fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);  \
+            fprintf(stderr, "current device: %d\n", id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
 #endif // CUDART_VERSION >= 11
 
+#if CUDART_VERSION >= 11100
+#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
+#else
+#define GGML_CUDA_ASSUME(x)
+#endif // CUDART_VERSION >= 11100
+
 #ifdef GGML_CUDA_F16
 typedef half dfloat; // dequantize float
 typedef half2 dfloat2;
@@ -207,15 +238,22 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
     return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
 }
 
+template<typename T>
+using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
+typedef to_t_cuda_t<float> to_fp32_cuda_t;
+typedef to_t_cuda_t<half> to_fp16_cuda_t;
+
 typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
-typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
 typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
-typedef void (*ggml_cuda_op_t)(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
-    float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main);
+typedef void (*ggml_cuda_op_mul_mat_t)(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream);
+typedef void (*ggml_cuda_op_flatten_t)(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream);
 
 // QK = number of values after dequantization
 // QR = QK / number of values before dequantization
@@ -396,11 +434,33 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
 #endif
 
+#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
+#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
+#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
+
+#define MUL_MAT_SRC1_COL_STRIDE 128
+
+#define MAX_STREAMS 8
+static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
+
 struct ggml_tensor_extra_gpu {
     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
-    cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
+    cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
 };
 
+// this is faster on Windows
+// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
+inline cudaError_t ggml_cuda_set_device(const int device) {
+    int current_device;
+    CUDA_CHECK(cudaGetDevice(&current_device));
+
+    if (device == current_device) {
+        return cudaSuccess;
+    }
+
+    return cudaSetDevice(device);
+}
+
 static int g_device_count = -1;
 static int g_main_device = 0;
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
@@ -408,13 +468,11 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 static bool g_mul_mat_q = true;
 
 static void * g_scratch_buffer = nullptr;
-static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
+static size_t g_scratch_size = 0; // disabled by default
 static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
-
 static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -657,7 +715,8 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
 
 //================================== k-quants
 
-static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
     const int i   = blockIdx.x;
     const block_q2_K * x = (const block_q2_K *) vx;
@@ -669,7 +728,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const int is  = 8*n + l/16;
 
     const uint8_t q = x[i].qs[32*n + l];
-    float * y = yy + i*QK_K + 128*n;
+    dst_t * y = yy + i*QK_K + 128*n;
 
     float dall = __low2half(x[i].dm);
     float dmin = __high2half(x[i].dm);
@@ -681,7 +740,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const int is = tid/16;  // 0 or 1
     const int il = tid%16;  // 0...15
     const uint8_t q = x[i].qs[il] >> (2*is);
-    float * y = yy + i*QK_K + 16*is + il;
+    dst_t * y = yy + i*QK_K + 16*is + il;
     float dall = __low2half(x[i].dm);
     float dmin = __high2half(x[i].dm);
     y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
@@ -690,7 +749,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
 
 }
 
-static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
     const int i = blockIdx.x;
     const block_q3_K * x = (const block_q3_K *) vx;
@@ -714,7 +774,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
     float d_all = x[i].d;
     float dl = d_all * (us - 32);
 
-    float * y = yy + i*QK_K + 128*n + 32*j;
+    dst_t * y = yy + i*QK_K + 128*n + 32*j;
     const uint8_t * q = x[i].qs + 32*n;
     const uint8_t * hm = x[i].hmask;
 
@@ -726,7 +786,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
     const int im  = il/8;    // 0...1
     const int in  = il%8;    // 0...7
 
-    float * y = yy + i*QK_K + 16*is + il;
+    dst_t * y = yy + i*QK_K + 16*is + il;
 
     const uint8_t q = x[i].qs[il] >> (2*is);
     const uint8_t h = x[i].hmask[in] >> (2*is + im);
@@ -754,7 +814,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
 }
 #endif
 
-static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
     const int i = blockIdx.x;
@@ -767,7 +828,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
     const int is  = 2*il;
     const int n   = 4;
 
-    float * y = yy + i*QK_K + 64*il + n*ir;
+    dst_t * y = yy + i*QK_K + 64*il + n*ir;
 
     const float dall = __low2half(x[i].dm);
     const float dmin = __high2half(x[i].dm);
@@ -786,7 +847,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
 #else
     const int tid = threadIdx.x;
     const uint8_t * q = x[i].qs;
-    float * y = yy + i*QK_K;
+    dst_t * y = yy + i*QK_K;
     const float d = (float)x[i].dm[0];
     const float m = (float)x[i].dm[1];
     y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
@@ -794,7 +855,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
 #endif
 }
 
-static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q5_K * x = (const block_q5_K *) vx;
 
     const int i = blockIdx.x;
@@ -806,7 +868,7 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
     const int ir  = tid%16;   // ir is in 0...15
     const int is  = 2*il;     // is is in 0...6
 
-    float * y = yy + i*QK_K + 64*il + 2*ir;
+    dst_t * y = yy + i*QK_K + 64*il + 2*ir;
 
     const float dall = __low2half(x[i].dm);
     const float dmin = __high2half(x[i].dm);
@@ -834,13 +896,14 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
     const int is = tid/16; // 0 or 1
     const uint8_t h = x[i].qh[in] >> im;
     const float d = x[i].d;
-    float * y = yy + i*QK_K + tid;
+    dst_t * y = yy + i*QK_K + tid;
     y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
     y[32] = d * x[i].scales[is+2] * ((q >>  4) - ((h >> 4) & 1 ? 0 : 16));
 #endif
 }
 
-static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
+template<typename dst_t>
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
     const int i = blockIdx.x;
@@ -852,7 +915,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
     const int il  = tid - 32*ip; // 0...32
     const int is  = 8*ip + il/16;
 
-    float * y = yy + i*QK_K + 128*ip + il;
+    dst_t * y = yy + i*QK_K + 128*ip + il;
 
     const float d = x[i].d;
 
@@ -871,7 +934,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
     const int ip  = tid/16;         // 0 or 1
     const int il  = tid - 16*ip;    // 0...15
 
-    float * y = yy + i*QK_K + 16*ip + il;
+    dst_t * y = yy + i*QK_K + 16*ip + il;
 
     const float d = x[i].d;
 
@@ -1464,6 +1527,14 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
     v.y = x[ib + iqs + 1];
 }
 
+static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const float * x = (const float *) vx;
+
+    // automatic half -> float type cast if dfloat == float
+    v.x = x[ib + iqs + 0];
+    v.y = x[ib + iqs + 1];
+}
+
 static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
     const int ix = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -1503,8 +1574,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
     reinterpret_cast<half&>(y[ib].ds.y) = sum;
 }
 
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
     const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
 
     if (i >= k) {
@@ -2107,10 +2178,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI4_0;
     const int kqsx = k % QI4_0;
@@ -2201,10 +2272,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI4_1;
     const int kqsx = k % QI4_1;
@@ -2293,10 +2364,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI5_0;
     const int kqsx = k % QI5_0;
@@ -2407,10 +2478,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset < nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset < nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI5_1;
     const int kqsx = k % QI5_1;
@@ -2513,10 +2584,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI8_0;
     const int kqsx = k % QI8_0;
@@ -2604,10 +2675,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI2_K;
     const int kqsx = k % QI2_K;
@@ -2725,10 +2796,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI3_K;
     const int kqsx = k % QI3_K;
@@ -2943,10 +3014,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI4_K; // == 0 if QK_K == 256
     const int kqsx = k % QI4_K; // == k if QK_K == 256
@@ -3124,10 +3195,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI5_K; // == 0 if QK_K == 256
     const int kqsx = k % QI5_K; // == k if QK_K == 256
@@ -3253,10 +3324,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
-    __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  nwarps);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
+    GGML_CUDA_ASSUME(i_offset >= 0);
+    GGML_CUDA_ASSUME(i_offset <  nwarps);
+    GGML_CUDA_ASSUME(k >= 0);
+    GGML_CUDA_ASSUME(k <  WARP_SIZE);
 
     const int kbx  = k / QI6_K; // == 0 if QK_K == 256
     const int kqsx = k % QI6_K; // == k if QK_K == 256
@@ -3444,6 +3515,12 @@ static __device__ __forceinline__ void mul_mat_q(
     }
 }
 
+#define  MMQ_X_Q4_0_RDNA2  64
+#define  MMQ_Y_Q4_0_RDNA2  128
+#define NWARPS_Q4_0_RDNA2  8
+#define  MMQ_X_Q4_0_RDNA1  64
+#define  MMQ_Y_Q4_0_RDNA1  64
+#define NWARPS_Q4_0_RDNA1  8
 #define  MMQ_X_Q4_0_AMPERE 64
 #define  MMQ_Y_Q4_0_AMPERE 128
 #define NWARPS_Q4_0_AMPERE 4
@@ -3451,11 +3528,32 @@ static __device__ __forceinline__ void mul_mat_q(
 #define  MMQ_Y_Q4_0_PASCAL 64
 #define NWARPS_Q4_0_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q4_0(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    mul_mat_q4_0(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q4_0_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q4_0_RDNA2;
+    const int nwarps = NWARPS_Q4_0_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q4_0_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q4_0_RDNA1;
+    const int nwarps = NWARPS_Q4_0_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q4_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;
     const int nwarps = NWARPS_Q4_0_AMPERE;
@@ -3475,9 +3573,15 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
 #else
     (void) vec_dot_q4_0_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q4_1_RDNA2  64
+#define  MMQ_Y_Q4_1_RDNA2  128
+#define NWARPS_Q4_1_RDNA2  8
+#define  MMQ_X_Q4_1_RDNA1  64
+#define  MMQ_Y_Q4_1_RDNA1  64
+#define NWARPS_Q4_1_RDNA1  8
 #define  MMQ_X_Q4_1_AMPERE 64
 #define  MMQ_Y_Q4_1_AMPERE 128
 #define NWARPS_Q4_1_AMPERE 4
@@ -3486,14 +3590,33 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
 #define NWARPS_Q4_1_PASCAL 8
 
 template <bool need_check> static __global__ void
-#if __CUDA_ARCH__ < CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#elif __CUDA_ARCH__ < CC_VOLTA
     __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
-#endif // __CUDA_ARCH__ < CC_TURING
+#endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q4_1(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q4_1_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q4_1_RDNA2;
+    const int nwarps = NWARPS_Q4_1_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q4_1_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q4_1_RDNA1;
+    const int nwarps = NWARPS_Q4_1_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q4_1_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;
     const int nwarps = NWARPS_Q4_1_AMPERE;
@@ -3513,9 +3636,15 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q4_1_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q5_0_RDNA2  64
+#define  MMQ_Y_Q5_0_RDNA2  128
+#define NWARPS_Q5_0_RDNA2  8
+#define  MMQ_X_Q5_0_RDNA1  64
+#define  MMQ_Y_Q5_0_RDNA1  64
+#define NWARPS_Q5_0_RDNA1  8
 #define  MMQ_X_Q5_0_AMPERE 128
 #define  MMQ_Y_Q5_0_AMPERE 64
 #define NWARPS_Q5_0_AMPERE 4
@@ -3523,11 +3652,32 @@ template <bool need_check> static __global__ void
 #define  MMQ_Y_Q5_0_PASCAL 64
 #define NWARPS_Q5_0_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q5_0(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    mul_mat_q5_0(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q5_0_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q5_0_RDNA2;
+    const int nwarps = NWARPS_Q5_0_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q5_0_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q5_0_RDNA1;
+    const int nwarps = NWARPS_Q5_0_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q5_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;
     const int nwarps = NWARPS_Q5_0_AMPERE;
@@ -3547,9 +3697,15 @@ template <bool need_check> static __global__ void mul_mat_q5_0(
 #else
     (void) vec_dot_q5_0_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q5_1_RDNA2  64
+#define  MMQ_Y_Q5_1_RDNA2  128
+#define NWARPS_Q5_1_RDNA2  8
+#define  MMQ_X_Q5_1_RDNA1  64
+#define  MMQ_Y_Q5_1_RDNA1  64
+#define NWARPS_Q5_1_RDNA1  8
 #define  MMQ_X_Q5_1_AMPERE 128
 #define  MMQ_Y_Q5_1_AMPERE 64
 #define NWARPS_Q5_1_AMPERE 4
@@ -3557,11 +3713,32 @@ template <bool need_check> static __global__ void mul_mat_q5_0(
 #define  MMQ_Y_Q5_1_PASCAL 64
 #define NWARPS_Q5_1_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q5_1(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+mul_mat_q5_1(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q5_1_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q5_1_RDNA2;
+    const int nwarps = NWARPS_Q5_1_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q5_1_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q5_1_RDNA1;
+    const int nwarps = NWARPS_Q5_1_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q5_1_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;
     const int nwarps = NWARPS_Q5_1_AMPERE;
@@ -3581,9 +3758,15 @@ template <bool need_check> static __global__ void mul_mat_q5_1(
 #else
     (void) vec_dot_q5_1_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q8_0_RDNA2  64
+#define  MMQ_Y_Q8_0_RDNA2  128
+#define NWARPS_Q8_0_RDNA2  8
+#define  MMQ_X_Q8_0_RDNA1  64
+#define  MMQ_Y_Q8_0_RDNA1  64
+#define NWARPS_Q8_0_RDNA1  8
 #define  MMQ_X_Q8_0_AMPERE 128
 #define  MMQ_Y_Q8_0_AMPERE 64
 #define NWARPS_Q8_0_AMPERE 4
@@ -3591,11 +3774,32 @@ template <bool need_check> static __global__ void mul_mat_q5_1(
 #define  MMQ_Y_Q8_0_PASCAL 64
 #define NWARPS_Q8_0_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q8_0(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    mul_mat_q8_0(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q8_0_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q8_0_RDNA2;
+    const int nwarps = NWARPS_Q8_0_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q8_0_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q8_0_RDNA1;
+    const int nwarps = NWARPS_Q8_0_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q8_0_AMPERE;
     const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;
     const int nwarps = NWARPS_Q8_0_AMPERE;
@@ -3615,9 +3819,15 @@ template <bool need_check> static __global__ void mul_mat_q8_0(
 #else
     (void) vec_dot_q8_0_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q2_K_RDNA2  64
+#define  MMQ_Y_Q2_K_RDNA2  128
+#define NWARPS_Q2_K_RDNA2  8
+#define  MMQ_X_Q2_K_RDNA1  128
+#define  MMQ_Y_Q2_K_RDNA1  32
+#define NWARPS_Q2_K_RDNA1  8
 #define  MMQ_X_Q2_K_AMPERE 64
 #define  MMQ_Y_Q2_K_AMPERE 128
 #define NWARPS_Q2_K_AMPERE 4
@@ -3625,11 +3835,32 @@ template <bool need_check> static __global__ void mul_mat_q8_0(
 #define  MMQ_Y_Q2_K_PASCAL 64
 #define NWARPS_Q2_K_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q2_K(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+mul_mat_q2_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q2_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q2_K_RDNA2;
+    const int nwarps = NWARPS_Q2_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q2_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q2_K_RDNA1;
+    const int nwarps = NWARPS_Q2_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q2_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;
     const int nwarps = NWARPS_Q2_K_AMPERE;
@@ -3649,9 +3880,15 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
 #else
     (void) vec_dot_q2_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q3_K_RDNA2  128
+#define  MMQ_Y_Q3_K_RDNA2  64
+#define NWARPS_Q3_K_RDNA2  8
+#define  MMQ_X_Q3_K_RDNA1  32
+#define  MMQ_Y_Q3_K_RDNA1  128
+#define NWARPS_Q3_K_RDNA1  8
 #define  MMQ_X_Q3_K_AMPERE 128
 #define  MMQ_Y_Q3_K_AMPERE 128
 #define NWARPS_Q3_K_AMPERE 4
@@ -3660,14 +3897,33 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
 #define NWARPS_Q3_K_PASCAL 8
 
 template <bool need_check> static __global__ void
-#if __CUDA_ARCH__ < CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#elif __CUDA_ARCH__ < CC_VOLTA
     __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
-#endif // __CUDA_ARCH__ < CC_TURING
+#endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q3_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q3_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q3_K_RDNA2;
+    const int nwarps = NWARPS_Q3_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q3_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q3_K_RDNA1;
+    const int nwarps = NWARPS_Q3_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q3_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;
     const int nwarps = NWARPS_Q3_K_AMPERE;
@@ -3687,9 +3943,15 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q3_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q4_K_RDNA2  64
+#define  MMQ_Y_Q4_K_RDNA2  128
+#define NWARPS_Q4_K_RDNA2  8
+#define  MMQ_X_Q4_K_RDNA1  32
+#define  MMQ_Y_Q4_K_RDNA1  64
+#define NWARPS_Q4_K_RDNA1  8
 #define  MMQ_X_Q4_K_AMPERE 64
 #define  MMQ_Y_Q4_K_AMPERE 128
 #define NWARPS_Q4_K_AMPERE 4
@@ -3698,14 +3960,33 @@ template <bool need_check> static __global__ void
 #define NWARPS_Q4_K_PASCAL 8
 
 template <bool need_check> static __global__ void
-#if __CUDA_ARCH__ < CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#elif __CUDA_ARCH__ < CC_VOLTA
     __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
-#endif // __CUDA_ARCH__ < CC_TURING
+#endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q4_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q4_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q4_K_RDNA2;
+    const int nwarps = NWARPS_Q4_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q4_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q4_K_RDNA1;
+    const int nwarps = NWARPS_Q4_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q4_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;
     const int nwarps = NWARPS_Q4_K_AMPERE;
@@ -3725,9 +4006,15 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q4_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q5_K_RDNA2  64
+#define  MMQ_Y_Q5_K_RDNA2  128
+#define NWARPS_Q5_K_RDNA2  8
+#define  MMQ_X_Q5_K_RDNA1  32
+#define  MMQ_Y_Q5_K_RDNA1  64
+#define NWARPS_Q5_K_RDNA1  8
 #define  MMQ_X_Q5_K_AMPERE 64
 #define  MMQ_Y_Q5_K_AMPERE 128
 #define NWARPS_Q5_K_AMPERE 4
@@ -3735,11 +4022,32 @@ template <bool need_check> static __global__ void
 #define  MMQ_Y_Q5_K_PASCAL 64
 #define NWARPS_Q5_K_PASCAL 8
 
-template <bool need_check> static __global__ void mul_mat_q5_K(
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+mul_mat_q5_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q5_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q5_K_RDNA2;
+    const int nwarps = NWARPS_Q5_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q5_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q5_K_RDNA1;
+    const int nwarps = NWARPS_Q5_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q5_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;
     const int nwarps = NWARPS_Q5_K_AMPERE;
@@ -3759,9 +4067,15 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
 #else
     (void) vec_dot_q5_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
+#define  MMQ_X_Q6_K_RDNA2  64
+#define  MMQ_Y_Q6_K_RDNA2  128
+#define NWARPS_Q6_K_RDNA2  8
+#define  MMQ_X_Q6_K_RDNA1  32
+#define  MMQ_Y_Q6_K_RDNA1  64
+#define NWARPS_Q6_K_RDNA1  8
 #define  MMQ_X_Q6_K_AMPERE 64
 #define  MMQ_Y_Q6_K_AMPERE 64
 #define NWARPS_Q6_K_AMPERE 4
@@ -3770,14 +4084,33 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
 #define NWARPS_Q6_K_PASCAL 8
 
 template <bool need_check> static __global__ void
-#if __CUDA_ARCH__ < CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#elif __CUDA_ARCH__ < CC_VOLTA
     __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
-#endif // __CUDA_ARCH__ < CC_TURING
+#endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q6_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if __CUDA_ARCH__ >= CC_TURING
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    const int mmq_x  =  MMQ_X_Q6_K_RDNA2;
+    const int mmq_y  =  MMQ_Y_Q6_K_RDNA2;
+    const int nwarps = NWARPS_Q6_K_RDNA2;
+#else
+    const int mmq_x  =  MMQ_X_Q6_K_RDNA1;
+    const int mmq_y  =  MMQ_Y_Q6_K_RDNA1;
+    const int nwarps = NWARPS_Q6_K_RDNA1;
+#endif // defined(RDNA3) || defined(RDNA2)
+
+    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= CC_VOLTA
     const int mmq_x  =  MMQ_X_Q6_K_AMPERE;
     const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;
     const int nwarps = NWARPS_Q6_K_AMPERE;
@@ -3797,7 +4130,7 @@ template <bool need_check> static __global__ void
 #else
     (void) vec_dot_q6_K_q8_1_mul_mat;
     assert(false);
-#endif // __CUDA_ARCH__ >= CC_TURING
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
 template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
@@ -4042,8 +4375,10 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
 }
 
 // rope == RoPE == rotary positional embedding
-static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
-                                const float p_delta, const int p_delta_rows, const float theta_scale) {
+
+template<typename T, bool has_pos>
+static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
+                            const int p_delta_rows, const float theta_scale) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (col >= ncols) {
@@ -4052,8 +4387,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
     const int i = row*ncols + col;
+    const int i2 = row/p_delta_rows;
 
-    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
+    const int p = has_pos ? pos[i2] : 0;
+    const float p0 = p*freq_scale;
+    const float theta = p0*powf(theta_scale, col/2);
     const float sin_theta = sinf(theta);
     const float cos_theta = cosf(theta);
 
@@ -4064,8 +4402,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
-                                const float p_delta, const int p_delta_rows, const float theta_scale) {
+template<typename T, bool has_pos>
+static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
+                                 const int p_delta_rows, const float theta_scale) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (col >= ncols) {
@@ -4074,8 +4413,11 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
     const int i = row*ncols + col/2;
+    const int i2 = row/p_delta_rows;
 
-    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
+    const int p = has_pos ? pos[i2] : 0;
+    const float p0 = p*freq_scale;
+    const float theta = p0*powf(theta_scale, col/2);
     const float sin_theta = sinf(theta);
     const float cos_theta = cosf(theta);
 
@@ -4086,8 +4428,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
     dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
-                                    const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
+static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
+                                    const int p_delta_rows, const float theta_scale, const int n_ctx) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int half_n_dims = ncols/4;
 
@@ -4097,11 +4439,13 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
 
     const int row = blockDim.y*blockIdx.y + threadIdx.y;
     const int i = row*ncols + col;
+    const int i2 = row/p_delta_rows;
 
     const float col_theta_scale = powf(theta_scale, col);
-    const float p = p0 + p_delta*(row/p_delta_rows);
+     // FIXME: this is likely wrong
+    const int p = pos != nullptr ? pos[i2] : 0;
 
-    const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
+    const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
     const float sin_theta = sinf(theta);
     const float cos_theta = cosf(theta);
 
@@ -4111,7 +4455,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     dst[i + 0]           = x0*cos_theta - x1*sin_theta;
     dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
 
-    const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
+    const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
     const float sin_block_theta = sinf(block_theta);
     const float cos_block_theta = cosf(block_theta);
 
@@ -4265,32 +4609,38 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
     quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
 }
 
-static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
     dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4299,7 +4649,8 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu
 #endif
 }
 
-static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4308,12 +4659,14 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu
 #endif
 }
 
-static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
 }
 
-static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4322,7 +4675,8 @@ static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cu
 #endif
 }
 
-static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+template<typename dst_t>
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4513,6 +4867,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
     dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
+static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -4522,6 +4881,35 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
+static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_row_q5_0_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_row_q5_1_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_row_q8_0_cuda;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_F32:
+            return convert_fp32_to_fp16_cuda;
+        default:
+            return nullptr;
+    }
+}
+
 static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
@@ -4560,7 +4948,15 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q4_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_0_RDNA2;
+        nwarps = NWARPS_Q4_0_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q4_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_0_RDNA1;
+        nwarps = NWARPS_Q4_0_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q4_0_AMPERE;
         mmq_y  =  MMQ_Y_Q4_0_AMPERE;
         nwarps = NWARPS_Q4_0_AMPERE;
@@ -4597,7 +4993,15 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q4_1_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_1_RDNA2;
+        nwarps = NWARPS_Q4_1_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q4_1_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_1_RDNA1;
+        nwarps = NWARPS_Q4_1_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q4_1_AMPERE;
         mmq_y  =  MMQ_Y_Q4_1_AMPERE;
         nwarps = NWARPS_Q4_1_AMPERE;
@@ -4634,7 +5038,15 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q5_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_0_RDNA2;
+        nwarps = NWARPS_Q5_0_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q5_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_0_RDNA1;
+        nwarps = NWARPS_Q5_0_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q5_0_AMPERE;
         mmq_y  =  MMQ_Y_Q5_0_AMPERE;
         nwarps = NWARPS_Q5_0_AMPERE;
@@ -4671,7 +5083,15 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q5_1_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_1_RDNA2;
+        nwarps = NWARPS_Q5_1_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q5_1_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_1_RDNA1;
+        nwarps = NWARPS_Q5_1_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q5_1_AMPERE;
         mmq_y  =  MMQ_Y_Q5_1_AMPERE;
         nwarps = NWARPS_Q5_1_AMPERE;
@@ -4708,7 +5128,15 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q8_0_RDNA2;
+        mmq_y  =  MMQ_Y_Q8_0_RDNA2;
+        nwarps = NWARPS_Q8_0_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q8_0_RDNA1;
+        mmq_y  =  MMQ_Y_Q8_0_RDNA1;
+        nwarps = NWARPS_Q8_0_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q8_0_AMPERE;
         mmq_y  =  MMQ_Y_Q8_0_AMPERE;
         nwarps = NWARPS_Q8_0_AMPERE;
@@ -4745,7 +5173,15 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q2_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q2_K_RDNA2;
+        nwarps = NWARPS_Q2_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q2_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q2_K_RDNA1;
+        nwarps = NWARPS_Q2_K_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q2_K_AMPERE;
         mmq_y  =  MMQ_Y_Q2_K_AMPERE;
         nwarps = NWARPS_Q2_K_AMPERE;
@@ -4784,7 +5220,15 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q3_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q3_K_RDNA2;
+        nwarps = NWARPS_Q3_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q3_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q3_K_RDNA1;
+        nwarps = NWARPS_Q3_K_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q3_K_AMPERE;
         mmq_y  =  MMQ_Y_Q3_K_AMPERE;
         nwarps = NWARPS_Q3_K_AMPERE;
@@ -4822,7 +5266,15 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q4_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q4_K_RDNA2;
+        nwarps = NWARPS_Q4_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q4_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q4_K_RDNA1;
+        nwarps = NWARPS_Q4_K_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q4_K_AMPERE;
         mmq_y  =  MMQ_Y_Q4_K_AMPERE;
         nwarps = NWARPS_Q4_K_AMPERE;
@@ -4859,7 +5311,15 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q5_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q5_K_RDNA2;
+        nwarps = NWARPS_Q5_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q5_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q5_K_RDNA1;
+        nwarps = NWARPS_Q5_K_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q5_K_AMPERE;
         mmq_y  =  MMQ_Y_Q5_K_AMPERE;
         nwarps = NWARPS_Q5_K_AMPERE;
@@ -4896,7 +5356,15 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
     const int compute_capability = g_compute_capabilities[id];
 
     int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_TURING) {
+    if (compute_capability >= CC_RDNA2) {
+        mmq_x  =  MMQ_X_Q6_K_RDNA2;
+        mmq_y  =  MMQ_Y_Q6_K_RDNA2;
+        nwarps = NWARPS_Q6_K_RDNA2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        mmq_x  =  MMQ_X_Q6_K_RDNA1;
+        mmq_y  =  MMQ_Y_Q6_K_RDNA1;
+        nwarps = NWARPS_Q6_K_RDNA1;
+    } else if (compute_capability >= CC_VOLTA) {
         mmq_x  =  MMQ_X_Q6_K_AMPERE;
         mmq_y  =  MMQ_Y_Q6_K_AMPERE;
         nwarps = NWARPS_Q6_K_AMPERE;
@@ -4968,31 +5436,41 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
 }
 
-static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
-                          const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+template<typename T>
+static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
+                          const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
-    rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
+    if (pos == nullptr) {
+        rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+    } else {
+        rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+    }
 }
 
-static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
-                          const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+template<typename T>
+static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
+                          const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
-    rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
+    if (pos == nullptr) {
+        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+    } else {
+        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+    }
 }
 
-static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
-                              const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
+static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
+                              const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
     GGML_ASSERT(ncols % 4 == 0);
     const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
     const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
     const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
+    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
 }
 
 static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -5130,25 +5608,30 @@ void ggml_init_cublas() {
         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         int64_t total_vram = 0;
         fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
-        for (int id = 0; id < g_device_count; ++id) {
+        for (int64_t id = 0; id < g_device_count; ++id) {
             cudaDeviceProp prop;
             CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
-            fprintf(stderr, "  Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
+            fprintf(stderr, "  Device %ld: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
 
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
-
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+            g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
+#else
             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         }
-        for (int id = 0; id < g_device_count; ++id) {
+        for (int64_t id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
         }
 
-        for (int id = 0; id < g_device_count; ++id) {
-            CUDA_CHECK(cudaSetDevice(id));
+        for (int64_t id = 0; id < g_device_count; ++id) {
+            CUDA_CHECK(ggml_cuda_set_device(id));
 
-            // create main stream
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
+            // create cuda streams
+            for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+                CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking));
+            }
 
             // create cublas handle
             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
@@ -5217,7 +5700,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     if (src->backend == GGML_BACKEND_CPU) {
         kind = cudaMemcpyHostToDevice;
         src_ptr = (char *) src->data;
-    } else if (src->backend == GGML_BACKEND_GPU) {
+    } else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) {
+        GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
         kind = cudaMemcpyDeviceToDevice;
         struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
         int id;
@@ -5256,240 +5740,209 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
 }
 
 inline void ggml_cuda_op_add(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i  != nullptr);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    // compute
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
+        add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
+        add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
     } else {
         GGML_ASSERT(false);
     }
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) i02;
-    (void) i1;
 }
 
 inline void ggml_cuda_op_mul(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i  != nullptr);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
+    mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
 
     (void) dst;
-    (void) src0_ddq_i;
-    (void) i02;
-    (void) i1;
 }
 
 inline void ggml_cuda_op_gelu(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
-
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    // compute
-    gelu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+    gelu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_silu(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
-
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    // compute
-    silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+    silu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_norm(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
-    // compute
-    norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_rms_norm(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    // compute
-    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
+    rms_norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_mul_mat_q(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
-
-    GGML_ASSERT(src0_ddq_i != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
     const int64_t ne00 = src0->ne[0];
 
     const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
     GGML_ASSERT(ne10 % QK8_1 == 0);
 
     const int64_t ne0 = dst->ne[0];
 
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t row_diff = row_high - row_low;
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
-    const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
-
-    const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
-        ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
-    size_t as;
-    void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*sizeof(block_q8_1)/QK8_1, &as);
-    quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, cudaStream_main);
+    const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q4_1:
-            ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_0:
-            ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_1:
-            ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q8_0:
-            ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q2_K:
-            ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q3_K:
-            ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q4_K:
-            ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q5_K:
-            ggml_mul_mat_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         case GGML_TYPE_Q6_K:
-            ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+            ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
             break;
         default:
             GGML_ASSERT(false);
             break;
     }
 
-    ggml_cuda_pool_free(src1_q8_1, as);
-
     (void) src1;
     (void) dst;
-    (void) src0_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_ddf_i;
 }
 
 static int64_t get_row_rounding(ggml_type type) {
-    int max_compute_capability = INT_MIN;
-    for (int id = 0; id < g_device_count; ++id) {
-        if (max_compute_capability < g_compute_capabilities[id]
-                && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
-            max_compute_capability = g_compute_capabilities[id];
+    int64_t min_compute_capability = INT_MAX;
+    int64_t max_compute_capability = INT_MIN;
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+            if (min_compute_capability > g_compute_capabilities[id]) {
+                min_compute_capability = g_compute_capabilities[id];
+            }
+            if (max_compute_capability < g_compute_capabilities[id]) {
+                max_compute_capability = g_compute_capabilities[id];
+            }
         }
     }
 
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
     switch(type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
-            return max_compute_capability >= CC_TURING ? 128 : 64;
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            return max_compute_capability >= CC_RDNA2 ? 128 : 64;
+        case GGML_TYPE_F16:
+            return 1;
+        case GGML_TYPE_Q2_K:
+            return max_compute_capability >= CC_RDNA2 ? 128 : 32;
+        case GGML_TYPE_Q3_K:
+            return min_compute_capability < CC_RDNA2 ? 128 : 64;
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+            return max_compute_capability >= CC_RDNA2 ? 128 : 64;
+        default:
+            GGML_ASSERT(false);
+    }
+#else
+    switch(type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
@@ -5500,222 +5953,271 @@ static int64_t get_row_rounding(ggml_type type) {
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
         case GGML_TYPE_Q5_K:
-            return max_compute_capability >= CC_TURING ? 128 : 64;
+            return max_compute_capability >= CC_VOLTA ? 128 : 64;
         case GGML_TYPE_Q6_K:
             return 64;
         default:
             GGML_ASSERT(false);
     }
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 }
 
-inline void ggml_cuda_op_mul_mat_vec(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
-
-    GGML_ASSERT(src0_ddq_i != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+inline void ggml_cuda_op_mul_mat_vec_q(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = i01_high - i01_low;
+    const int64_t row_diff = row_high - row_low;
 
-#ifdef GGML_CUDA_FORCE_DMMV
-    const bool use_mul_mat_vec_q = false;
-    (void) g_compute_capabilities[0];
-#else
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
 
-    bool mul_mat_vec_q_implemented =
-        src0->type == GGML_TYPE_Q4_0 ||
-        src0->type == GGML_TYPE_Q4_1 ||
-        src0->type == GGML_TYPE_Q5_0 ||
-        src0->type == GGML_TYPE_Q5_1 ||
-        src0->type == GGML_TYPE_Q8_0;
-#if QK_K == 256
-    mul_mat_vec_q_implemented = mul_mat_vec_q_implemented ||
-        src0->type == GGML_TYPE_Q2_K ||
-        src0->type == GGML_TYPE_Q3_K ||
-        src0->type == GGML_TYPE_Q4_K ||
-        src0->type == GGML_TYPE_Q5_K ||
-        src0->type == GGML_TYPE_Q6_K;
-#endif // QK_K == 256
-
-    const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented;
-#endif
+    (void) src1;
+    (void) dst;
+    (void) src1_ddf_i;
+    (void) src1_ncols;
+    (void) src1_padded_row_size;
+}
 
-    if (use_mul_mat_vec_q) {
-        const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
-            ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
-        size_t as;
-        void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
-        quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main);
-
-        switch (src0->type) {
-            case GGML_TYPE_Q4_0:
-                mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q4_1:
-                mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_0:
-                mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_1:
-                mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q8_0:
-                mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q2_K:
-                mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q3_K:
-                mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q4_K:
-                mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_K:
-                mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q6_K:
-                mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+inline void ggml_cuda_op_dequantize_mul_mat_vec(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
-        ggml_cuda_pool_free(src1_q8_1, as);
-    } else {
-        // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
 #ifdef GGML_CUDA_F16
-        size_t ash;
-        dfloat * src1_dfloat = nullptr; // dfloat == half
-
-        bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
-            src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
-            src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
-
-        if (src1_convert_f16) {
-            src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
-            ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
-                                    ne00, 1, sizeof(float), 0, 0,
-                                    ne00, 1, sizeof(half),  0, 0, cudaStream_main);
-        }
+    size_t ash;
+    dfloat * src1_dfloat = nullptr; // dfloat == half
+
+    bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+    if (src1_convert_f16) {
+        src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+        ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
+                                ne00, 1, sizeof(float), 0, 0,
+                                ne00, 1, sizeof(half),  0, 0, stream);
+    }
 #else
-        dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
+    const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
 #endif // GGML_CUDA_F16
 
-        switch (src0->type) {
-            case GGML_TYPE_Q4_0:
-                dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q4_1:
-                dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_0:
-                dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_1:
-                dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q8_0:
-                dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q2_K:
-                dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q3_K:
-                dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q4_K:
-                dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q5_K:
-                dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_Q6_K:
-                dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            case GGML_TYPE_F16:
-                convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_F16:
+            convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
 
 #ifdef GGML_CUDA_F16
-        if (src1_convert_f16) {
-            ggml_cuda_pool_free(src1_dfloat, ash);
-        }
-#endif // GGML_CUDA_F16
+    if (src1_convert_f16) {
+        ggml_cuda_pool_free(src1_dfloat, ash);
     }
+#endif // GGML_CUDA_F16
 
     (void) src1;
     (void) dst;
-    (void) src0_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_ddq_i;
+    (void) src1_ncols;
+    (void) src1_padded_row_size;
 }
 
 inline void ggml_cuda_op_mul_mat_cublas(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src0_dd_i != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(dst_dd_i != nullptr);
 
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
 
     const int64_t ne00 = src0->ne[0];
 
     const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
 
     const int64_t ne0 = dst->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t row_diff = row_high - row_low;
 
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // ldc == nrows of the matrix that cuBLAS writes into
-    int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
+    int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
+
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+        // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
+        half * src0_as_f16 = nullptr;
+        size_t src0_as = 0;
+        if (src0->type != GGML_TYPE_F16) {
+            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
+            GGML_ASSERT(to_fp16_cuda != nullptr);
+            size_t ne = row_diff*ne00;
+            src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
+            to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
+        }
+        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
+
+        half * src1_as_f16 = nullptr;
+        size_t src1_as = 0;
+        if (src1->type != GGML_TYPE_F16) {
+            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+            GGML_ASSERT(to_fp16_cuda != nullptr);
+            size_t ne = src1_ncols*ne10;
+            src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
+            to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
+        }
+        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
+
+        size_t dst_as = 0;
+        half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
+
+        const half alpha_f16 = 1.0f;
+        const half beta_f16 = 0.0f;
+
+        CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
+        CUBLAS_CHECK(
+            cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                    row_diff, src1_ncols, ne10,
+                    &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
+                                src1_ptr, CUDA_R_16F, ne10,
+                    &beta_f16,   dst_f16, CUDA_R_16F, ldc,
+                    CUBLAS_COMPUTE_16F,
+                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
-    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
-    CUBLAS_CHECK(
-        cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
-                i01_diff, ne11, ne10,
-                &alpha, src0_ddf_i, ne00,
-                        src1_ddf_i, ne10,
-                &beta,  dst_ddf_i,  ldc));
+        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+        to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
+
+        ggml_cuda_pool_free(dst_f16, dst_as);
+
+        if (src0_as != 0) {
+            ggml_cuda_pool_free(src0_as_f16, src0_as);
+        }
+
+        if (src1_as != 0) {
+            ggml_cuda_pool_free(src1_as_f16, src1_as);
+        }
+    }
+    else {
+        float * src0_ddq_as_f32 = nullptr;
+        size_t src0_as = 0;
+
+        if (src0->type != GGML_TYPE_F32) {
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+            GGML_ASSERT(to_fp32_cuda != nullptr);
+            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
+            to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
+        }
+        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
+
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+
+        CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
+        CUBLAS_CHECK(
+            cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                    row_diff, src1_ncols, ne10,
+                    &alpha, src0_ddf_i, ne00,
+                            src1_ddf_i,  ne10,
+                    &beta,  dst_dd_i,   ldc));
+
+        if (src0_as != 0) {
+            ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
+        }
+    }
 
     (void) dst;
-    (void) src0_ddq_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_ddq_i;
+    (void) src1_padded_row_size;
 }
 
 inline void ggml_cuda_op_rope(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t ne2 = dst->ne[2];
+    const int64_t nrows = ggml_nrows(src0);
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
     const int n_ctx  = ((int32_t *) dst->op_params)[3];
@@ -5726,41 +6228,56 @@ inline void ggml_cuda_op_rope(
     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
-    const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
+
+    const int32_t * pos = nullptr;
+    if ((mode & 1) == 0) {
+        GGML_ASSERT(src1->type == GGML_TYPE_I32);
+        GGML_ASSERT(src1->ne[0] == ne2);
+        pos = (const int32_t *) src1_dd;
+    }
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
     // compute
     if (is_glm) {
-        rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main);
+        GGML_ASSERT(false);
+        rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
     } else if (is_neox) {
         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
-        rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
+        if (src0->type == GGML_TYPE_F32) {
+            rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+        } else {
+            GGML_ASSERT(false);
+        }
     } else {
-        rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
+        if (src0->type == GGML_TYPE_F32) {
+            rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+        } else {
+            GGML_ASSERT(false);
+        }
     }
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_alibi(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
     const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
@@ -5775,335 +6292,393 @@ inline void ggml_cuda_op_alibi(
     const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
-    // compute
-    alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
+    alibi_f32_cuda(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
 
     (void) src1;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_diag_mask_inf(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int nrows0 = ggml_nrows(src0);
 
     const int n_past = ((int32_t *) dst->op_params)[0];
 
-    // compute
-    diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
+    diag_mask_inf_f32_cuda(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_soft_max(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
+    const int64_t nrows = ggml_nrows(src0);
 
-    // compute
-    soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_scale(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
-    cudaStream_t & cudaStream_main){
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
-    GGML_ASSERT(dst_ddf_i != nullptr);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     const float scale = ((float *) src1->data)[0];
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t i01_diff = i01_high - i01_low;
-
-    // compute
-    scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main);
+    scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
     CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
-    (void) src0_ddq_i;
-    (void) src1_ddf_i;
-    (void) i02;
-    (void) i1;
+    (void) src1_dd;
+}
+
+static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
+    const int64_t nrows0 = ggml_nrows(src0);
+
+    const bool use_src1 = src1 != nullptr;
+    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(              dst->backend != GGML_BACKEND_GPU_SPLIT);
+
+    struct ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
+    struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+    struct ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *)  dst->extra;
+
+    const bool src0_on_device =             src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
+    const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
+    const bool  dst_on_device =              dst->backend == GGML_BACKEND_GPU;
+
+    const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;
+
+    // dd = data device
+    float * src0_ddf = nullptr;
+    float * src1_ddf = nullptr;
+    float *  dst_ddf = nullptr;
+
+    // as = actual size
+    size_t src0_asf = 0;
+    size_t src1_asf = 0;
+    size_t  dst_asf = 0;
+
+    ggml_cuda_set_device(g_main_device);
+    const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+    if (src0_on_device) {
+        src0_ddf = (float *) src0_extra->data_device[g_main_device];
+    } else {
+        src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
+        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
+    }
+
+    if (use_src1 && !src1_stays_on_host) {
+        if (src1_on_device) {
+            src1_ddf = (float *) src1_extra->data_device[g_main_device];
+        } else {
+            src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
+            CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
+        }
+    }
+    if (dst_on_device) {
+        dst_ddf = (float *) dst_extra->data_device[g_main_device];
+    } else {
+        dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
+    }
+
+    // do the computation
+    op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
+    CUDA_CHECK(cudaGetLastError());
+
+    // copy dst to host if necessary
+    if (!dst_on_device) {
+        CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
+    }
+
+    if (src0_asf > 0) {
+        ggml_cuda_pool_free(src0_ddf, src0_asf);
+    }
+    if (src1_asf > 0) {
+        ggml_cuda_pool_free(src1_ddf, src1_asf);
+    }
+    if (dst_asf > 0) {
+        ggml_cuda_pool_free(dst_ddf, dst_asf);
+    }
+
+    if (dst->backend == GGML_BACKEND_CPU) {
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
+}
+
+static void ggml_cuda_set_peer_access(const int n_tokens) {
+    static bool peer_access_enabled = false;
+
+    const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
+
+    if (peer_access_enabled == enable_peer_access) {
+        return;
+    }
+
+#ifdef NDEBUG
+    for (int id = 0; id < g_device_count; ++id) {
+        CUDA_CHECK(ggml_cuda_set_device(id));
+
+        for (int id_other = 0; id_other < g_device_count; ++id_other) {
+            if (id == id_other) {
+                continue;
+            }
+            if (id != g_main_device && id_other != g_main_device) {
+                continue;
+            }
+
+            int can_access_peer;
+            CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
+            if (can_access_peer) {
+                if (enable_peer_access) {
+                    CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
+                } else {
+                    CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other));
+                }
+            }
+        }
+    }
+#endif // NDEBUG
+
+    peer_access_enabled = enable_peer_access;
 }
 
-static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-                         ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
+static void ggml_cuda_op_mul_mat(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
+    const bool convert_src1_to_q8_1) {
+
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
     const int64_t nrows0 = ggml_nrows(src0);
 
-    const bool use_src1 = src1 != nullptr;
-    const int64_t ne10 = use_src1 ? src1->ne[0] : 1;
-    const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
-    const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
-    const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
-    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+    const int64_t nrows1 = ggml_nrows(src1);
 
     GGML_ASSERT(ne03 == ne13);
 
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    const int nb2 = dst->nb[2];
+    const int nb3 = dst->nb[3];
+
+    ggml_cuda_set_peer_access(ne11);
 
     GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
-    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
 
-    // strides for iteration over dims 3 and 2
-    const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
-    const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
-    const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
-    const int64_t src0_stride = ne00 * ne01 * stride_mod;
-    const int64_t src1_stride = ne10 * ne11 * stride_mod;
-    const int64_t dst_stride = ne0 * ne1 * stride_mod;
+    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
 
-    const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
-    const int64_t i03_max = flatten_rows ? 1 : ne03;
-    const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
-    const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
-    GGML_ASSERT(!(flatten_rows && ne02 < ne12));
+    const int64_t i02_divisor = ne12 / ne02;
 
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
+    const size_t q8_1_ts = sizeof(block_q8_1);
+    const size_t q8_1_bs = QK8_1;
 
-    struct ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
-    struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
-    struct ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *) dst->extra;
+    struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    struct ggml_tensor_extra_gpu *  dst_extra = (ggml_tensor_extra_gpu *)  dst->extra;
 
     const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
     const bool src0_is_contiguous = ggml_is_contiguous(src0);
-    const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
 
-    const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
-    const bool src1_stays_on_host = use_src1 && (
-        dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
+    const bool src1_is_contiguous = ggml_is_contiguous(src1);
+    const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
+        ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
 
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+    GGML_ASSERT(!(split && ne02 > 1));
+    GGML_ASSERT(!(split && ne03 > 1));
     GGML_ASSERT(!(split && ne02 < ne12));
 
-    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
-
     // dd = data device
-    char  * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
-    float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
-    float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
-    float *  dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
-
-    // asq = actual size quantized, asf = actual size float
-    size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
-    size_t  dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    char  *  src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr};
+    float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
+    char  * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1
+    float *   dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-    // if multiple devices are used they need to wait for the main device
-    // here an event is recorded that signifies that the main device has finished calculating the input data
-    if (split && g_device_count > 1) {
-        CUDA_CHECK(cudaSetDevice(g_main_device));
-        CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
-    }
+    // as = actual size
+    size_t  src0_as[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t   dst_as[GGML_CUDA_MAX_DEVICES] = {0};
 
-    for (int id = 0; id < g_device_count; ++id) {
-        if (!split && id != g_main_device) {
-            continue;
-        }
+    int64_t  row_low[GGML_CUDA_MAX_DEVICES];
+    int64_t row_high[GGML_CUDA_MAX_DEVICES];
 
-        const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU && id == g_main_device;
-        const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        // by default, use all rows
+        row_low[id]  = 0;
+        row_high[id] = ne01;
 
-        int64_t row_low, row_high;
+        // for multi GPU, get the row boundaries from tensor split
+        // and round to mul_mat_q tile sizes
         if (split) {
             const int64_t rounding = get_row_rounding(src0->type);
 
-            row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
-            row_low -= row_low % rounding;
+            if (id != 0) {
+                row_low[id]  = ne01*g_tensor_split[id];
+                row_low[id] -= row_low[id] % rounding;
+            }
 
-            if (id == g_device_count - 1) {
-                row_high = nrows0;
-            } else {
-                row_high = nrows0*g_tensor_split[id + 1];
-                row_high -= row_high % rounding;
+            if (id != g_device_count - 1) {
+                row_high[id]  = ne01*g_tensor_split[id + 1];
+                row_high[id] -= row_high[id] % rounding;
             }
-        } else {
-            row_low = 0;
-            row_high = nrows0*i02_divisor;
         }
-        if (row_low == row_high) {
+    }
+
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
             continue;
         }
 
-        int64_t row_diff = row_high - row_low;
+        const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
+        const bool  dst_on_device =  dst->backend == GGML_BACKEND_GPU && id == g_main_device;
 
-        cudaSetDevice(id);
-        cudaStream_t cudaStream_main = g_cudaStreams_main[id];
-
-        // wait for main GPU data if necessary
-        if (split && id != g_main_device) {
-            CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
-        }
+        ggml_cuda_set_device(id);
+        const cudaStream_t stream = g_cudaStreams[id][0];
 
         if (src0_on_device && src0_is_contiguous) {
-            if (src0_is_f32) {
-                src0_ddf[id] = (float *) src0_extra->data_device[id];
-            } else {
-                src0_ddq[id] = (char *) src0_extra->data_device[id];
-            }
+            src0_dd[id] = (char *) src0_extra->data_device[id];
         } else {
-            if (src0_is_f32) {
-                src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
-            } else {
-                src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]);
-            }
+            const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
+            src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
         }
 
-        if (src0_needs_f32 && !src0_is_f32) {
-            src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
+        if (src1_on_device && src1_is_contiguous) {
+            src1_ddf[id] = (float *) src1_extra->data_device[id];
+        } else {
+            src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
         }
 
-        if (use_src1 && !src1_stays_on_host) {
-            if (src1_on_device && src1_is_contiguous) {
-                src1_ddf[id] = (float *) src1_extra->data_device[id];
-            } else {
-                src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
+        if (convert_src1_to_q8_1) {
+            src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
+
+            if (split && src1_on_device && src1_is_contiguous) {
+                quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
+                CUDA_CHECK(cudaGetLastError());
             }
         }
+
         if (dst_on_device) {
-            dst_ddf[id] = (float *) dst_extra->data_device[id];
+            dst_dd[id] = (float *) dst_extra->data_device[id];
         } else {
-            size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float);
-            dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
+            const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
+            dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
         }
+    }
 
-        for (int64_t i03 = 0; i03 < i03_max; i03++) {
-            const int64_t i13 = i03 % ne13;
-            for (int64_t i02 = 0; i02 < i02_max; i02++) {
-                const int64_t i12 = i02 % ne12;
+    // if multiple devices are used they need to wait for the main device
+    // here an event is recorded that signals that the main device has finished calculating the input data
+    if (split && g_device_count > 1) {
+        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+        CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
+    }
 
-                const int64_t i0 = i03*i02_max + i02;
+    const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+        const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
+        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
 
-                // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
-                const int64_t i0_offset_low = row_low/rows_per_iter;
-                const int64_t i0_offset_high = row_high/rows_per_iter;
+        for (int64_t id = 0; id < g_device_count; ++id) {
+            if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
+                continue;
+            }
 
-                int64_t i01_low = 0;
-                int64_t i01_high = rows_per_iter;
-                if (split) {
-                    if (i0 < i0_offset_low || i0 > i0_offset_high) {
-                        continue;
-                    }
-                    if (i0 == i0_offset_low) {
-                        i01_low = row_low % rows_per_iter;
-                    }
-                    if (i0 == i0_offset_high) {
-                        i01_high = row_high % rows_per_iter;
-                    }
-                }
+            const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
+            const bool  dst_on_device =  dst->backend == GGML_BACKEND_GPU && id == g_main_device;
+            const int64_t row_diff = row_high[id] - row_low[id];
 
-                // There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables.
-                // Removing the first assert or changing the order of the arguments causes the second assert to fail.
-                // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
-                // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
-                GGML_ASSERT(i01_low == 0 || g_device_count > 1);
-                GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
+            ggml_cuda_set_device(id);
+            const cudaStream_t stream = g_cudaStreams[id][is];
 
-                const int64_t i01_diff = i01_high - i01_low;
-                if (i01_diff == 0) {
-                    continue;
-                }
-                const int64_t i11 = i13*ne12 + i12;
+            // wait for main GPU data if necessary
+            if (split && (id != g_main_device || is != 0)) {
+                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0));
+            }
+
+            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+                const int64_t i03 = i0 / ne12;
+                const int64_t i02 = i0 % ne12;
+
+                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
-                float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
-                float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
-                float * dst_ddf_i  =  dst_ddf[id] + (i0             - i0_offset_low)*dst_stride;
-
-                // for split tensors the data pointer needs to be rounded down
-                // to the bin edge for i03, i02 bins beyond the first
-                if (i0 - i0_offset_low > 0) {
-                    GGML_ASSERT(!flatten_rows);
-                    src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
-                    src0_ddf_i -= (row_low % ne01)*ne00;
-                    dst_ddf_i  -= (row_low % ne0)*ne1;
-                }
+                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
+                float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
+                char  * src1_ddq_i = src1_ddq[id] +  src1_ddq_i_offset;
+                float *   dst_dd_i =   dst_dd[id] + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
 
                 // the main device memory buffer can be on VRAM scratch, with space for all partial results
                 // in that case an offset on dst_ddf_i is needed
                 if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
-                    dst_ddf_i += i01_low; // offset is 0 if no tensor split
+                    dst_dd_i += row_low[id]; // offset is 0 if no tensor split
                 }
 
                 // copy src0, src1 to device if necessary
-                if (use_src1 && !src1_stays_on_host) {
-                    if (src1->backend == GGML_BACKEND_CPU) {
-                        GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
-                        int64_t nrows1 = flatten_rows ? nrows0 : ne11;
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
-                    } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
-                        if (id != g_main_device) {
-                            GGML_ASSERT(!flatten_rows);
+                if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
+                    if (id != g_main_device) {
+                        if (convert_src1_to_q8_1) {
+                            char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset;
+                            CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs,
+                                                    cudaMemcpyDeviceToDevice, stream));
+                        } else {
                             float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
-                            src1_ddf_i_source += i11*src1_stride;
-                            CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
-                                                    cudaMemcpyDeviceToDevice, cudaStream_main));
+                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+                            CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float),
+                                                    cudaMemcpyDeviceToDevice, stream));
                         }
-                    } else if (src1_on_device && !src1_is_contiguous) {
-                        GGML_ASSERT(!split);
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
-                    } else {
-                        GGML_ASSERT(false);
                     }
+                } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+                                   src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+                } else {
+                    GGML_ASSERT(false);
                 }
 
-                if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
-                    if (src0_is_f32) {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
-                    } else {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
-                    }
+                if (convert_src1_to_q8_1 && src1->backend == GGML_BACKEND_CPU) {
+                    quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
+                    CUDA_CHECK(cudaGetLastError());
                 }
 
-                // convert src0 to f32 if it is necessary for the ggml_cuda_op
-                if (src0_needs_f32 && !src0_is_f32) {
-                    to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
-                    CUDA_CHECK(cudaGetLastError());
+                if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream));
                 }
 
                 // do the computation
-                op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
+                op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+                   row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream);
                 CUDA_CHECK(cudaGetLastError());
 
                 // copy dst to host or other device if necessary
@@ -6125,95 +6700,86 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                         // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
                         // Instead they need to be copied to the correct slice in ne0 = dst row index.
                         // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
-                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i01_low*sizeof(float) + i02*nb2 + i03*nb3);
-                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_ddf_i, i01_diff*sizeof(float),
-                                                     i01_diff*sizeof(float), ne1, kind, cudaStream_main));
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0 + row_low[id];
+                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float),
+                                                    row_diff*sizeof(float), src1_ncols, kind, stream));
                     } else {
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
-                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0;
+                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream));
                     }
                 }
 
-                // signify to main device that other device is done
-                if (split && g_device_count > 1 && id != g_main_device) {
-                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
+                // add event for the main device to wait on until other device is done
+                if (split && (id != g_main_device || is != 0)) {
+                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
                 }
             }
         }
     }
 
-    // wait until each device is finished, then free their buffers
-    for (int id = 0; id < g_device_count; ++id) {
-        if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
-            continue;
-        }
-
-        CUDA_CHECK(cudaSetDevice(id));
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        CUDA_CHECK(ggml_cuda_set_device(id));
 
-        if (src0_asq[id] > 0) {
-            ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
-        }
-        if (src0_asf[id] > 0) {
-            ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]);
+        // free buffers again when done
+        if (src0_as[id] > 0) {
+            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
         }
         if (src1_asf[id] > 0) {
             ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
         }
-        if (dst_asf[id] > 0) {
-            ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
+        if (src1_asq[id] > 0) {
+            ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
+        }
+        if (dst_as[id] > 0) {
+            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
         }
     }
 
     // main device waits for all other devices to be finished
     if (split && g_device_count > 1) {
-        CUDA_CHECK(cudaSetDevice(g_main_device));
-        for (int id = 0; id < g_device_count; ++id) {
-            if (id != g_main_device && src0_extra->events[id]) {
-                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
+        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+        is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS;
+
+        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+        for (int64_t id = 0; id < g_device_count; ++id) {
+            for (int64_t is = 0; is < is_max; ++is) {
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
             }
         }
     }
 
     if (dst->backend == GGML_BACKEND_CPU) {
-        CUDA_CHECK(cudaSetDevice(g_main_device));
+        CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         CUDA_CHECK(cudaDeviceSynchronize());
     }
 }
 
-void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
-    // Due to flatten_rows == true this does in practice not make a difference however.
-    // Better solution would be nice but right now that would require disproportionate changes.
-    GGML_ASSERT(
-        (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
-        src1->type == GGML_TYPE_F32 &&
-        (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
+static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
 }
 
-void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
+static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
 }
 
-void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_gelu, true, true);
+static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
 }
 
-void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
+static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
 }
 
-void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
+static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
 }
 
-void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
+static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@@ -6223,17 +6789,13 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
     const int64_t ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
-        src1->type == GGML_TYPE_F32 &&
-        dst->type == GGML_TYPE_F32 &&
-        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
-        return true;
-    }
-
-    return false;
+    return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+            src1->type == GGML_TYPE_F32 &&
+             dst->type == GGML_TYPE_F32 &&
+            (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
 }
 
-void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
     GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
     GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
@@ -6247,8 +6809,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
 
     const int64_t ne12 = src1->ne[2];
 
-    CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -6259,10 +6821,10 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
-    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
+    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
 }
 
-void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
     GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
     GGML_ASSERT(!ggml_is_permuted(src0));
     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
@@ -6278,8 +6840,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int64_t nb01 = src0->nb[1];
     const int64_t nb02 = src0->nb[2];
 
-    CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -6290,38 +6852,49 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
-    const int row_stride_x = nb01 / sizeof(half);
-    const int channel_stride_x = nb02 / sizeof(half);
+    const int64_t row_stride_x = nb01 / sizeof(half);
+    const int64_t channel_stride_x = nb02 / sizeof(half);
 
-    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
+    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
-void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
         src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
 
+    int64_t min_compute_capability = INT_MAX;
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if (min_compute_capability > g_compute_capabilities[id]
+                && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+            min_compute_capability = g_compute_capabilities[id];
+        }
+    }
+
     if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
     } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
     }else if (src0->type == GGML_TYPE_F32) {
-        ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+        ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
         if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
-            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
-        } else {
-            int min_compute_capability = INT_MAX;
-            for (int id = 0; id < g_device_count; ++id) {
-                if (min_compute_capability > g_compute_capabilities[id]
-                        && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
-                    min_compute_capability = g_compute_capabilities[id];
-                }
-            }
 
+#ifdef GGML_CUDA_FORCE_DMMV
+            const bool use_mul_mat_vec_q = false;
+#else
+            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+#endif // GGML_CUDA_FORCE_DMMV
+
+            if (use_mul_mat_vec_q) {
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
+            } else {
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
+            }
+        } else {
             if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
-                ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
             } else {
-                ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+                ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
             }
         }
     } else {
@@ -6329,12 +6902,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
     }
 }
 
-void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
+static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
 }
 
-void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
 
@@ -6360,8 +6932,8 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
     const int64_t nb11 = src1->nb[1];
     const int64_t nb12 = src1->nb[2];
 
-    CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -6371,52 +6943,49 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
 
     if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+                              ne10, ne11, nb10, nb11, nb12, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+                              ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
+        fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
         GGML_ASSERT(false);
     }
 
     (void) dst;
 }
 
-void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_cpy(src0, dst, nullptr);
     (void) src1;
 }
 
-void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
+static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
 }
 
-void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
+static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max);
 }
 
-void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
-
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true);
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope);
 }
 
-void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
+static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
 }
 
-void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
     (void) dst;
 }
 
 void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
-    int nrows = ggml_nrows(tensor);
+    const int64_t nrows = ggml_nrows(tensor);
 
     const int64_t ne0 = tensor->ne[0];
 
@@ -6426,14 +6995,14 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
     struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
     memset(extra, 0, sizeof(*extra));
 
-    for (int id = 0; id < g_device_count; ++id) {
+    for (int64_t id = 0; id < g_device_count; ++id) {
         if (backend == GGML_BACKEND_GPU && id != g_main_device) {
             continue;
         }
 
-        cudaSetDevice(id);
+        ggml_cuda_set_device(id);
 
-        int row_low, row_high;
+        int64_t row_low, row_high;
         if (backend == GGML_BACKEND_GPU) {
             row_low = 0;
             row_high = nrows;
@@ -6483,7 +7052,9 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         extra->data_device[id] = buf;
 
         if (backend == GGML_BACKEND_GPU_SPLIT) {
-            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
+            for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+                CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
+            }
         }
     }
 
@@ -6497,15 +7068,17 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
-    for (int id = 0; id < g_device_count; ++id) {
+    for (int64_t id = 0; id < g_device_count; ++id) {
         if (extra->data_device[id] != nullptr) {
-            CUDA_CHECK(cudaSetDevice(id));
+            CUDA_CHECK(ggml_cuda_set_device(id));
             CUDA_CHECK(cudaFree(extra->data_device[id]));
         }
 
-        if (extra->events[id] != nullptr) {
-            CUDA_CHECK(cudaSetDevice(id));
-            CUDA_CHECK(cudaEventDestroy(extra->events[id]));
+        for (int64_t is = 0; is < MAX_STREAMS; ++is) {
+            if (extra->events[id][is] != nullptr) {
+                CUDA_CHECK(ggml_cuda_set_device(id));
+                CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
+            }
         }
     }
 
@@ -6528,11 +7101,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
     return extra;
 }
 
-void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
+static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
     if (scratch && g_scratch_size == 0) {
         return;
     }
 
+    tensor->backend = GGML_BACKEND_GPU;
+
     // recursively assign CUDA buffers until a compute tensor is found
     if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
         const ggml_op src0_op = tensor->src[0]->op;
@@ -6544,8 +7119,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
     }
 
-    tensor->backend = GGML_BACKEND_GPU;
-
     if (scratch && no_alloc) {
         return;
     }
@@ -6557,7 +7130,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         force_inplace;
     const size_t size = ggml_nbytes(tensor);
 
-    CUDA_CHECK(cudaSetDevice(g_main_device));
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
     if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
         struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
@@ -6606,6 +7179,7 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
         return;
     }
     if (g_scratch_buffer == nullptr) {
+        ggml_cuda_set_device(g_main_device);
         CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
     }
 
@@ -6629,6 +7203,15 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
     tensor->extra = extra;
 }
 
+void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+    GGML_ASSERT(ggml_is_contiguous(tensor));
+
+    struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice));
+}
+
 void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
     ggml_cuda_assign_buffers_impl(tensor, true, false, false);
 }
@@ -6645,7 +7228,7 @@ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
     ggml_cuda_assign_buffers_impl(tensor, false, true, false);
 }
 
-void ggml_cuda_set_main_device(int main_device) {
+void ggml_cuda_set_main_device(const int main_device) {
     if (main_device >= g_device_count) {
         fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
                 main_device, g_device_count, g_main_device);
@@ -6659,12 +7242,17 @@ void ggml_cuda_set_main_device(int main_device) {
     }
 }
 
-void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
+void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
     g_mul_mat_q = mul_mat_q;
 }
 
-void ggml_cuda_set_scratch_size(size_t scratch_size) {
-    g_scratch_size = scratch_size;
+void ggml_cuda_set_scratch_size(const size_t scratch_size) {
+    // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
+    // it still won't always work as expected, but it's better than nothing
+    if (scratch_size > g_scratch_size) {
+        ggml_cuda_free_scratch();
+    }
+    g_scratch_size = std::max(g_scratch_size, scratch_size);
 }
 
 void ggml_cuda_free_scratch() {
index a72e82069b9f1430fdbcc007f93ae26ea6c2f924..fda704b6656234ae63e1399fc680c5e5cf6c4a0d 100644 (file)
@@ -31,6 +31,7 @@ GGML_API void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tens
 
 GGML_API void   ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
 GGML_API void   ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
+GGML_API void   ggml_cuda_copy_to_device(struct ggml_tensor * tensor);
 
 GGML_API void   ggml_cuda_set_main_device(int main_device);
 GGML_API void   ggml_cuda_set_mul_mat_q(bool mul_mat_q);
index fca28d37ef97069ca7800fc890ea4d013c46437f..790cf0bf7b963d9c1f52a45dbac5be71e0662256 100644 (file)
@@ -19,6 +19,8 @@
 
 #pragma once
 
+#include "ggml.h"
+
 #include <stddef.h>
 #include <stdbool.h>
 
@@ -33,6 +35,8 @@ struct ggml_cgraph;
 extern "C" {
 #endif
 
+void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data);
+
 struct ggml_metal_context;
 
 // number of command buffers to use
index 1139ee3114610e05726b7e4319843eade838271c..866fed43442a8c487bec4e117dd955e01ee92e1d 100644 (file)
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
-// TODO: temporary - reuse llama.cpp logging
 #ifdef GGML_METAL_NDEBUG
-#define metal_printf(...)
+#define GGML_METAL_LOG_INFO(...)
+#define GGML_METAL_LOG_WARN(...)
+#define GGML_METAL_LOG_ERROR(...)
 #else
-#define metal_printf(...) fprintf(stderr, __VA_ARGS__)
+#define GGML_METAL_LOG_INFO(...)  ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
+#define GGML_METAL_LOG_WARN(...)  ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
+#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
 #endif
 
 #define UNUSED(x) (void)(x)
@@ -100,7 +103,8 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
-    GGML_METAL_DECL_KERNEL(rope);
+    GGML_METAL_DECL_KERNEL(rope_f32);
+    GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -120,8 +124,37 @@ static NSString * const msl_library_source = @"see metal.metal";
 @implementation GGMLMetalClass
 @end
 
+ggml_log_callback ggml_metal_log_callback = NULL;
+void * ggml_metal_log_user_data = NULL;
+
+void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
+    ggml_metal_log_callback  = log_callback;
+    ggml_metal_log_user_data = user_data;
+}
+
+static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
+    if (ggml_metal_log_callback != NULL) {
+        va_list args;
+        va_start(args, format);
+        char buffer[128];
+        int len = vsnprintf(buffer, 128, format, args);
+        if (len < 128) {
+            ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
+        } else {
+            char* buffer2 = malloc(len+1);
+            vsnprintf(buffer2, len+1, format, args);
+            buffer2[len] = 0;
+            ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
+            free(buffer2);
+        }
+        va_end(args);
+    }
+}
+
+
+
 struct ggml_metal_context * ggml_metal_init(int n_cb) {
-    metal_printf("%s: allocating\n", __func__);
+    GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
 
     id <MTLDevice> device;
     NSString * s;
@@ -131,14 +164,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     NSArray * devices = MTLCopyAllDevices();
     for (device in devices) {
         s = [device name];
-        metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
+        GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
     }
 #endif
 
     // Pick and show default Metal device
     device = MTLCreateSystemDefaultDevice();
     s = [device name];
-    metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
+    GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
 
     // Configure context
     struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
@@ -165,7 +198,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
 
         if (error) {
-            metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
+            GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
             return NULL;
         }
     }
@@ -179,11 +212,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
         NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
         NSString * path   = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
-        metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
+        GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
 
         NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
         if (error) {
-            metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
+            GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
             return NULL;
         }
 
@@ -195,7 +228,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
 #endif
         if (error) {
-            metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
+            GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
             return NULL;
         }
     }
@@ -207,11 +240,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #define GGML_METAL_ADD_KERNEL(name) \
         ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
         ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
-        metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
+        GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
                 (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
                 (int) ctx->pipeline_##name.threadExecutionWidth); \
         if (error) { \
-            metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
+          GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
             return NULL; \
         }
 
@@ -261,7 +294,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
-        GGML_METAL_ADD_KERNEL(rope);
+        GGML_METAL_ADD_KERNEL(rope_f32);
+        GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -270,13 +304,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #undef GGML_METAL_ADD_KERNEL
     }
 
-    metal_printf("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
 #if TARGET_OS_OSX
-    metal_printf("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
     if (ctx->device.maxTransferRate != 0) {
-        metal_printf("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
     } else {
-        metal_printf("%s: maxTransferRate               = built-in GPU\n", __func__);
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
     }
 #endif
 
@@ -284,7 +318,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 }
 
 void ggml_metal_free(struct ggml_metal_context * ctx) {
-    metal_printf("%s: deallocating\n", __func__);
+    GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
 #define GGML_METAL_DEL_KERNEL(name) \
     [ctx->function_##name release]; \
     [ctx->pipeline_##name release];
@@ -335,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
-    GGML_METAL_DEL_KERNEL(rope);
+    GGML_METAL_DEL_KERNEL(rope_f32);
+    GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
@@ -360,7 +395,7 @@ void * ggml_metal_host_malloc(size_t n) {
     void * data = NULL;
     const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
     if (result != 0) {
-        metal_printf("%s: error: posix_memalign failed\n", __func__);
+        GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
         return NULL;
     }
 
@@ -388,7 +423,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
 // Metal buffer based on the host memory pointer
 //
 static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
-    //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
+    //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
 
     const int64_t tsize = ggml_nbytes(t);
 
@@ -400,13 +435,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
         if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
             *offs = (size_t) ioffs;
 
-            //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
+            //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
 
             return ctx->buffers[i].metal;
         }
     }
 
-    metal_printf("%s: error: buffer is nil\n", __func__);
+    GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
 
     return nil;
 }
@@ -418,7 +453,7 @@ bool ggml_metal_add_buffer(
                          size_t   size,
                          size_t   max_size) {
     if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
-        metal_printf("%s: too many buffers\n", __func__);
+        GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
         return false;
     }
 
@@ -428,7 +463,7 @@ bool ggml_metal_add_buffer(
             const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
 
             if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
-                metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
+                GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
                 return false;
             }
         }
@@ -449,11 +484,11 @@ bool ggml_metal_add_buffer(
             ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
             if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
+                GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
                 return false;
             }
 
-            metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
+            GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
 
             ++ctx->n_buffers;
         } else {
@@ -473,13 +508,13 @@ bool ggml_metal_add_buffer(
                 ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
                 if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                    metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
+                    GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
                     return false;
                 }
 
-                metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
+                GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
                 if (i + size_step < size) {
-                    metal_printf("\n");
+                    GGML_METAL_LOG_INFO("\n");
                 }
 
                 ++ctx->n_buffers;
@@ -487,17 +522,17 @@ bool ggml_metal_add_buffer(
         }
 
 #if TARGET_OS_OSX
-        metal_printf(", (%8.2f / %8.2f)",
+        GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
                 ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
                 ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
-            metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
+            GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
         } else {
-            metal_printf("\n");
+            GGML_METAL_LOG_INFO("\n");
         }
 #else
-        metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
+        GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
 #endif
     }
 
@@ -610,7 +645,7 @@ void ggml_metal_graph_find_concurrency(
     }
 
     if (ctx->concur_list_len > GGML_MAX_CONCUR) {
-        metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
+        GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
     }
 }
 
@@ -664,7 +699,7 @@ void ggml_metal_graph_compute(
                     continue;
                 }
 
-                //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+                //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
 
                 struct ggml_tensor * src0 = gf->nodes[i]->src[0];
                 struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -708,17 +743,17 @@ void ggml_metal_graph_compute(
                 id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
                 id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(ctx, dst,  &offs_dst)  : nil;
 
-                //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+                //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
                 //if (src0) {
-                //    metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+                //    GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
                 //            ggml_is_contiguous(src0), src0->name);
                 //}
                 //if (src1) {
-                //    metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+                //    GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
                 //            ggml_is_contiguous(src1), src1->name);
                 //}
                 //if (dst) {
-                //    metal_printf("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
+                //    GGML_METAL_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
                 //            dst->name);
                 //}
 
@@ -736,25 +771,59 @@ void ggml_metal_graph_compute(
                             GGML_ASSERT(ggml_is_contiguous(src0));
                             GGML_ASSERT(ggml_is_contiguous(src1));
 
-                            // utilize float4
-                            GGML_ASSERT(ne00 % 4 == 0);
-                            const int64_t nb = ne00/4;
+                            bool bcast_row = false;
 
-                            if (ggml_nelements(src1) == ne10) {
+                            int64_t nb = ne00;
+
+                            if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
                                 // src1 is a row
                                 GGML_ASSERT(ne11 == 1);
+
+                                nb = ne00 / 4;
                                 [encoder setComputePipelineState:ctx->pipeline_add_row];
+
+                                bcast_row = true;
                             } else {
                                 [encoder setComputePipelineState:ctx->pipeline_add];
                             }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&nb     length:sizeof(nb) atIndex:3];
-
-                            const int64_t n = ggml_nelements(dst)/4;
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
+                            [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
+
+                            if (bcast_row) {
+                                const int64_t n = ggml_nelements(dst)/4;
+
+                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            } else {
+                                const int nth = MIN(1024, ne0);
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                            }
                         } break;
                     case GGML_OP_MUL:
                         {
@@ -830,13 +899,13 @@ void ggml_metal_graph_compute(
                                 } break;
                             default:
                                 {
-                                    metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                                    GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
                                     GGML_ASSERT(false);
                                 }
                         } break;
                     case GGML_OP_SOFT_MAX:
                         {
-                            const int nth = 32;
+                            const int nth = MIN(32, ne00);
 
                             if (ne00%4 == 0) {
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
@@ -889,7 +958,7 @@ void ggml_metal_graph_compute(
                                 src1t == GGML_TYPE_F32 &&
                                 [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                                 ne00%32 == 0 &&
-                                ne11 > 1) {
+                                ne11 > 2) {
                                 switch (src0->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
@@ -1019,7 +1088,7 @@ void ggml_metal_graph_compute(
                                         } break;
                                     default:
                                         {
-                                            metal_printf("Asserting on type %d\n",(int)src0t);
+                                            GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
                                             GGML_ASSERT(false && "not implemented");
                                         }
                                 };
@@ -1100,7 +1169,7 @@ void ggml_metal_graph_compute(
                             float eps;
                             memcpy(&eps, dst->op_params, sizeof(float));
 
-                            const int nth = 512;
+                            const int nth = MIN(512, ne00);
 
                             [encoder setComputePipelineState:ctx->pipeline_rms_norm];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1119,7 +1188,7 @@ void ggml_metal_graph_compute(
                             float eps;
                             memcpy(&eps, dst->op_params, sizeof(float));
 
-                            const int nth = 256;
+                            const int nth = MIN(256, ne00);
 
                             [encoder setComputePipelineState:ctx->pipeline_norm];
                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
@@ -1137,17 +1206,16 @@ void ggml_metal_graph_compute(
                         {
                             GGML_ASSERT((src0t == GGML_TYPE_F32));
 
+                            const int nth = MIN(1024, ne00);
+
                             const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
                             const int n_head = ((int32_t *) dst->op_params)[1];
                             float max_bias;
                             memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-                            if (__builtin_popcount(n_head) != 1) {
-                                GGML_ASSERT(false && "only power-of-two n_head implemented");
-                            }
-
                             const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
                             const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+                            const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
                             [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1168,14 +1236,18 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
                             [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
                             [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
-                            [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
-
-                            const int nth = 32;
+                            [encoder setBytes:&m0   length:sizeof(   float) atIndex:18];
+                            [encoder setBytes:&m1   length:sizeof(   float) atIndex:19];
+                            [encoder setBytes:&n_heads_log2_floor   length:sizeof(int) atIndex:20];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ROPE:
                         {
+                            GGML_ASSERT(ne10 == ne02);
+
+                            const int nth = MIN(1024, ne00);
+
                             const int n_past = ((int32_t *) dst->op_params)[0];
                             const int n_dims = ((int32_t *) dst->op_params)[1];
                             const int mode   = ((int32_t *) dst->op_params)[2];
@@ -1185,38 +1257,44 @@ void ggml_metal_graph_compute(
                             memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
                             memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
-                            [encoder setComputePipelineState:ctx->pipeline_rope];
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
+                                case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
+                                default: GGML_ASSERT(false);
+                            };
+
                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
-                            [encoder setBytes:&n_past  length:sizeof(     int) atIndex:18];
-                            [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:19];
-                            [encoder setBytes:&mode    length:sizeof(     int) atIndex:20];
-                            [encoder setBytes:&freq_base  length:sizeof(float) atIndex:21];
-                            [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
+                            [encoder setBuffer:id_src1 offset:offs_src1        atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:2];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:6];
+                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:10];
+                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:14];
+                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:18];
+                            [encoder setBytes:&n_past  length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&mode    length:sizeof(     int) atIndex:21];
+                            [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22];
+                            [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
                         {
-                            const int nth = 32;
+                            const int nth = MIN(1024, ne00);
 
                             switch (src0t) {
                                 case GGML_TYPE_F32:
@@ -1261,7 +1339,7 @@ void ggml_metal_graph_compute(
                         } break;
                     default:
                         {
-                            metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                            GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
                             GGML_ASSERT(false);
                         }
                 }
@@ -1286,7 +1364,7 @@ void ggml_metal_graph_compute(
 
         MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
         if (status != MTLCommandBufferStatusCompleted) {
-            metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
+            GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
             GGML_ASSERT(false);
         }
     }
index 3087ecda812d90d510ffc6859ff3d4dc52f12e7e..5a860098f157c49c131101f9d2b36d96d60d6e60 100644 (file)
@@ -24,12 +24,59 @@ typedef struct {
     int8_t  qs[QK8_0]; // quants
 } block_q8_0;
 
+// general-purpose kernel for addition of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
+// cons: not very efficient
 kernel void kernel_add(
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] + src1[tpig];
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant  int64_t & nb00,
+        constant  int64_t & nb01,
+        constant  int64_t & nb02,
+        constant  int64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant  int64_t & nb10,
+        constant  int64_t & nb11,
+        constant  int64_t & nb12,
+        constant  int64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant  int64_t & nb0,
+        constant  int64_t & nb1,
+        constant  int64_t & nb2,
+        constant  int64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
+
+        src0_ptr += ntg.x*nb00;
+        src1_ptr += ntg.x*nb10;
+        dst_ptr  += ntg.x*nb0;
+    }
 }
 
 // assumption: src1 is a row
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb,
+        constant    int64_t & nb [[buffer(27)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
@@ -118,7 +165,7 @@ kernel void kernel_soft_max(
     device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    float lmax = psrc0[tpitg[0]];
+    float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
     for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
         lmax = MAX(lmax, psrc0[i00]);
     }
@@ -158,7 +205,7 @@ kernel void kernel_soft_max_4(
     device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
-    float4 lmax4 = psrc4[tpitg[0]];
+    float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
     for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
         lmax4 = fmax(lmax4, psrc4[i00]);
     }
@@ -783,7 +830,9 @@ kernel void kernel_alibi_f32(
         constant  uint64_t & nb1,
         constant  uint64_t & nb2,
         constant  uint64_t & nb3,
-        constant      float & m0,
+        constant     float & m0,
+        constant     float & m1,
+        constant       int & n_heads_log2_floor,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -799,37 +848,73 @@ kernel void kernel_alibi_f32(
     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
 
     device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-    float m_k = pow(m0, i2 + 1);
+    float m_k;
+    if (i2 < n_heads_log2_floor) {
+        m_k = pow(m0, i2 + 1);
+    } else {
+        m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
+    }
     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
         device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
         dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
     }
 }
 
+typedef void (rope_t)(
+        device const    void * src0,
+        device const int32_t * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant     int64_t & ne03,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant    uint64_t & nb03,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & ne2,
+        constant     int64_t & ne3,
+        constant    uint64_t & nb0,
+        constant    uint64_t & nb1,
+        constant    uint64_t & nb2,
+        constant    uint64_t & nb3,
+        constant         int & n_past,
+        constant         int & n_dims,
+        constant         int & mode,
+        constant       float & freq_base,
+        constant       float & freq_scale,
+        uint  tiitg[[thread_index_in_threadgroup]],
+        uint3 tptg[[threads_per_threadgroup]],
+        uint3 tgpig[[threadgroup_position_in_grid]]);
+
+template<typename T>
 kernel void kernel_rope(
-        device const  void * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        constant       int & n_past,
-        constant       int & n_dims,
-        constant       int & mode,
-        constant     float & freq_base,
-        constant     float & freq_scale,
+        device const    void * src0,
+        device const int32_t * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant     int64_t & ne03,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant    uint64_t & nb03,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & ne2,
+        constant     int64_t & ne3,
+        constant    uint64_t & nb0,
+        constant    uint64_t & nb1,
+        constant    uint64_t & nb2,
+        constant    uint64_t & nb3,
+        constant         int & n_past,
+        constant         int & n_dims,
+        constant         int & mode,
+        constant       float & freq_base,
+        constant       float & freq_scale,
         uint  tiitg[[thread_index_in_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
         uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -839,7 +924,9 @@ kernel void kernel_rope(
 
     const bool is_neox = mode & 2;
 
-    const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+    device const int32_t * pos = src1;
+
+    const int64_t p = pos[i2];
 
     const float theta_0 = freq_scale * (float)p;
     const float inv_ndims = -1.f/n_dims;
@@ -851,11 +938,11 @@ kernel void kernel_rope(
             const float cos_theta = cos(theta);
             const float sin_theta = sin(theta);
 
-            device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-            const float x0 = src[0];
-            const float x1 = src[1];
+            const T x0 = src[0];
+            const T x1 = src[1];
 
             dst_data[0] = x0*cos_theta - x1*sin_theta;
             dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -870,8 +957,8 @@ kernel void kernel_rope(
 
                 const int64_t i0 = ib*n_dims + ic/2;
 
-                device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                 const float x0 = src[0];
                 const float x1 = src[n_dims/2];
@@ -883,6 +970,9 @@ kernel void kernel_rope(
     }
 }
 
+template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
+template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
+
 kernel void kernel_cpy_f16_f16(
         device const half * src0,
         device       half * dst,
@@ -1273,8 +1363,8 @@ kernel void kernel_mul_mat_q3_K_f32(
 
     float yl[32];
 
-    const uint16_t kmask1 = 0x3030;
-    const uint16_t kmask2 = 0x0f0f;
+    //const uint16_t kmask1 = 0x3030;
+    //const uint16_t kmask2 = 0x0f0f;
 
     const int tid = tiisg/4;
     const int ix  = tiisg%4;
index 777048d01115779f666457d70aeed3a833bc44b8..7e4069d76b2595737edfeb6c3415a4d8ef8949fe 100644 (file)
@@ -847,7 +847,7 @@ std::array<std::string, 2> mul_str_values = {
     "mul_f32", "float"
 };
 
-std::string& replace(std::string& s, const std::string& from, const std::string& to) {
+static std::string& replace(std::string& s, const std::string& from, const std::string& to) {
     size_t pos = 0;
     while ((pos = s.find(from, pos)) != std::string::npos) {
          s.replace(pos, from.length(), to);
@@ -856,7 +856,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
     return s;
 }
 
-std::string generate_kernels() {
+static std::string generate_kernels() {
     std::stringstream src;
     src << program_source << '\n';
     src << k_quants_source << '\n';
@@ -1476,10 +1476,15 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
 
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
     const float alpha = 1.0f;
     const float beta = 0.0f;
     const int x_ne = ne01 * ne00;
@@ -1498,13 +1503,22 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
     cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
     cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
 
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
+    int64_t pi02 = -1;
+    int64_t pi03 = -1;
+
+    for (int64_t i13 = 0; i13 < ne13; i13++) {
+        int64_t i03 = i13 / r3;
+
+        for (int64_t i12 = 0; i12 < ne12; i12++) {
+            int64_t i02 = i12 / r2;
+
             // copy data to device
-            if (src0->backend != GGML_BACKEND_GPU) {
+            if (src0->backend != GGML_BACKEND_GPU && (i02 != pi02 || i03 != pi03)) {
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
+                pi02 = i02;
+                pi03 = i03;
             }
-            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
+            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
 
             CL_CHECK(clFinish(queue));
 
@@ -1525,7 +1539,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
             }
 
             // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
             CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
         }
     }
@@ -1547,6 +1561,8 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
     const int nb10 = src1->nb[0];
     const int nb11 = src1->nb[1];
@@ -1556,6 +1572,9 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
 
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
     const ggml_fp16_t alpha = ggml_fp32_to_fp16(1.0f);
     const ggml_fp16_t beta = ggml_fp32_to_fp16(0.0f);
     const int x_ne = ne01 * ne00;
@@ -1577,32 +1596,41 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     bool src1_cont_rows = nb10 == sizeof(float);
     bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
 
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
+    int64_t pi02 = -1;
+    int64_t pi03 = -1;
+
+    for (int64_t i13 = 0; i13 < ne13; i13++) {
+        int64_t i03 = i13 / r3;
+
+        for (int64_t i12 = 0; i12 < ne12; i12++) {
+            int64_t i02 = i12 / r2;
+
             // copy src0 to device
-            if (src0->backend != GGML_BACKEND_GPU) {
+            if (src0->backend != GGML_BACKEND_GPU && (i02 != pi02 || i03 != pi03)) {
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
+                pi02 = i02;
+                pi03 = i03;
             }
 
             // convert src1 to fp16
             // TODO: use multiple threads
-            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
-            char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
+            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
+            char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
             if (src1_cont_rows) {
                 if (src1_cont_cols) {
                     ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
                 }
                 else {
-                    for (int64_t i01 = 0; i01 < ne11; i01++) {
-                        ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
+                    for (int64_t i11 = 0; i11 < ne11; i11++) {
+                        ggml_fp32_to_fp16_row((float *) (src1i + i11*nb11), tmp + i11*ne10, ne10);
                     }
                 }
             }
             else {
-                for (int64_t i01 = 0; i01 < ne11; i01++) {
-                    for (int64_t i00 = 0; i00 < ne10; i00++) {
+                for (int64_t i11 = 0; i11 < ne11; i11++) {
+                    for (int64_t i10 = 0; i10 < ne10; i10++) {
                         // very slow due to no inlining
-                        tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
+                        tmp[i11*ne10 + i10] = ggml_fp32_to_fp16(*(float *) (src1i + i11*nb11 + i10*nb10));
                     }
                 }
             }
@@ -1631,7 +1659,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
             // copy dst to host, then convert to float
             CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
 
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
 
             ggml_fp16_to_fp32_row(tmp, d, d_ne);
         }
@@ -1652,12 +1680,17 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
     const ggml_type type = src0->type;
     const bool mul_mat_vec = ne11 == 1;
 
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
     const float alpha = 1.0f;
     const float beta = 0.0f;
     const int x_ne = ne01 * ne00;
@@ -1690,12 +1723,23 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
     size_t ev_idx = 0;
     std::vector<cl_event> events;
 
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
+    int64_t pi02 = -1;
+    int64_t pi03 = -1;
+
+    for (int64_t i13 = 0; i13 < ne13; i13++) {
+        int64_t i03 = i13 / r3;
+
+        for (int64_t i12 = 0; i12 < ne12; i12++) {
+            int64_t i02 = i12 / r2;
+
             // copy src0 to device if necessary
             if (src0->backend == GGML_BACKEND_CPU) {
-                events.emplace_back();
-                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
+                if (i02 != pi02 || i03 != pi03) {
+                    events.emplace_back();
+                    CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
+                    pi02 = i02;
+                    pi03 = i03;
+                }
             } else if (src0->backend == GGML_BACKEND_GPU) {
                 d_Q = (cl_mem) src0->extra;
             } else {
@@ -1704,7 +1748,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
             if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
                 // copy src1 to device
                 events.emplace_back();
-                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, events.data() + ev_idx++));
+                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++));
 
                 // compute
                 const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
@@ -1725,7 +1769,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
                 CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
 
                 // copy src1 to device
-                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
+                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
 
                 events.emplace_back();
 
@@ -1749,7 +1793,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
             }
 
             // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
             CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL));
             for (auto *event : events) {
                 clReleaseEvent(event);
@@ -1788,7 +1832,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
     return false;
 }
 
-bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
+static bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
     // If device doesn't support FP16
     if (!fp16_support) {
         return false;
index 2a193849949c8164c814ae74fb54971ce0fde4c5..b7206908783481b548d96d138d557cb1a03b1551 100644 (file)
@@ -89,7 +89,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo
 
 static int pthread_join(pthread_t thread, void * unused) {
     (void) unused;
-    return (int) WaitForSingleObject(thread, INFINITE);
+    int ret = (int) WaitForSingleObject(thread, INFINITE);
+    CloseHandle(thread);
+    return ret;
 }
 
 static int sched_yield (void) {
@@ -134,6 +136,7 @@ typedef void * thread_ret_t;
 
 #define GGML_SOFT_MAX_UNROLL 4
 #define GGML_VEC_DOT_UNROLL  2
+#define GGML_VEC_MAD_UNROLL  32
 
 //
 // logging
@@ -242,18 +245,18 @@ inline static void * ggml_aligned_malloc(size_t size) {
 //
 
 #define GGML_TENSOR_UNARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb); \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne); \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb);
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 
 #define GGML_TENSOR_BINARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb); \
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb); \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne); \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb);
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 
 #if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
@@ -283,7 +286,7 @@ typedef double ggml_float;
 // 16-bit float
 // on Arm, we use __fp16
 // on x86, we use uint16_t
-#ifdef __ARM_NEON
+#if defined(__ARM_NEON) && !defined(_MSC_VER)
 
 // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
 //
@@ -1029,8 +1032,8 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
             y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
 
             // get the 5-th bit and store it in qh at the right position
-            qh |= ((xi0 & 0x10) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
         }
 
         memcpy(&y[i].qh, &qh, sizeof(qh));
@@ -1077,8 +1080,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r
             y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
 
             // get the 5-th bit and store it in qh at the right position
-            qh |= ((xi0 & 0x10) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10) >> 4) << (j + qk/2);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
         }
 
         memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
@@ -1269,6 +1272,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
         _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
 #endif
     }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+    }
 #else
     // scalar
     quantize_row_q8_0_reference(x, y, k);
@@ -1487,6 +1517,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
         _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
 #endif
     }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d  = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+
+        // compute sum for y[i].s
+        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
+        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+
+        // set y[i].s
+        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
+        y[i].s = sum*d;
+    }
 #else
     // scalar
     quantize_row_q8_1_reference(x, y, k);
@@ -1863,7 +1928,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
     #define GGML_F16x8_ADD          vaddq_f16
     #define GGML_F16x8_MUL          vmulq_f16
     #define GGML_F16x8_REDUCE(res, x)                             \
-    {                                                             \
+    do {                                                          \
         int offset = GGML_F16_ARR >> 1;                           \
         for (int i = 0; i < offset; ++i) {                        \
             x[i] = vaddq_f16(x[i], x[offset+i]);                  \
@@ -1879,7 +1944,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
         const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
         const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
         res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
-    }
+    } while (0)
 
     #define GGML_F16_VEC                GGML_F16x8
     #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
@@ -1940,7 +2005,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
 #define GGML_F32x8_ADD     _mm256_add_ps
 #define GGML_F32x8_MUL     _mm256_mul_ps
 #define GGML_F32x8_REDUCE(res, x)                                 \
-{                                                                 \
+do {                                                              \
     int offset = GGML_F32_ARR >> 1;                               \
     for (int i = 0; i < offset; ++i) {                            \
         x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
@@ -1957,7 +2022,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
                                  _mm256_extractf128_ps(x[0], 1)); \
     const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
     res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));                     \
-}
+} while (0)
 // TODO: is this optimal ?
 
 #define GGML_F32_VEC        GGML_F32x8
@@ -2659,30 +2724,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
     for (int i = 0; i < nb; i++) {
-        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
 
-        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
-        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
 
-        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
-        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
 
-        vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
-        vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
 
-        vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
-        vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
+        // subtract offset
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
 
-        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
-        vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
 
         vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
 
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
 
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
-        sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
 
         sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
     }
@@ -2820,27 +2887,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
     size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
     for (int i = 0; i < nb; i++) {
-        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
 
-        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
-        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
 
-        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
-        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
 
-        vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
-        vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
 
-        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
-        vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
 
         vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
 
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
 
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
-        sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
 
         sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
     }
@@ -3085,66 +3153,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
 
     uint32_t qh;
 
-    // These temp values are for masking and shift operations
-    uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
-    uint32_t temp_2[16] = {0x1,   0x2,   0x4,   0x8,   0x10,   0x20,   0x40,   0x80,
-                         0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
-
     size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
+    // These tempory registers are for masking and shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+
+    vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
+    vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
     for (int i = 0; i < nb; i++) {
         memcpy(&qh, x[i].qh, sizeof(uint32_t));
 
-        // temporary registers
-        vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
-        vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
-        vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
-        vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
-
         // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-        vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl);
-        vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl);
-        vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
 
         // ((qh & (1u << (j + 16))) >> (j + 12));
-        vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl);
-        vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl);
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
+        vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
 
         // narrowing
-        vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl);
-        vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
 
-        vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl);
-        vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
 
         // load
-        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
 
-        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
-        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
 
-        vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
-        vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
 
-        vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
-        vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
 
-        vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
-        vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
 
-        vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl);
-        vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl);
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
 
-        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
-        vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
 
         vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
 
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
 
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
-        sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
 
         sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
     }
@@ -3411,62 +3474,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
 
     uint32_t qh;
 
-    // These temp values are for shift operations
-    uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
-
     size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
+    // temporary registers for shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
     for (int i = 0; i < nb; i++) {
         memcpy(&qh, x[i].qh, sizeof(uint32_t));
 
-        // temporary registers
-        vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
-        vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
-
         // load qh
-        vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl);
+        vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
 
         // ((qh >> (j +  0)) << 4) & 0x10;
-        vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl);
-        vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
-        vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl);
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
 
         // ((qh >> (j + 12))     ) & 0x10;
-        vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl);
-        vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl);
+        vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
 
         // narrowing
-        vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl);
-        vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
 
-        vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl);
-        vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
 
         // load
-        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
 
-        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
-        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
 
-        vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
-        vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
 
-        vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
-        vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
 
-        vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
-        vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
 
-        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
-        vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
 
         vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
 
-        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
-        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
 
-        int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
-        sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
 
         sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
     }
@@ -3707,6 +3766,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
 #endif
 }
 
+// xs and vs are byte strides of x and v
+inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
+
+    const float * restrict x[GGML_VEC_MAD_UNROLL];
+    const float * restrict v[GGML_VEC_MAD_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
+        x[i] = (const float *) ((const char *) xv + i*xs);
+        v[i] = (const float *) ((const char *) vv + i*vs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
+
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        vx[k] = GGML_F32_VEC_SET1(v[k][0]);
+    }
+
+    GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+            for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+                ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
+                ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
+            }
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        for (int i = np; i < n; ++i) {
+            y[i] += x[k][i]*v[k][0];
+        }
+    }
+#else
+    // scalar
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        for (int i = 0; i < n; ++i) {
+            y[i] += x[k][i]*v[k][0];
+        }
+    }
+#endif
+}
+
 //inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
 inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #if defined(GGML_USE_ACCELERATE)
@@ -4403,10 +4514,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
 static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    return
-        (t0->ne[1] == t1->ne[1])  &&
-        (t0->ne[2] == t1->ne[2])  &&
-        (t0->ne[3] == t1->ne[3]);
+    return (t0->ne[1] == t1->ne[1])   &&
+           (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
+           (t1->ne[3]%t0->ne[3] == 0);
 }
 
 enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
@@ -5076,43 +5186,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
     return tensor;
 }
 
+void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
+    const int64_t ne2 = tensor->ne[2];
+    const int64_t ne1 = tensor->ne[1];
+    const int64_t ne0 = tensor->ne[0];
+
+    const int64_t i3_ = (i/(ne2*ne1*ne0));
+    const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
+    const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
+    const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
+
+    if (i0) {
+        * i0 = i0_;
+    }
+    if (i1) {
+        * i1 = i1_;
+    }
+    if (i2) {
+        * i2 = i2_;
+    }
+    if (i3) {
+        * i3 = i3_;
+    }
+}
+
 int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
+    }
     switch (tensor->type) {
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
                 return ((int8_t *)(tensor->data))[i];
-            } break;
+            }
         case GGML_TYPE_I16:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
                 return ((int16_t *)(tensor->data))[i];
-            } break;
+            }
         case GGML_TYPE_I32:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
                 return ((int32_t *)(tensor->data))[i];
-            } break;
+            }
         case GGML_TYPE_F16:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
                 return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
-            } break;
+            }
         case GGML_TYPE_F32:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 return ((float *)(tensor->data))[i];
-            } break;
+            }
         default:
             {
                 GGML_ASSERT(false);
-            } break;
+            }
     }
 
     return 0.0f;
 }
 
 void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
+        return;
+    }
     switch (tensor->type) {
         case GGML_TYPE_I8:
             {
@@ -5146,43 +5291,104 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
     }
 }
 
+int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            return ((int8_t *) data)[0];
+        case GGML_TYPE_I16:
+            return ((int16_t *) data)[0];
+        case GGML_TYPE_I32:
+            return ((int32_t *) data)[0];
+        case GGML_TYPE_F16:
+            return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+        case GGML_TYPE_F32:
+            return ((float *) data)[0];
+        default:
+            GGML_ASSERT(false);
+    }
+
+    return 0.0f;
+}
+
+void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                ((int8_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                ((int16_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ((int32_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ((float *)(data))[0] = value;
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
+    }
     switch (tensor->type) {
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
                 return ((int8_t *)(tensor->data))[i];
-            } break;
+            }
         case GGML_TYPE_I16:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
                 return ((int16_t *)(tensor->data))[i];
-            } break;
+            }
         case GGML_TYPE_I32:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
                 return ((int32_t *)(tensor->data))[i];
-            } break;
+            }
         case GGML_TYPE_F16:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
                 return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
-            } break;
+            }
         case GGML_TYPE_F32:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 return ((float *)(tensor->data))[i];
-            } break;
+            }
         default:
             {
                 GGML_ASSERT(false);
-            } break;
+            }
     }
 
     return 0.0f;
 }
 
 void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
+        return;
+    }
     switch (tensor->type) {
         case GGML_TYPE_I8:
             {
@@ -5216,6 +5422,56 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
     }
 }
 
+float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            return ((int8_t *) data)[0];
+        case GGML_TYPE_I16:
+            return ((int16_t *) data)[0];
+        case GGML_TYPE_I32:
+            return ((int32_t *) data)[0];
+        case GGML_TYPE_F16:
+            return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+        case GGML_TYPE_F32:
+            return ((float *) data)[0];
+        default:
+            GGML_ASSERT(false);
+    }
+
+    return 0.0f;
+}
+
+void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                ((int8_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                ((int16_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ((int32_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ((float *)(data))[0] = value;
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 void * ggml_get_data(const struct ggml_tensor * tensor) {
     return tensor->data;
 }
@@ -5358,6 +5614,44 @@ struct ggml_tensor * ggml_add_inplace(
     return ggml_add_impl(ctx, a, b, true);
 }
 
+// ggml_add_cast
+
+static struct ggml_tensor * ggml_add_cast_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        enum   ggml_type     type) {
+    // TODO: support less-strict constraint
+    //       GGML_ASSERT(ggml_can_repeat(b, a));
+    GGML_ASSERT(ggml_can_repeat_rows(b, a));
+    GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        // TODO: support backward pass for broadcasting
+        GGML_ASSERT(ggml_are_same_shape(a, b));
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
+
+    result->op   = GGML_OP_ADD;
+    result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_add_cast(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b,
+        enum   ggml_type     type) {
+    return ggml_add_cast_impl(ctx, a, b, type);
+}
+
 // ggml_add1
 
 static struct ggml_tensor * ggml_add1_impl(
@@ -5794,7 +6088,6 @@ struct ggml_tensor * ggml_repeat(
     result->op   = GGML_OP_REPEAT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -5822,7 +6115,6 @@ struct ggml_tensor * ggml_repeat_back(
     result->op   = GGML_OP_REPEAT_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -6197,8 +6489,9 @@ struct ggml_tensor * ggml_out_prod(
         is_node = true;
     }
 
-    const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
+    // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
+    const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
 
     result->op   = GGML_OP_OUT_PROD;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6417,45 +6710,92 @@ struct ggml_tensor * ggml_cont_inplace(
     return ggml_cont_impl(ctx, a, true);
 }
 
-// ggml_reshape
 
-struct ggml_tensor * ggml_reshape(
+// make contiguous, with new shape
+GGML_API struct ggml_tensor * ggml_cont_1d(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
-    GGML_ASSERT(ggml_is_contiguous(a));
-    GGML_ASSERT(ggml_is_contiguous(b));
-    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
-
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
-    if (b->grad) {
-        // gradient propagation is not supported
-        //GGML_ASSERT(false);
-    }
-
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a, 0);
-    ggml_format_name(result, "%s (reshaped)", a->name);
-
-    result->op   = GGML_OP_RESHAPE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
+        struct ggml_tensor  * a,
+        int64_t               ne0) {
+    return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
 }
 
-struct ggml_tensor * ggml_reshape_1d(
+GGML_API struct ggml_tensor * ggml_cont_2d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int64_t               ne0) {
-    GGML_ASSERT(ggml_is_contiguous(a));
-    GGML_ASSERT(ggml_nelements(a) == ne0);
+        int64_t               ne0,
+        int64_t               ne1) {
+    return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
+}
 
-    bool is_node = false;
+GGML_API struct ggml_tensor * ggml_cont_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2) {
+    return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
+}
+
+struct ggml_tensor * ggml_cont_4d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3) {
+    GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
+
+    bool is_node = false;
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
+    ggml_format_name(result, "%s (cont)", a->name);
+
+    result->op   = GGML_OP_CONT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_reshape
+
+struct ggml_tensor * ggml_reshape(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    if (b->grad) {
+        // gradient propagation is not supported
+        //GGML_ASSERT(false);
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a, 0);
+    ggml_format_name(result, "%s (reshaped)", a->name);
+
+    result->op   = GGML_OP_RESHAPE;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0);
+
+    bool is_node = false;
 
     if (a->grad) {
         is_node = true;
@@ -6797,7 +7137,6 @@ struct ggml_tensor * ggml_get_rows_back(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 }
@@ -6979,7 +7318,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
 static struct ggml_tensor * ggml_rope_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   n_past,
+        struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
@@ -6988,7 +7327,10 @@ static struct ggml_tensor * ggml_rope_impl(
         float                 xpos_base,
         bool                  xpos_down,
         bool                  inplace) {
-    GGML_ASSERT(n_past >= 0);
+    GGML_ASSERT(ggml_is_vector(b));
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] == b->ne[0]);
+
     bool is_node = false;
 
     if (a->grad) {
@@ -6997,7 +7339,7 @@ static struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[8] = { n_past, n_dims, mode, n_ctx };
+    int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
     memcpy(params + 4, &freq_base,  sizeof(float));
     memcpy(params + 5, &freq_scale, sizeof(float));
     memcpy(params + 6, &xpos_base,  sizeof(float));
@@ -7007,6 +7349,7 @@ static struct ggml_tensor * ggml_rope_impl(
     result->op   = GGML_OP_ROPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
+    result->src[1] = b;
 
     return result;
 }
@@ -7014,55 +7357,55 @@ static struct ggml_tensor * ggml_rope_impl(
 struct ggml_tensor * ggml_rope(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   n_past,
+        struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
+    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   n_past,
+        struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
+    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
 }
 
 struct ggml_tensor * ggml_rope_custom(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   n_past,
+        struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
         float                 freq_base,
         float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
+    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
 }
 
 struct ggml_tensor * ggml_rope_custom_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   n_past,
+        struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
         float                 freq_base,
         float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
+    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
 }
 
 struct ggml_tensor * ggml_rope_xpos_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   n_past,
+        struct ggml_tensor  * b,
         int                   n_dims,
         float                 base,
         bool                  down) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
+    return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
 }
 
 // ggml_rope_back
@@ -7070,7 +7413,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
 struct ggml_tensor * ggml_rope_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   n_past,
+        struct ggml_tensor  * b,
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
@@ -7078,7 +7421,10 @@ struct ggml_tensor * ggml_rope_back(
         float                 freq_scale,
         float                 xpos_base,
         bool                  xpos_down) {
-    GGML_ASSERT(n_past >= 0);
+    GGML_ASSERT(ggml_is_vector(b));
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] == b->ne[0]);
+
     GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
 
     bool is_node = false;
@@ -7089,7 +7435,7 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    int32_t params[8] = { n_past, n_dims, mode, n_ctx };
+    int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
     memcpy(params + 4, &freq_base,  sizeof(float));
     memcpy(params + 5, &freq_scale, sizeof(float));
     memcpy(params + 6, &xpos_base,  sizeof(float));
@@ -7099,6 +7445,7 @@ struct ggml_tensor * ggml_rope_back(
     result->op   = GGML_OP_ROPE_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
+    result->src[1] = b;
 
     return result;
 }
@@ -7161,7 +7508,7 @@ struct ggml_tensor * ggml_clamp(
     return result;
 }
 
-// ggml_conv_1d_stage_0
+// ggml_conv_1d
 
 static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
     return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
@@ -7627,27 +7974,30 @@ struct ggml_tensor * ggml_flash_attn_back(
 
     // d shape [D,N,ne2,ne3]
     // q shape [D,N,ne2,ne3]
-    // k shape [D,M,ne2,ne3]
-    // v shape [M,D,ne2,ne3]
+    // k shape [D,M,kvne2,ne3]
+    // v shape [M,D,kvne2,ne3]
 
-    const int64_t   D = q->ne[0];
-    const int64_t   N = q->ne[1];
-    const int64_t   M = k->ne[1];
-    const int64_t ne2 = q->ne[2];
-    const int64_t ne3 = q->ne[3];
+    const int64_t     D = q->ne[0];
+    const int64_t     N = q->ne[1];
+    const int64_t     M = k->ne[1];
+    const int64_t   ne2 = q->ne[2];
+    const int64_t   ne3 = q->ne[3];
+    const int64_t kvne2 = k->ne[2];
 
     GGML_ASSERT(k->ne[0] == D);
     GGML_ASSERT(v->ne[0] == M);
     GGML_ASSERT(v->ne[1] == D);
     GGML_ASSERT(d->ne[0] == D);
     GGML_ASSERT(d->ne[1] == N);
-    GGML_ASSERT(k->ne[2] == ne2);
+    GGML_ASSERT(k->ne[2] == kvne2);
     GGML_ASSERT(k->ne[3] == ne3);
-    GGML_ASSERT(v->ne[2] == ne2);
+    GGML_ASSERT(v->ne[2] == kvne2);
     GGML_ASSERT(v->ne[3] == ne3);
     GGML_ASSERT(d->ne[2] == ne2);
     GGML_ASSERT(d->ne[3] == ne3);
 
+    GGML_ASSERT(ne2 % kvne2 == 0);
+
     bool is_node = false;
 
     if (q->grad || k->grad || v->grad) {
@@ -7657,14 +8007,23 @@ struct ggml_tensor * ggml_flash_attn_back(
     }
 
     // store gradients of q, k and v as continuous tensors concatenated in result.
-    // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
-    // gradq->data = result->data
-    // gradk->data = result->data + nb0*D*N*ne2*ne3
-    // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
     // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
-    int64_t ne[4] = {D,M+N+M,ne2,ne3};
+    const int64_t elem_q = ggml_nelements(q);
+    const int64_t elem_k = ggml_nelements(k);
+    const int64_t elem_v = ggml_nelements(v);
 
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+    enum ggml_type result_type = GGML_TYPE_F32;
+    GGML_ASSERT(ggml_blck_size(result_type) == 1);
+    const size_t tsize = ggml_type_size(result_type);
+
+    const size_t offs_q = 0;
+    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+    const size_t end    = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
+
+    const size_t nelements = (end + tsize - 1)/tsize;
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
 
     int32_t masked_i = masked ? 1 : 0;
     ggml_set_op_params(result, &masked_i, sizeof(masked_i));
@@ -8357,7 +8716,7 @@ static void ggml_compute_forward_dup_f16(
         return;
     }
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     const int ith = params->ith; // thread index
     const int nth = params->nth; // number of threads
@@ -8628,7 +8987,7 @@ static void ggml_compute_forward_dup_f32(
         return;
     }
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     const int ith = params->ith; // thread index
     const int nth = params->nth; // number of threads
@@ -8909,7 +9268,7 @@ static void ggml_compute_forward_add_f32(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -8941,8 +9300,6 @@ static void ggml_compute_forward_add_f32(
 #else
             ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
 #endif
-                // }
-            // }
         }
     } else {
         // src1 is not contiguous
@@ -8984,7 +9341,7 @@ static void ggml_compute_forward_add_f16_f32(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -9038,7 +9395,7 @@ static void ggml_compute_forward_add_f16_f16(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -9089,14 +9446,15 @@ static void ggml_compute_forward_add_q_f32(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
 
     const enum ggml_type type = src0->type;
+    const enum ggml_type dtype = dst->type;
     ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
-    ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
+    ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == ggml_type_size(type));
@@ -9108,7 +9466,6 @@ static void ggml_compute_forward_add_q_f32(
     GGML_ASSERT(nb2 <= nb3);
 
     GGML_ASSERT(ggml_is_quantized(src0->type));
-    GGML_ASSERT(dst->type == src0->type);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
     // rows per thread
@@ -9146,7 +9503,11 @@ static void ggml_compute_forward_add_q_f32(
         // add src1
         ggml_vec_acc_f32(ne00, wdata, src1_row);
         // quantize row to dst
-        quantize_row_q(wdata, dst_row, ne00);
+        if (quantize_row_q != NULL) {
+            quantize_row_q(wdata, dst_row, ne00);
+        } else {
+            memcpy(dst_row, wdata, ne0*nb0);
+        }
     }
 }
 
@@ -9211,7 +9572,7 @@ static void ggml_compute_forward_add1_f32(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9266,7 +9627,7 @@ static void ggml_compute_forward_add1_f16_f32(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -9316,7 +9677,7 @@ static void ggml_compute_forward_add1_f16_f16(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -9366,7 +9727,7 @@ static void ggml_compute_forward_add1_q_f32(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     const enum ggml_type type = src0->type;
     ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
@@ -9494,8 +9855,8 @@ static void ggml_compute_forward_acc_f32(
     const int nr = ggml_nrows(src1);
     const int nc = src1->ne[0];
 
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
 
     // src0 and dst as viewed during acc
     const size_t nb0 = ggml_element_size(src0);
@@ -9584,7 +9945,7 @@ static void ggml_compute_forward_sub_f32(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9674,7 +10035,7 @@ static void ggml_compute_forward_mul_f32(
 
     const int64_t nr = ggml_nrows(src0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9765,7 +10126,7 @@ static void ggml_compute_forward_div_f32(
 
     const int nr  = ggml_nrows(src0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9974,8 +10335,8 @@ static void ggml_compute_forward_sum_f32(
     assert(ggml_is_scalar(dst));
     assert(src0->nb[0] == sizeof(float));
 
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb);
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 
     ggml_float sum     = 0;
     ggml_float row_sum = 0;
@@ -10006,8 +10367,8 @@ static void ggml_compute_forward_sum_f16(
 
     assert(src0->nb[0] == sizeof(ggml_fp16_t));
 
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb);
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 
     float sum = 0;
     float row_sum = 0;
@@ -10060,7 +10421,7 @@ static void ggml_compute_forward_sum_rows_f32(
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT(dst->nb[0] == sizeof(float));
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     GGML_ASSERT(ne0 == 1);
     GGML_ASSERT(ne1 == ne01);
@@ -10110,7 +10471,7 @@ static void ggml_compute_forward_mean_f32(
 
     assert(src0->nb[0] == sizeof(float));
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     assert(ne0 == 1);
     assert(ne1 == ne01);
@@ -10210,7 +10571,7 @@ static void ggml_compute_forward_repeat_f32(
         return;
     }
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
     const int nr0 = (int)(ne0/ne00);
@@ -10242,11 +10603,61 @@ static void ggml_compute_forward_repeat_f32(
     }
 }
 
+static void ggml_compute_forward_repeat_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
+                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
+                                // ggml_vec_cpy_f16(ne00, y, x)
+                                for (int i = 0; i < ne00; ++i) {
+                                    y[i]  = x[i];
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
 static void ggml_compute_forward_repeat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_repeat_f16(params, src0, dst);
+            } break;
         case GGML_TYPE_F32:
             {
                 ggml_compute_forward_repeat_f32(params, src0, dst);
@@ -10271,7 +10682,7 @@ static void ggml_compute_forward_repeat_back_f32(
         return;
     }
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
     const int nr0 = (int)(ne00/ne0);
@@ -10349,7 +10760,7 @@ static void ggml_compute_forward_concat_f32(
 
     const int ith = params->ith;
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     // TODO: support for transposed / permuted tensors
     GGML_ASSERT(nb0  == sizeof(float));
@@ -10951,7 +11362,7 @@ static void ggml_compute_forward_norm_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
@@ -11020,7 +11431,7 @@ static void ggml_compute_forward_rms_norm_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
@@ -11085,7 +11496,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
@@ -11260,7 +11671,7 @@ static void ggml_compute_forward_group_norm_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     const float eps = 1e-6f; // TODO: make this a parameter
 
@@ -11371,7 +11782,7 @@ static void ggml_compute_forward_mul_mat(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -11408,11 +11819,6 @@ static void ggml_compute_forward_mul_mat(
 
 #if defined(GGML_USE_CLBLAST)
     if (ggml_cl_can_mul_mat(src0, src1, dst)) {
-        // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
-        //       ref: https://github.com/ggerganov/ggml/pull/224
-        GGML_ASSERT(ne02 == ne12);
-        GGML_ASSERT(ne03 == ne13);
-
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
             ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
         }
@@ -11586,10 +11992,10 @@ static void ggml_compute_forward_out_prod_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
+    // int64_t t0 = ggml_perf_time_us();
+    // UNUSED(t0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -11628,6 +12034,146 @@ static void ggml_compute_forward_out_prod_f32(
         return;
     }
 
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // block-tiling attempt
+    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
+    const int64_t blck_1 = 16;
+
+    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
+        const int64_t bir1 = MIN(bir + blck_1, ir1);
+        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
+            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
+            for (int64_t ir = bir; ir < bir1; ++ir) {
+                // dst indices
+                const int64_t i3 = ir/(ne2*ne1);
+                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                const int64_t i02 = i2;
+                const int64_t i03 = i3;
+
+                //const int64_t i10 = i1;
+                const int64_t i12 = i2;
+                const int64_t i13 = i3;
+
+#if GGML_VEC_MAD_UNROLL > 2
+                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
+                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
+                }
+                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32(ne0, d, s0, *s1);
+                }
+#else
+                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32(ne0, d, s0, *s1);
+                }
+#endif
+            }
+        }
+    }
+
+
+    //int64_t t1 = ggml_perf_time_us();
+    //static int64_t acc = 0;
+    //acc += t1 - t0;
+    //if (t1 - t0 > 10) {
+    //    printf("\n");
+    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
+    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
+    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+
+    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
+    //}
+}
+
+static void ggml_compute_forward_out_prod_q_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    // int64_t t0 = ggml_perf_time_us();
+    // UNUSED(t0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // we don't support permuted src0 dim0
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+
+    // dst dim0 cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
+    // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+
+    if (params->type == GGML_TASK_INIT) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
     // parallelize by last three dimensions
 
     // total rows in dst
@@ -11647,6 +12193,8 @@ static void ggml_compute_forward_out_prod_f32(
     //       for i0:
     //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
 
+    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
     for (int64_t ir = ir0; ir < ir1; ++ir) {
         // dst indices
         const int64_t i3 = ir/(ne2*ne1);
@@ -11667,10 +12215,8 @@ static void ggml_compute_forward_out_prod_f32(
             float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
             float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
 
-            ggml_vec_mad_f32(ne0, d, s0, *s1);
-            // for (int64_t i0 = 0; i0 < ne0; ++i0) {
-            //     d[i0] += s0[i0] * s1[i1];
-            // }
+            dequantize_row_q(s0, wdata, ne0);
+            ggml_vec_mad_f32(ne0, d, wdata, *s1);
         }
     }
 
@@ -11699,10 +12245,13 @@ static void ggml_compute_forward_out_prod(
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
             {
-                GGML_ASSERT(false); // todo
-                // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
+                ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F16:
             {
@@ -11820,8 +12369,8 @@ static void ggml_compute_forward_set_f32(
     const int nr = ggml_nrows(src1);
     const int nc = src1->ne[0];
 
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
 
     // src0 and dst as viewed during set
     const size_t nb0 = ggml_element_size(src0);
@@ -12090,14 +12639,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
               struct ggml_tensor * dst) {
     GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_are_same_shape(opt0, dst));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
     GGML_ASSERT(ggml_is_contiguous(dst));
 
-    ggml_compute_forward_dup_same_cont(params, opt0, dst);
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    if (params->type == GGML_TASK_INIT) {
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12123,11 +12673,8 @@ static void ggml_compute_forward_get_rows_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
               struct ggml_tensor * dst) {
     GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_are_same_shape(opt0, dst));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
     GGML_ASSERT(ggml_is_contiguous(dst));
 
     // ggml_compute_forward_dup_same_cont(params, opt0, dst);
@@ -12161,16 +12708,15 @@ static void ggml_compute_forward_get_rows_back(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
+                ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
+                ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
             } break;
         default:
             {
@@ -12211,7 +12757,7 @@ static void ggml_compute_forward_diag_f32(
 
     // TODO: handle transposed/permuted matrices
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     GGML_ASSERT(ne00 == ne0);
     GGML_ASSERT(ne00 == ne1);
@@ -12599,13 +13145,11 @@ static void ggml_compute_forward_alibi_f16(
         return;
     }
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-    assert(n_past >= 0);
-
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
     const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -12620,7 +13164,7 @@ static void ggml_compute_forward_alibi_f16(
     //const int nb3 = src0->nb[3];
 
     GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
+    //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
     GGML_ASSERT(n_head == ne2);
 
     // add alibi to src0 (KQ_scaled)
@@ -12766,8 +13310,8 @@ static void ggml_compute_forward_clamp(
 static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
@@ -12777,9 +13321,9 @@ static void ggml_compute_forward_rope_f32(
 
     // these two only relevant for xPos RoPE:
     float xpos_base;
-    bool xpos_down;
+    bool  xpos_down;
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
     const int n_ctx  = ((int32_t *) dst->op_params)[3];
@@ -12788,9 +13332,7 @@ static void ggml_compute_forward_rope_f32(
     memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
     memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
 
-    assert(n_past >= 0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12820,9 +13362,11 @@ static void ggml_compute_forward_rope_f32(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    const int32_t * pos = (const int32_t *) src1->data;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            const int64_t p = pos[i2];
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
@@ -12859,7 +13403,7 @@ static void ggml_compute_forward_rope_f32(
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
                         // zeta scaling for xPos only:
-                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
+                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
 
                         theta *= theta_scale;
@@ -12904,8 +13448,8 @@ static void ggml_compute_forward_rope_f32(
 static void ggml_compute_forward_rope_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
@@ -12913,16 +13457,14 @@ static void ggml_compute_forward_rope_f16(
     float freq_base;
     float freq_scale;
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
     const int n_ctx  = ((int32_t *) dst->op_params)[3];
     memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
-    assert(n_past >= 0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12952,9 +13494,11 @@ static void ggml_compute_forward_rope_f16(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    const int32_t * pos = (const int32_t *) src1->data;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            const int64_t p = pos[i2];
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
@@ -13033,15 +13577,16 @@ static void ggml_compute_forward_rope_f16(
 static void ggml_compute_forward_rope(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_f16(params, src0, dst);
+                ggml_compute_forward_rope_f16(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_f32(params, src0, dst);
+                ggml_compute_forward_rope_f32(params, src0, src1, dst);
             } break;
         default:
             {
@@ -13055,6 +13600,7 @@ static void ggml_compute_forward_rope(
 static void ggml_compute_forward_rope_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -13072,7 +13618,7 @@ static void ggml_compute_forward_rope_back_f32(
     float xpos_base;
     bool xpos_down;
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
     const int n_ctx  = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
@@ -13081,9 +13627,7 @@ static void ggml_compute_forward_rope_back_f32(
     memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
     memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
 
-    assert(n_past >= 0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -13109,9 +13653,11 @@ static void ggml_compute_forward_rope_back_f32(
 
     const bool is_neox = mode & 2;
 
+    const int32_t * pos = (const int32_t *) src1->data;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            const int64_t p = pos[i2];
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
@@ -13123,7 +13669,7 @@ static void ggml_compute_forward_rope_back_f32(
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
                         // zeta scaling for xPos only:
-                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
+                        float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
 
                         theta *= theta_scale;
@@ -13166,6 +13712,7 @@ static void ggml_compute_forward_rope_back_f32(
 static void ggml_compute_forward_rope_back_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -13176,13 +13723,11 @@ static void ggml_compute_forward_rope_back_f16(
     // dx = rope_back(dy, src1)
     // src0 is dy, src1 contains options
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_dims = ((int32_t *) dst->op_params)[1];
     const int mode   = ((int32_t *) dst->op_params)[2];
 
-    assert(n_past >= 0);
-
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -13208,9 +13753,11 @@ static void ggml_compute_forward_rope_back_f16(
 
     const bool is_neox = mode & 2;
 
+    const int32_t * pos = (const int32_t *) src1->data;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            const int64_t p = pos[i2];
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
@@ -13262,15 +13809,16 @@ static void ggml_compute_forward_rope_back_f16(
 static void ggml_compute_forward_rope_back(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_back_f16(params, src0, dst);
+                ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_back_f32(params, src0, dst);
+                ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
             } break;
         default:
             {
@@ -13293,7 +13841,7 @@ static void ggml_compute_forward_conv_1d_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13374,7 +13922,7 @@ static void ggml_compute_forward_conv_1d_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13702,7 +14250,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13793,7 +14341,7 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13907,7 +14455,7 @@ static void ggml_compute_forward_conv_2d_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -14029,7 +14577,7 @@ static void ggml_compute_forward_conv_transpose_2d(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_BINARY_OP_LOCALS;
+    GGML_TENSOR_BINARY_OP_LOCALS
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -14288,7 +14836,7 @@ static void ggml_compute_forward_upscale_f32(
 
     const int ith = params->ith;
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     const int scale_factor = dst->op_params[0];
 
@@ -14340,14 +14888,14 @@ static void ggml_compute_forward_flash_attn_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb);
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb);
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb);
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne);
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb);
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -14417,10 +14965,11 @@ static void ggml_compute_forward_flash_attn_f32(
             S[i] = -INFINITY;
         }
 
-        for (int64_t ic = 0; ic < nek1; ++ic) {
+        const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+        for (int64_t ic = 0; ic < masked_begin; ++ic) {
             // k indices
             const int ik3 = iq3;
-            const int ik2 = iq2;
+            const int ik2 = iq2 % nek2;
             const int ik1 = ic;
 
             // S indices
@@ -14433,20 +14982,18 @@ static void ggml_compute_forward_flash_attn_f32(
         }
 
         // scale
-        ggml_vec_scale_f32(nek1, S, scale);
+        ggml_vec_scale_f32(masked_begin, S, scale);
 
-        if (masked) {
-            for (int64_t i = P; i < M; i++) {
-                if (i > P + iq1) {
-                    S[i] = -INFINITY;
-                }
-            }
+        for (int64_t i = masked_begin; i < M; i++) {
+            S[i] = -INFINITY;
         }
 
         // softmax
+        // exclude known -INF S[..] values from max and loop
+        // dont forget to set their SW values to zero
         {
             float max = -INFINITY;
-            ggml_vec_max_f32(M, &max, S);
+            ggml_vec_max_f32(masked_begin, &max, S);
 
             ggml_float sum = 0.0;
             {
@@ -14460,10 +15007,15 @@ static void ggml_compute_forward_flash_attn_f32(
                 ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 
                 for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+                    if (i >= masked_begin) {
+                        break;
+                    }
                     float * SS = S + i;
 
                     for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
-                        if (SS[j] == -INFINITY) {
+                        if (i + j >= masked_begin) {
+                            break;
+                        } else if (SS[j] == -INFINITY) {
                             SS[j] = 0.0f;
                         } else {
 #ifndef GGML_FLASH_ATTN_EXP_FP16
@@ -14488,10 +15040,10 @@ static void ggml_compute_forward_flash_attn_f32(
             assert(sum > 0.0);
 
             sum = 1.0/sum;
-            ggml_vec_scale_f32(M, S, sum);
+            ggml_vec_scale_f32(masked_begin, S, sum);
 
 #ifndef NDEBUG
-            for (int i = 0; i < M; ++i) {
+            for (int i = 0; i < masked_begin; ++i) {
                 assert(!isnan(S[i]));
                 assert(!isinf(S[i]));
             }
@@ -14504,9 +15056,13 @@ static void ggml_compute_forward_flash_attn_f32(
             const int i2 = iq2;
             const int i3 = iq3;
 
-            ggml_vec_dot_f32(nek1,
-                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
-                    (float *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+            // v indices
+            const int iv2 = iq2 % nev2;
+            const int iv3 = iq3;
+
+            ggml_vec_dot_f32(masked_begin,
+                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)),
+                    (float *) ((char *) v->data   + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
                     S);
         }
     }
@@ -14522,14 +15078,14 @@ static void ggml_compute_forward_flash_attn_f16(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb);
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb);
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb);
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne);
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb);
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -14603,7 +15159,7 @@ static void ggml_compute_forward_flash_attn_f16(
             for (int64_t ic = 0; ic < nek1; ++ic) {
                 // k indices
                 const int ik3 = iq3;
-                const int ik2 = iq2;
+                const int ik2 = iq2 % nek2;
                 const int ik1 = ic;
 
                 // S indices
@@ -14618,7 +15174,7 @@ static void ggml_compute_forward_flash_attn_f16(
             for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
                 // k indices
                 const int ik3 = iq3;
-                const int ik2 = iq2;
+                const int ik2 = iq2 % nek2;
                 const int ik1 = ic;
 
                 // S indices
@@ -14643,6 +15199,8 @@ static void ggml_compute_forward_flash_attn_f16(
         }
 
         // softmax
+        // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
+        // dont forget to set their S values to zero
         {
             float max = -INFINITY;
             ggml_vec_max_f32(M, &max, S);
@@ -14699,6 +15257,7 @@ static void ggml_compute_forward_flash_attn_f16(
             S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
+        // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
         if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
             for (int64_t ic = 0; ic < nev1; ++ic) {
                 // dst indices
@@ -14706,9 +15265,13 @@ static void ggml_compute_forward_flash_attn_f16(
                 const int i2 = iq2;
                 const int i3 = iq3;
 
-                ggml_vec_dot_f16(nek1,
-                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
-                        (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                // v indices
+                const int iv2 = iq2 % nev2;
+                const int iv3 = iq3;
+
+                ggml_vec_dot_f16(nev0,
+                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)),
+                        (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
                         S16);
             }
         } else {
@@ -14718,9 +15281,13 @@ static void ggml_compute_forward_flash_attn_f16(
                 const int i2 = iq2;
                 const int i3 = iq3;
 
-                ggml_vec_dot_f16_unroll(nek1, nbv1,
-                        (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
-                        ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                // v indices
+                const int iv2 = iq2 % nev2;
+                const int iv3 = iq3;
+
+                ggml_vec_dot_f16_unroll(nev0, nbv1,
+                        (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2   + i3*nb3)),
+                        ((char *)             v->data + (         ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
                         S16);
             }
         }
@@ -14763,18 +15330,18 @@ static void ggml_compute_forward_flash_ff_f16(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_LOCALS(int64_t, nea,  a,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nba,  a,   nb);
-    GGML_TENSOR_LOCALS(int64_t, neb0, b0,  ne);
-    GGML_TENSOR_LOCALS(size_t,  nbb0, b0,  nb);
-    GGML_TENSOR_LOCALS(int64_t, neb1, b1,  ne);
-    GGML_TENSOR_LOCALS(size_t,  nbb1, b1,  nb);
-    GGML_TENSOR_LOCALS(int64_t, nec0, c0,  ne);
-    GGML_TENSOR_LOCALS(size_t,  nbc0, c0,  nb);
-    GGML_TENSOR_LOCALS(int64_t, nec1, c1,  ne);
-    GGML_TENSOR_LOCALS(size_t,  nbc1, c1,  nb);
-    GGML_TENSOR_LOCALS(int64_t, ne,   dst, ne);
-    GGML_TENSOR_LOCALS(size_t,  nb,   dst, nb);
+    GGML_TENSOR_LOCALS(int64_t, nea,  a,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nba,  a,   nb)
+    GGML_TENSOR_LOCALS(int64_t, neb0, b0,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nbb0, b0,  nb)
+    GGML_TENSOR_LOCALS(int64_t, neb1, b1,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nbb1, b1,  nb)
+    GGML_TENSOR_LOCALS(int64_t, nec0, c0,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nbc0, c0,  nb)
+    GGML_TENSOR_LOCALS(int64_t, nec1, c1,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nbc1, c1,  nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,   dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,   dst, nb)
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -14922,16 +15489,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb);
-    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb);
-    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb);
-    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne);
-    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb);
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne);
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb);
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -14979,10 +15546,37 @@ static void ggml_compute_forward_flash_attn_back_f32(
         return;
     }
 
-    // parallelize by q rows using ggml_vec_dot_f32
+    const int64_t elem_q = ggml_nelements(q);
+    const int64_t elem_k = ggml_nelements(k);
 
-    // total rows in q
-    const int nr = neq2*neq3;
+    enum ggml_type result_type = dst->type;
+    GGML_ASSERT(ggml_blck_size(result_type) == 1);
+    const size_t tsize = ggml_type_size(result_type);
+
+    const size_t offs_q = 0;
+    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+    void * grad_q = (char *) dst->data;
+    void * grad_k = (char *) dst->data + offs_k;
+    void * grad_v = (char *) dst->data + offs_v;
+
+    const size_t nbgq1 = nb0*neq0;
+    const size_t nbgq2 = nb0*neq0*neq1;
+    const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+    const size_t nbgk1 = nb0*nek0;
+    const size_t nbgk2 = nb0*nek0*nek1;
+    const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+    const size_t nbgv1 = nb0*nev0;
+    const size_t nbgv2 = nb0*nev0*nev1;
+    const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+    // parallelize by k rows using ggml_vec_dot_f32
+
+    // total rows in k
+    const int nr = nek2*nek3;
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -14995,268 +15589,243 @@ static void ggml_compute_forward_flash_attn_back_f32(
 
     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
 
+    // how often k2 (and v2) is repeated in q2
+    int nrep = neq2/nek2;
+
     for (int ir = ir0; ir < ir1; ++ir) {
         // q indices
-        const int iq3 = ir/(neq2);
-        const int iq2 = ir - iq3*neq2;
-        for ( int iq1 = 0; iq1 < neq1; ++iq1) {
+        const int ik3 = ir/(nek2);
+        const int ik2 = ir - ik3*nek2;
 
+        const int iq3 = ik3;
+        const int id3 = ik3;
+        const int iv3 = ik3;
+        const int iv2 = ik2;
 
-            // not sure about CACHE_LINE_SIZE_F32..
-            // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
-            float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
-            float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+        for (int irep = 0; irep < nrep; ++irep) {
+            const int iq2 = ik2 + irep*nek2;
+            const int id2 = iq2;
 
-            for (int i = M; i < Mup; ++i) {
-                S[i] = -INFINITY;
-            }
+            // (ik2 + irep*nek2) % nek2 == ik2
+            for (int iq1 = 0; iq1 < neq1; ++iq1) {
+                const int id1 = iq1;
 
-            for (int64_t ic = 0; ic < nek1; ++ic) {
-                // k indices
-                const int ik3 = iq3;
-                const int ik2 = iq2;
-                const int ik1 = ic;
+                // not sure about CACHE_LINE_SIZE_F32..
+                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
 
-                // S indices
-                const int i1 = ik1;
+                for (int i = M; i < Mup; ++i) {
+                    S[i] = -INFINITY;
+                }
 
-                ggml_vec_dot_f32(neq0,
-                        S + i1,
-                        (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                        (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
-            }
+                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    // k indices
+                    const int ik1 = ic;
 
-            // scale
-            ggml_vec_scale_f32(nek1, S, scale);
+                    // S indices
+                    const int i1 = ik1;
 
-            if (masked) {
-                for (int64_t i = P; i < M; i++) {
-                    if (i > P + iq1) {
-                        S[i] = -INFINITY;
-                    }
+                    ggml_vec_dot_f32(neq0,
+                            S + i1,
+                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
                 }
-            }
 
-            // softmax
-            {
-                float max = -INFINITY;
-                ggml_vec_max_f32(M, &max, S);
+                // scale
+                ggml_vec_scale_f32(masked_begin, S, scale);
 
-                ggml_float sum = 0.0;
+                for (int64_t i = masked_begin; i < M; i++) {
+                    S[i] = -INFINITY;
+                }
+
+                // softmax
+                // exclude known -INF S[..] values from max and loop
+                // dont forget to set their SM values to zero
                 {
+                    float max = -INFINITY;
+                    ggml_vec_max_f32(masked_begin, &max, S);
+
+                    ggml_float sum = 0.0;
+                    {
 #ifdef GGML_SOFT_MAX_ACCELERATE
-                    max = -max;
-                    vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
-                    vvexpf(SM, SM, &Mup);
-                    ggml_vec_sum_f32(Mup, &sum, SM);
+                        max = -max;
+                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+                        vvexpf(SM, SM, &Mup);
+                        ggml_vec_sum_f32(Mup, &sum, SM);
 #else
-                    uint16_t   scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
-                    ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
-
-                    for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
-                        float * SR =  S + i;
-                        float * SW = SM + i;
+                        uint16_t   scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
+                        ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 
-                        for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
-                            if (SR[j] == -INFINITY) {
-                                SW[j] = 0.0f;
-                            } else {
+                        for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
+                            if (i >= masked_begin) {
+                                break;
+                            }
+                            float * SR =  S + i;
+                            float * SW = SM + i;
+
+                            for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
+                                if (i + j >= masked_begin) {
+                                    break;
+                                } else if (SR[j] == -INFINITY) {
+                                    SW[j] = 0.0f;
+                                } else {
 #ifndef GGML_FLASH_ATTN_EXP_FP16
-                                const float val = expf(SR[j] - max);
+                                    const float val = expf(SR[j] - max);
 #else
-                                ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
-                                memcpy(&scvt[j], &s, sizeof(uint16_t));
-                                const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+                                    ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
+                                    memcpy(&scvt[j], &s, sizeof(uint16_t));
+                                    const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
 #endif
-                                sump[j] += (ggml_float)val;
-                                SW[j] = val;
+                                    sump[j] += (ggml_float)val;
+                                    SW[j] = val;
+                                }
                             }
                         }
-                    }
 
-                    for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
-                        sum += sump[i];
-                    }
+                        for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
+                            sum += sump[i];
+                        }
 #endif
-                }
-
-                assert(sum > 0.0);
-
-                sum = 1.0/sum;
-                ggml_vec_scale_f32(M, SM, sum);
-
-            }
-
-            // step-by-step explanation
-            {
-                // forward-process                   shape      grads from backward process
-                // parallel_for iq2,iq3:
-                //  k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,iq2,iq3]  += grad[kcur]
-                //  q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
-                //  v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iq2,iq3]  += grad[vcur]
-                //  for iq1:
-                //   kcur   = k[:D,:M,iq2,iq3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
-                //   qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
-                //   vcur   = v[:M,:D,iq2,iq3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
-                //   S0     = -Inf                   [D,1,1,1]
-                //  ~S1[i]  = dot(kcur[:D,i], qcur)
-                //   S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
-                //   S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
-                //   S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                //   S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
-                //  ~S5[i]  = dot(vcur[:,i], S4)
-                //   S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,iq1,iq2,iq3]
-                //  ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
-                //   dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
-                // dst                               backward-/ grad[dst]                 = d
-                //
-                // output gradients with their dependencies:
-                //
-                // grad[kcur] = grad[S1].T @ qcur
-                // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
-                // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                // grad[S4]   = grad[S5] @ vcur
-                // grad[S4]   = d[:D,iq1,iq2,iq3] @ vcur
-                // grad[qcur] = grad[S1]   @ kcur
-                // grad[vcur] = grad[S5].T @ S4
-                // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
-                //
-                // in post-order:
-                //
-                // S1         = qcur @ kcur.T
-                // S2         = S1 * scale
-                // S3         = diag_mask_inf(S2, P)
-                // S4         = softmax(S3)
-                // grad[S4]   = d[:D,iq1,iq2,iq3] @ vcur
-                // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
-                // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
-                // grad[qcur] = grad[S1]   @ kcur
-                // grad[kcur] = grad[S1].T @ qcur
-                // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
-                //
-                // using less variables (SM=S4):
-                //
-                // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
-                // SM            = softmax(S)
-                // S             = d[:D,iq1,iq2,iq3] @ vcur
-                // dot_SM_gradSM = dot(SM, S)
-                // S             = SM * (S - dot(SM, S))
-                // S             = diag_mask_zero(S, P) * scale
-                //
-                // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
-                // grad[k][:D,:M,iq2,iq3]  += S.T @ qcur
-                // grad[v][:M,:D,iq2,iq3]  += d[:D,iq1,iq2,iq3].T @ SM
-            }
-
-            // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
-            // S = d[:D,iq1,iq2,iq3] @ vcur
-            // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
-            ggml_vec_set_f32(M, S, 0);
-            for (int64_t ic = 0; ic < D; ++ic) {
-                // dst indices
-                const int i1 = iq1;
-                const int i2 = iq2;
-                const int i3 = iq3;
+                    }
 
-                ggml_vec_mad_f32(M,
-                        S,
-                         (float *) ((char *) v->data + (          ic*nbv1 + i2*nbv2 + i3*nbv3)),
-                        *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
-            }
+                    assert(sum > 0.0);
 
-            // S = SM * (S - dot(SM, S))
-            float dot_SM_gradSM = 0;
-            ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
-            ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
-            ggml_vec_mul_f32 (M, S, S, SM);
+                    sum = 1.0/sum;
+                    ggml_vec_scale_f32(masked_begin, SM, sum);
 
-            // S = diag_mask_zero(S, P) * scale
-            if (masked) {
-                // for (int64_t i = P + iq1 + 1; i < M; i++) {
-                //     S[i] = 0;
-                // }
-                for (int64_t i = P; i < M; i++) {
-                    if (i > P + iq1) {
-                        S[i] = 0;
-                    }
                 }
-            }
-            ggml_vec_scale_f32(M, S, scale);
-
-            void * grad_q = (char *) dst->data;
-            void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
-            void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
-
-            const size_t nbgq1 = nb0*neq0;
-            const size_t nbgq2 = nb0*neq0*neq1;
-            const size_t nbgq3 = nb0*neq0*neq1*neq2;
-
-            const size_t nbgk1 = nb0*nek0;
-            const size_t nbgk2 = nb0*nek0*nek1;
-            const size_t nbgk3 = nb0*nek0*nek1*neq2;
-
-            const size_t nbgv1 = nb0*nev0;
-            const size_t nbgv2 = nb0*nev0*nev1;
-            const size_t nbgv3 = nb0*nev0*nev1*neq2;
-
-            // S    shape [M,1]
-            // SM   shape [M,1]
-            // kcur shape [D,M]
-            // qcur shape [D,1]
-            // vcur shape [M,D]
-            //
-            // grad[q][:D,iq1,iq2,iq3] += S @ kcur
-            // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
-            // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
-            //
-            //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
-            //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
-            for (int64_t ic = 0; ic < M; ++ic) {
-                // dst indices
-                const int i1 = iq1;
-                const int i2 = iq2;
-                const int i3 = iq3;
 
-                ggml_vec_mad_f32(D,
-                        (float *) ((char *) grad_q  + (i1*nbgq1  + i2*nbgq2  + i3*nbgq3)),
-                        (float *) ((char *) k->data + (ic*nbk1   + i2*nbk2   + i3*nbk3)),
-                        S[ic]);
-            }
+                // step-by-step explanation
+                {
+                    // forward-process                    shape      grads from backward process
+                    // parallel_for ik2,ik3:
+                    //  for irep:
+                    //   iq2 = ik2 + irep*nek2
+                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
+                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
+                    //   for iq1:
+                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
+                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
+                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
+                    //    S0     = -Inf                   [D,1,1,1]
+                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
+                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
+                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
+                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
+                    //   ~S5[i]  = dot(vcur[:,i], S4)
+                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
+                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
+                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
+                    // dst                               backward-/ grad[dst]                 = d
+                    //
+                    // output gradients with their dependencies:
+                    //
+                    // grad[kcur] = grad[S1].T @ qcur
+                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    // grad[S4]   = grad[S5] @ vcur
+                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
+                    // grad[qcur] = grad[S1]   @ kcur
+                    // grad[vcur] = grad[S5].T @ S4
+                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+                    //
+                    // in post-order:
+                    //
+                    // S1         = qcur @ kcur.T
+                    // S2         = S1 * scale
+                    // S3         = diag_mask_inf(S2, P)
+                    // S4         = softmax(S3)
+                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
+                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                    // grad[qcur] = grad[S1]   @ kcur
+                    // grad[kcur] = grad[S1].T @ qcur
+                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+                    //
+                    // using less variables (SM=S4):
+                    //
+                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
+                    // SM            = softmax(S)
+                    // S             = d[:D,iq1,iq2,iq3] @ vcur
+                    // dot_SM_gradSM = dot(SM, S)
+                    // S             = SM * (S - dot(SM, S))
+                    // S             = diag_mask_zero(S, P) * scale
+                    //
+                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
+                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
+                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
+                }
 
-            // grad[k][:D,:M,iq2,iq3] += S.T       @ qcur
-            // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
-            // grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
-            for (int64_t ic = 0; ic < M; ++ic) {
-                // dst indices
-                const int i1 = iq1;
-                const int i2 = iq2;
-                const int i3 = iq3;
+                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+                // for ic:
+                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
+                // exclude known future zero S[..] values from operation
+                ggml_vec_set_f32(masked_begin, S, 0);
+                for (int64_t ic = 0; ic < D; ++ic) {
+                    ggml_vec_mad_f32(masked_begin,
+                            S,
+                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
+                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+                }
 
-                // ggml_vec_set_f32(D,
-                //         (float *) ((char *) grad_k  + (ic*nbgk1  + i2*nbgk2  + i3*nbgk3)),
-                //         0);
-                ggml_vec_mad_f32(D,
-                        (float *) ((char *) grad_k  + (ic*nbgk1  + i2*nbgk2  + i3*nbgk3)),
-                        (float *) ((char *) q->data + (i1*nbq1   + i2*nbq2   + i3*nbq3)),
-                        S[ic]);
-            }
+                // S = SM * (S - dot(SM, S))
+                float dot_SM_gradSM = 0;
+                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
+                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+                ggml_vec_mul_f32 (masked_begin, S, S, SM);
+
+                // S = diag_mask_zero(S, P) * scale
+                // already done by above ggml_vec_set_f32
+
+                // exclude known zero S[..] values from operation
+                ggml_vec_scale_f32(masked_begin, S, scale);
+
+                // S    shape [M,1]
+                // SM   shape [M,1]
+                // kcur shape [D,M]
+                // qcur shape [D,1]
+                // vcur shape [M,D]
+
+                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+                // for ic:
+                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
+                // exclude known zero S[..] values from loop
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    ggml_vec_mad_f32(D,
+                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
+                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
+                            S[ic]);
+                }
 
-            // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T       @ SM
-            // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
-            // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3]         * SM[:M]
-            for (int64_t ic = 0; ic < D; ++ic) {
-                // dst indices
-                const int i1 = iq1;
-                const int i2 = iq2;
-                const int i3 = iq3;
+                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
+                // for ic:
+                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
+                // exclude known zero S[..] values from loop
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    ggml_vec_mad_f32(D,
+                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
+                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
+                            S[ic]);
+                }
 
-                // ggml_vec_set_f32(M,
-                //         (float *) ((char *) grad_v   + (          ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
-                //         0);
-                ggml_vec_mad_f32(M,
-                        (float *) ((char *) grad_v   + (          ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
-                        SM,
-                        *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1  + i2*nbd2  + i3*nbd3)));
+                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
+                // for ic:
+                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
+                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
+                // exclude known zero SM[..] values from mad
+                for (int64_t ic = 0; ic < D; ++ic) {
+                    ggml_vec_mad_f32(masked_begin,
+                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
+                            SM,
+                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
+                }
             }
         }
     }
@@ -15292,8 +15861,8 @@ static void ggml_compute_forward_win_part_f32(
         return;
     }
 
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne);
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 
     const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
     const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
@@ -15354,8 +15923,8 @@ static void ggml_compute_forward_win_unpart_f32(
         return;
     }
 
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne);
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 
     const int32_t w = ((const int32_t *)(dst->op_params))[0];
 
@@ -15472,7 +16041,7 @@ static void ggml_compute_forward_get_rel_pos_f16(
 
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     const int64_t w = ne1;
 
@@ -16170,7 +16739,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_GET_ROWS_BACK:
             {
-                ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_DIAG:
             {
@@ -16194,11 +16763,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_ROPE:
             {
-                ggml_compute_forward_rope(params, tensor->src[0], tensor);
+                ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_ROPE_BACK:
             {
-                ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
+                ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_ALIBI:
             {
@@ -16355,7 +16924,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
+static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
+
+static size_t hash(void * p) {
+    return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+}
+
+static size_t hash_find(void * hash_table[], void * p) {
+    size_t h = hash(p);
+
+    // linear probing
+    size_t i = h;
+    while (hash_table[i] != NULL && hash_table[i] != p) {
+        i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+        if (i == h) {
+            // visited all hash table entries -> not found
+            return GGML_GRAPH_HASHTABLE_SIZE;
+        }
+    }
+    return i;
+}
+
+static bool hash_insert(void * hash_table[], void * p) {
+    size_t i = hash_find(hash_table, p);
+
+    GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+
+    if (hash_table[i] == p) {
+        return true;
+    }
+
+    // insert
+    GGML_ASSERT(hash_table[i] == NULL);
+    hash_table[i] = p;
+    return false;
+}
+
+static bool hash_contains(void * hash_table[], void * p) {
+    size_t i = hash_find(hash_table, p);
+    return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
+}
+
+struct hash_map {
+    void * keys[GGML_GRAPH_HASHTABLE_SIZE];
+    void * vals[GGML_GRAPH_HASHTABLE_SIZE];
+};
+
+static struct hash_map * new_hash_map(void) {
+    struct hash_map * result = malloc(sizeof(struct hash_map));
+    for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
+        result->keys[i] = NULL;
+        result->vals[i] = NULL;
+    }
+    return result;
+}
+
+static void free_hash_map(struct hash_map * map) {
+    free(map);
+}
+
+// gradient checkpointing
+
+static struct ggml_tensor * ggml_recompute_graph_node(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * graph,
+        struct hash_map     * replacements,
+        struct ggml_tensor  * node) {
+
+    if (node == NULL) {
+        return NULL;
+    }
+
+    if (node->is_param) {
+        return node;
+    }
+
+    if (!hash_contains(graph->visited_hash_table, node)) {
+        return node;
+    }
+
+    int count_children = 0;
+    for (int k = 0; k < GGML_MAX_SRC; ++k) {
+        if (node->src[k]) {
+            ++count_children;
+        }
+    }
+
+    if (count_children == 0) {
+        return node;
+    }
+
+    size_t i = hash_find(replacements->keys, node);
+    GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+    if (replacements->keys[i] == node) {
+        return (struct ggml_tensor *) replacements->vals[i];
+    }
+
+    struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
+
+    // insert clone into replacements
+    GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
+    replacements->keys[i] = node;
+    replacements->vals[i] = clone;
+
+    clone->op       = node->op;
+    clone->grad     = node->grad;
+    clone->is_param = node->is_param;
+    clone->extra    = node->extra;
+    for (int k = 0; k < GGML_MAX_DIMS; ++k) {
+        clone->nb[k] = node->nb[k];
+    }
+    for (int k = 0; k < GGML_MAX_SRC; ++k) {
+        clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
+    }
+    if (node->view_src != NULL) {
+        clone->data = (node->view_src->data == NULL)
+                        ? NULL // view_src not yet allocated
+                        : (char *) node->view_src->data // view_src already allocated
+                                 + node->view_offs;
+        clone->view_src  = node->view_src;
+        clone->view_offs = node->view_offs;
+    }
+
+    GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
+    GGML_ASSERT(sizeof(node->name)      == GGML_MAX_NAME);
+    memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
+    ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
+
+    return clone;
+}
+
+void ggml_build_backward_gradient_checkpointing(
+        struct ggml_context   * ctx,
+        struct ggml_cgraph    * gf,
+        struct ggml_cgraph    * gb,
+        struct ggml_cgraph    * gb_tmp,
+        struct ggml_tensor  * * checkpoints,
+        int                     n_checkpoints) {
+    *gb_tmp = *gf;
+    ggml_build_backward_expand(ctx, gf, gb_tmp, true);
+
+    if (n_checkpoints <= 0) {
+        *gb = *gb_tmp;
+        return;
+    }
+
+    struct hash_map * replacements = new_hash_map();
+
+    // insert checkpoints in replacements
+    for (int i = 0; i < n_checkpoints; ++i) {
+        size_t k = hash_find(replacements->keys, checkpoints[i]);
+        GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+        GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
+        replacements->keys[k] = checkpoints[i];
+        replacements->vals[k] = checkpoints[i];
+    }
+
+    *gb = *gf;
+    // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
+    // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
+    // by recomputing them from checkpoints
+    for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
+        struct ggml_tensor * node = gb_tmp->nodes[i];
+        for (int k = 0; k < GGML_MAX_SRC; ++k) {
+            // insert new tensors recomputing src, reusing already made replacements,
+            // remember replacements: remember new tensors with mapping from corresponding gf nodes
+            // recurse for input tensors,
+            // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
+            node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
+        }
+        // insert rewritten backward node with replacements made into resulting backward graph gb
+        ggml_build_forward_expand(gb, node);
+    }
+
+    free_hash_map(replacements);
+}
+
+// functions to change gradients considering the case that input a might be initial gradient with zero value
+
+static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+    if (hash_contains(zero_table, a)) {
+        return b;
+    } else {
+        return ggml_add_impl(ctx, a, b, false);
+    }
+}
+
+static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
+    if (hash_contains(zero_table, a)) {
+        struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
+        return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
+    } else {
+        return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+    }
+}
+
+static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+    if (hash_contains(zero_table, a)) {
+        return ggml_repeat(ctx, b, a);
+    } else {
+        return ggml_add1_impl(ctx, a, b, false);
+    }
+}
+
+static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
+    if (hash_contains(zero_table, a)) {
+        return ggml_neg(ctx, b);
+    } else {
+        return ggml_sub_impl(ctx, a, b, false);
+    }
+}
+
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
     struct ggml_tensor * src0 = tensor->src[0];
     struct ggml_tensor * src1 = tensor->src[1];
 
@@ -16363,34 +17143,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
         case GGML_OP_DUP:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
                 }
             } break;
         case GGML_OP_ADD:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
                 }
                 if (src1->grad) {
-                    src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
+                    src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
                 }
             } break;
         case GGML_OP_ADD1:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
                 }
                 if (src1->grad) {
-                    src1->grad = ggml_add_impl(ctx,
+                    src1->grad = ggml_add_or_set(ctx,
                         src1->grad,
                         ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
-                        inplace);
+                        zero_table);
                 }
             } break;
         case GGML_OP_ACC:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
                 }
                 if (src1->grad) {
                     const size_t nb1     = ((int32_t *) tensor->op_params)[0];
@@ -16407,117 +17187,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         nb1, nb2, nb3, offset);
 
                     src1->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                             src1->grad,
                             ggml_reshape(ctx,
                                 ggml_cont(ctx, tensor_grad_view),
                                 src1->grad),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_SUB:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
                 }
                 if (src1->grad) {
-                    src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
+                    src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
                 }
             } break;
         case GGML_OP_MUL:
             {
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_mul(ctx, src1, tensor->grad),
-                                inplace);
+                                zero_table);
                 }
                 if (src1->grad) {
                     src1->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                                 src1->grad,
                                 ggml_mul(ctx, src0, tensor->grad),
-                                inplace);
+                                zero_table);
                 }
             } break;
         case GGML_OP_DIV:
             {
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_div(ctx, tensor->grad, src1),
-                                inplace);
+                                zero_table);
                 }
                 if (src1->grad) {
                     src1->grad =
-                        ggml_sub_impl(ctx,
+                        ggml_sub_or_set(ctx,
                                 src1->grad,
                                 ggml_mul(ctx,
                                     tensor->grad,
                                     ggml_div(ctx, tensor, src1)),
-                                inplace);
+                                zero_table);
                 }
             } break;
         case GGML_OP_SQR:
             {
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_scale(ctx,
                                     ggml_mul(ctx, src0, tensor->grad),
                                     ggml_new_f32(ctx, 2.0f)),
-                                inplace);
+                                zero_table);
                 }
             } break;
         case GGML_OP_SQRT:
             {
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_scale(ctx,
                                     ggml_div(ctx,
                                         tensor->grad,
                                         tensor),
                                     ggml_new_f32(ctx, 0.5f)),
-                                inplace);
+                                zero_table);
                 }
             } break;
         case GGML_OP_LOG:
             {
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_div(ctx,
                                     tensor->grad,
                                     src0),
-                                inplace);
+                                zero_table);
                 }
             } break;
         case GGML_OP_SUM:
             {
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add1_impl(ctx,
+                        ggml_add1_or_set(ctx,
                                 src0->grad,
                                 tensor->grad,
-                                inplace);
+                                zero_table);
                 }
             } break;
         case GGML_OP_SUM_ROWS:
             {
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_repeat(ctx,
                                     tensor->grad,
                                     src0->grad),
-                                inplace);
+                                zero_table);
                 }
             } break;
         case GGML_OP_MEAN:
@@ -16529,20 +17309,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx,
+                    src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_repeat_back(ctx, tensor->grad, src0->grad),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_REPEAT_BACK:
             {
                 if (src0->grad) {
                     // TODO: test this
-                    src0->grad = ggml_add_impl(ctx,
+                    src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_repeat(ctx, tensor->grad, src0->grad),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_CONCAT:
@@ -16564,10 +17344,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     float eps;
                     memcpy(&eps, tensor->op_params, sizeof(float));
 
-                    src0->grad = ggml_add_impl(ctx,
+                    src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_RMS_NORM_BACK:
@@ -16591,37 +17371,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
                 // ds1 = t.T.dot(dt)
 
-                // tensor.shape [m,p]
-                // src0.shape   [n,m]
-                // src1.shape   [n,p]
+                // tensor.shape [m,p,qq,rr]
+                // src0.shape   [n,m,q1,r1]
+                // src1.shape   [n,p,qq,rr]
 
                 // necessary for llama
                 if (src0->grad) {
+                    struct ggml_tensor * s1_tg =
+                        ggml_out_prod(ctx, // [n,m,qq,rr]
+                            src1,          // [n,p,qq,rr]
+                            tensor->grad); // [m,p,qq,rr]
+                    const int64_t qq = s1_tg->ne[2];
+                    const int64_t rr = s1_tg->ne[3];
+                    const int64_t q1 = src0->ne[2];
+                    const int64_t r1 = src0->ne[3];
+                    const bool ne2_broadcasted = qq > q1;
+                    const bool ne3_broadcasted = rr > r1;
+                    if (ne2_broadcasted || ne3_broadcasted) {
+                        // sum broadcast repetitions of s1_tg into shape of src0
+                        s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+                    }
                     src0->grad =
-                        ggml_add_impl(ctx,
-                                src0->grad,
-                                ggml_out_prod(ctx, // [n,m]
-                                    src1,          // [n,p]
-                                    tensor->grad), // [m,p]
-                                inplace);
+                        ggml_add_or_set(ctx,
+                                src0->grad, // [n,m,q1,r1]
+                                s1_tg,      // [n,m,q1,r1]
+                                zero_table);
                 }
                 if (src1->grad) {
                     src1->grad =
-                        ggml_add_impl(ctx,
-                                src1->grad,
-                                // ggml_mul_mat(ctx,                   // [n,p]
-                                //     ggml_cont(ctx,                  // [m,n]
-                                //         ggml_transpose(ctx, src0)), // [m,n]
-                                //     tensor->grad),                  // [m,p]
+                        ggml_add_or_set(ctx,
+                                src1->grad,                            // [n,p,qq,rr]
+                                // ggml_mul_mat(ctx,                   // [n,p,qq,rr]
+                                //     ggml_cont(ctx,                  // [m,n,q1,r1]
+                                //         ggml_transpose(ctx, src0)), // [m,n,q1,r1]
+                                //     tensor->grad),                  // [m,p,qq,rr]
 
                                 // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
                                 // // avoid transpose of src0, rather transpose smaller tensor->grad
                                 // // and then use ggml_out_prod
-                                ggml_out_prod(ctx,                  // [n,p]
-                                    src0,                           // [n,m]
-                                    ggml_transpose(ctx,             // [p,m]
-                                        tensor->grad)),             // [m,p]
-                                inplace);
+                                ggml_out_prod(ctx,                  // [n,p,qq,rr]
+                                    src0,                           // [n,m,q1,r1]
+                                    ggml_transpose(ctx,             // [p,m,qq,rr]
+                                        tensor->grad)),             // [m,p,qq,rr]
+                                zero_table);
                 }
             } break;
         case GGML_OP_OUT_PROD:
@@ -16633,17 +17425,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_scale_impl(ctx, tensor->grad, src1, false),
-                            inplace);
+                            zero_table);
                 }
                 if (src1->grad) {
                     src1->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                             src1->grad,
                             ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_SET:
@@ -16670,23 +17462,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 }
 
                 if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx,
+                    src0->grad = ggml_add_or_set(ctx,
                         src0->grad,
                         ggml_acc_impl(ctx,
                             tensor->grad,
                             ggml_neg(ctx, tensor_grad_view),
                             nb1, nb2, nb3, offset, false),
-                        inplace);
+                        zero_table);
                 }
 
                 if (src1->grad) {
                     src1->grad =
-                        ggml_add_impl(ctx,
+                        ggml_add_or_set(ctx,
                             src1->grad,
                             ggml_reshape(ctx,
                                 ggml_cont(ctx, tensor_grad_view),
                                 src1->grad),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_CPY:
@@ -16697,7 +17489,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // tensor = src0 * 1 + src1 * 0
                 if (src0->grad) {
                     // dsrc0 = dtensor * 1
-                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
                 }
                 if (src1->grad) {
                     // dsrc1 = dtensor * 0 -> noop
@@ -16709,7 +17501,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 if (src0->grad) {
                     GGML_ASSERT(ggml_is_contiguous(src0->grad));
                     GGML_ASSERT(ggml_is_contiguous(tensor->grad));
-                    src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
                 }
             } break;
         case GGML_OP_RESHAPE:
@@ -16717,9 +17509,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx, src0->grad,
-                            ggml_reshape(ctx, tensor->grad, src0->grad),
-                        inplace);
+                        ggml_add_or_set(ctx, src0->grad,
+                            ggml_reshape(ctx,
+                                ggml_is_contiguous(tensor->grad)
+                                    ? tensor->grad
+                                    : ggml_cont(ctx, tensor->grad),
+                                src0->grad),
+                        zero_table);
                 }
             } break;
         case GGML_OP_VIEW:
@@ -16748,7 +17544,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         nb3 = (nb3 / n0) * ng;
                     }
 
-                    src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
+                    src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
                 }
             } break;
         case GGML_OP_PERMUTE:
@@ -16766,14 +17562,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     axes_backward[axis2] = 2;
                     axes_backward[axis3] = 3;
                     src0->grad =
-                        ggml_add_impl(ctx, src0->grad,
+                        ggml_add_or_set(ctx, src0->grad,
                             ggml_permute(ctx,
                                 tensor->grad,
                                 axes_backward[0],
                                 axes_backward[1],
                                 axes_backward[2],
                                 axes_backward[3]),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_TRANSPOSE:
@@ -16781,9 +17577,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx, src0->grad,
+                        ggml_add_or_set(ctx, src0->grad,
                             ggml_transpose(ctx, tensor->grad),
-                        inplace);
+                        zero_table);
                 }
             } break;
         case GGML_OP_GET_ROWS:
@@ -16791,9 +17587,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama (only for tokenizer)
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx, src0->grad,
+                        ggml_add_or_set(ctx, src0->grad,
+                            // last ggml_get_rows_back argument src0->grad is only
+                            // necessary to setup correct output shape
                             ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
-                        inplace);
+                        zero_table);
                 }
                 if (src1->grad) {
                     // noop
@@ -16813,9 +17611,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 if (src0->grad) {
                     const int n_past = ((int32_t *) tensor->op_params)[0];
                     src0->grad =
-                        ggml_add_impl(ctx, src0->grad,
+                        ggml_add_or_set(ctx, src0->grad,
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
-                        inplace);
+                        zero_table);
                 }
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
@@ -16824,9 +17622,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 if (src0->grad) {
                     const int n_past = ((int32_t *) tensor->op_params)[0];
                     src0->grad =
-                        ggml_add_impl(ctx, src0->grad,
+                        ggml_add_or_set(ctx, src0->grad,
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
-                        inplace);
+                        zero_table);
                 }
             } break;
         case GGML_OP_SOFT_MAX:
@@ -16834,9 +17632,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     src0->grad =
-                        ggml_add_impl(ctx, src0->grad,
+                        ggml_add_or_set(ctx, src0->grad,
                             ggml_soft_max_back(ctx, tensor->grad, tensor),
-                        inplace);
+                        zero_table);
                 }
 
             } break;
@@ -16848,7 +17646,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    const int n_past = ((int32_t *) tensor->op_params)[0];
+                    //const int n_past = ((int32_t *) tensor->op_params)[0];
                     const int n_dims = ((int32_t *) tensor->op_params)[1];
                     const int mode   = ((int32_t *) tensor->op_params)[2];
                     const int n_ctx  = ((int32_t *) tensor->op_params)[3];
@@ -16861,11 +17659,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
                     memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
 
-                    src0->grad = ggml_add_impl(ctx,
+                    src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_rope_back(ctx,
                                 tensor->grad,
-                                n_past,
+                                src1,
                                 n_dims,
                                 mode,
                                 n_ctx,
@@ -16873,13 +17671,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 freq_scale,
                                 xpos_base,
                                 xpos_down),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_ROPE_BACK:
             {
                 if (src0->grad) {
-                    const int n_past = ((int32_t *) tensor->op_params)[0];
+                    //const int n_past = ((int32_t *) tensor->op_params)[0];
                     const int n_dims = ((int32_t *) tensor->op_params)[1];
                     const int mode   = ((int32_t *) tensor->op_params)[2];
                     const int n_ctx  = ((int32_t *) tensor->op_params)[3];
@@ -16892,11 +17690,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
                     memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
 
-                    src0->grad = ggml_add_impl(ctx,
+                    src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_rope_impl(ctx,
                                 tensor->grad,
-                                n_past,
+                                src1,
                                 n_dims,
                                 mode,
                                 n_ctx,
@@ -16905,7 +17703,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 xpos_base,
                                 xpos_down,
                                 false),
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_ALIBI:
@@ -16968,145 +17766,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             masked);
                 }
 
-                if (src0->grad) {
-                    struct ggml_tensor * grad_q = NULL;
-                    const size_t nb0    = flash_grad->nb[0];
-                    const size_t offset = 0;
-                    switch(src0->n_dims) {
-                        case 2:
-                            {
-                                grad_q = ggml_view_2d(ctx,
-                                    flash_grad,
-                                    src0->ne[0],
-                                    src0->ne[1],
-                                    nb0*src0->ne[0],
-                                    offset);
-                            } break;
-                        case 3:
-                            {
-                                grad_q = ggml_view_3d(ctx,
-                                    flash_grad,
-                                    src0->ne[0],
-                                    src0->ne[1],
-                                    src0->ne[2],
-                                    nb0*src0->ne[0],
-                                    nb0*src0->ne[0]*src0->ne[1],
-                                    offset);
-                            } break;
-                        case 4:
-                            {
-                                grad_q = ggml_view_4d(ctx,
-                                    flash_grad,
-                                    src0->ne[0],
-                                    src0->ne[1],
-                                    src0->ne[2],
-                                    src0->ne[3],
-                                    nb0*src0->ne[0],
-                                    nb0*src0->ne[0]*src0->ne[1],
-                                    nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
-                                    offset);
-                            } break;
-                    }
+                struct ggml_tensor * src2 = tensor->src[2];
+                const int64_t elem_q = ggml_nelements(src0);
+                const int64_t elem_k = ggml_nelements(src1);
+                const int64_t elem_v = ggml_nelements(src2);
+
+                enum ggml_type result_type = flash_grad->type;
+                GGML_ASSERT(ggml_blck_size(result_type) == 1);
+                const size_t tsize = ggml_type_size(result_type);
 
-                    src0->grad = ggml_add_impl(ctx,
+                const size_t offs_q = 0;
+                const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+                const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+                if (src0->grad) {
+                    struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
+                    struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
+                    src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             grad_q,
-                            inplace);
+                            zero_table);
                 }
-
                 if (src1->grad) {
-                    struct ggml_tensor * grad_k = NULL;
-                    const size_t nb0    = flash_grad->nb[0];
-                    const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
-                    switch(src1->n_dims) {
-                        case 2:
-                            {
-                                grad_k = ggml_view_2d(ctx,
-                                    flash_grad,
-                                    src1->ne[0],
-                                    src1->ne[1],
-                                    nb0*src1->ne[0],
-                                    offset);
-                            } break;
-                        case 3:
-                            {
-                                grad_k = ggml_view_3d(ctx,
-                                    flash_grad,
-                                    src1->ne[0],
-                                    src1->ne[1],
-                                    src1->ne[2],
-                                    nb0*src1->ne[0],
-                                    nb0*src1->ne[0]*src1->ne[1],
-                                    offset);
-                            } break;
-                        case 4:
-                            {
-                                grad_k = ggml_view_4d(ctx,
-                                    flash_grad,
-                                    src1->ne[0],
-                                    src1->ne[1],
-                                    src1->ne[2],
-                                    src1->ne[3],
-                                    nb0*src1->ne[0],
-                                    nb0*src1->ne[0]*src1->ne[1],
-                                    nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
-                                    offset);
-                            } break;
-                    }
-
-                    src1->grad = ggml_add_impl(ctx,
+                    struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
+                    struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
+                    src1->grad = ggml_add_or_set(ctx,
                             src1->grad,
                             grad_k,
-                            inplace);
+                            zero_table);
                 }
-
-                struct ggml_tensor * opt0 = tensor->src[2];
-
-                if (opt0->grad) {
-                    struct ggml_tensor * grad_v = NULL;
-                    const size_t nb0    = flash_grad->nb[0];
-                    const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
-                                        + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
-                    switch(opt0->n_dims) {
-                        case 2:
-                            {
-                                grad_v = ggml_view_2d(ctx,
-                                    flash_grad,
-                                    opt0->ne[0],
-                                    opt0->ne[1],
-                                    nb0*opt0->ne[0],
-                                    offset);
-                            } break;
-                        case 3:
-                            {
-                                grad_v = ggml_view_3d(ctx,
-                                    flash_grad,
-                                    opt0->ne[0],
-                                    opt0->ne[1],
-                                    opt0->ne[2],
-                                    nb0*opt0->ne[0],
-                                    nb0*opt0->ne[0]*opt0->ne[1],
-                                    offset);
-                            } break;
-                        case 4:
-                            {
-                                grad_v = ggml_view_4d(ctx,
-                                    flash_grad,
-                                    opt0->ne[0],
-                                    opt0->ne[1],
-                                    opt0->ne[2],
-                                    opt0->ne[3],
-                                    nb0*opt0->ne[0],
-                                    nb0*opt0->ne[0]*opt0->ne[1],
-                                    nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
-                                    offset);
-                            } break;
-                    }
-
-                    opt0->grad = ggml_add_impl(ctx,
-                            opt0->grad,
+                if (src2->grad) {
+                    struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
+                    struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
+                    src2->grad = ggml_add_or_set(ctx,
+                            src2->grad,
                             grad_v,
-                            inplace);
+                            zero_table);
                 }
             } break;
         case GGML_OP_FLASH_FF:
@@ -17126,12 +17821,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         {
                             if (src0->grad) {
                                 src0->grad =
-                                    ggml_add_impl(ctx,
+                                    ggml_add_or_set(ctx,
                                             src0->grad,
                                             ggml_mul(ctx,
                                                 ggml_sgn(ctx, src0),
                                                 tensor->grad),
-                                            inplace);
+                                            zero_table);
                             }
                         } break;
                     case GGML_UNARY_OP_SGN:
@@ -17143,7 +17838,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     case GGML_UNARY_OP_NEG:
                         {
                             if (src0->grad) {
-                                src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+                                src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
                             }
                         } break;
                     case GGML_UNARY_OP_STEP:
@@ -17163,12 +17858,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     case GGML_UNARY_OP_RELU:
                         {
                             if (src0->grad) {
-                                src0->grad = ggml_add_impl(ctx,
+                                src0->grad = ggml_add_or_set(ctx,
                                         src0->grad,
                                         ggml_mul(ctx,
                                             ggml_step(ctx, src0),
                                             tensor->grad),
-                                        inplace);
+                                        zero_table);
                             }
                         } break;
                     case GGML_UNARY_OP_GELU:
@@ -17183,10 +17878,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         {
                             // necessary for llama
                             if (src0->grad) {
-                                src0->grad = ggml_add_impl(ctx,
+                                src0->grad = ggml_add_or_set(ctx,
                                         src0->grad,
                                         ggml_silu_back(ctx, src0, tensor->grad),
-                                        inplace);
+                                        zero_table);
                             }
                         } break;
                     default:
@@ -17209,13 +17904,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
         case GGML_OP_CROSS_ENTROPY_LOSS:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx,
+                    src0->grad = ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_cross_entropy_loss_back(ctx,
                                     src0,
                                     src1,
                                     tensor->grad),
-                                inplace);
+                                zero_table);
                 }
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@@ -17231,34 +17926,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 GGML_ASSERT(false);
             } break;
     }
-}
-
-static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
-
-static size_t hash(void * p) {
-    return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
-}
 
-static bool hash_insert(void * hash_table[], void * p) {
-    size_t h = hash(p);
-
-    // linear probing
-    size_t i = h;
-    while (hash_table[i] != NULL && hash_table[i] != p) {
-        i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
-        if (i == h) {
-            // hash table is full
-            GGML_ASSERT(false);
+    for (int i = 0; i < GGML_MAX_SRC; ++i) {
+        if (tensor->src[i] && tensor->src[i]->grad) {
+            GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
         }
     }
-
-    if (hash_table[i] == p) {
-        return true;
-    }
-
-    // insert
-    hash_table[i] = p;
-    return false;
 }
 
 static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
@@ -17276,8 +17949,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
     }
 
     for (int i = 0; i < GGML_MAX_SRC; ++i) {
-        if (node->src[i]) {
-            ggml_visit_parents(cgraph, node->src[i]);
+        const int k =
+            (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
+            (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
+            /* unknown order, just fall back to using i*/ i;
+        if (node->src[k]) {
+            ggml_visit_parents(cgraph, node->src[k]);
         }
     }
 
@@ -17336,6 +18013,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
         /*.grads        =*/ { NULL },
         /*.leafs        =*/ { NULL },
         /*.hash_table   =*/ { NULL },
+        /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
@@ -17361,12 +18039,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
         }
     }
 
+    // remember original gradients which start with zero values
+    void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
+    memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
+    for (int i = 0; i < gf->n_nodes; i++) {
+        if (gf->grads[i]) {
+            hash_insert(zero_table, gf->grads[i]);
+        }
+    }
+
     for (int i = gf->n_nodes - 1; i >= 0; i--) {
         struct ggml_tensor * node = gf->nodes[i];
 
-        // because we detached the grad nodes from the original graph, we can afford inplace operations
+        // inplace operations to add gradients are not created by ggml_compute_backward
+        // use allocator to automatically make inplace operations
         if (node->grad) {
-            ggml_compute_backward(ctx, node, keep);
+            ggml_compute_backward(ctx, node, zero_table);
         }
     }
 
@@ -17378,6 +18066,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
             ggml_build_forward_expand(gb, node->grad);
         }
     }
+
+    free(zero_table);
 }
 
 struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
@@ -17397,6 +18087,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
         /*.grads        =*/ { NULL },
         /*.leafs        =*/ { NULL },
         /*.hash_table   =*/ { NULL },
+        /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
@@ -17787,7 +18478,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_CONCAT:
             case GGML_OP_MUL_MAT:
-            case GGML_OP_OUT_PROD:
                 {
                     n_tasks = n_threads;
 
@@ -17829,6 +18519,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         cur = 0;
                     }
 
+                    work_size = MAX(work_size, cur);
+                } break;
+            case GGML_OP_OUT_PROD:
+                {
+                    n_tasks = n_threads;
+
+                    size_t cur = 0;
+
+                    if (ggml_is_quantized(node->src[0]->type)) {
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+                    }
+
                     work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_SCALE:
@@ -18969,7 +19671,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
 }
 
 static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
-    int i = 0;
+    int64_t i = 0;
     for (int p = 0; p < np; ++p) {
         const int64_t ne = ggml_nelements(ps[p]) ;
         // TODO: add function to get all elements at once
@@ -18979,6 +19681,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
     }
 }
 
+static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
+    int64_t i = 0;
+    for (int p = 0; p < np; ++p) {
+        const int64_t ne = ggml_nelements(ps[p]) ;
+        // TODO: add function to get all elements at once
+        for (int64_t j = 0; j < ne; ++j) {
+            g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
+        }
+    }
+}
+
 //
 // ADAM
 //
@@ -19027,26 +19740,40 @@ static enum ggml_opt_result ggml_opt_adam(
     const float eps   = params.adam.eps;
     const float gclip = params.adam.gclip;
     const int decay_min_ndim = params.adam.decay_min_ndim;
+    const int n_accum = MAX(1, params.n_gradient_accumulation);
+    const float accum_norm = 1.0f / (float) n_accum;
 
+    float * g  = opt->adam.g->data;  // gradients
     float * m  = opt->adam.m->data;  // first moment
     float * v  = opt->adam.v->data;  // second moment
 
     float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
 
-    if (callback) {
-        callback(callback_data, &sched);
-    }
-
-    // compute the function value
-    ggml_graph_reset  (gf);
-    ggml_set_f32      (f->grad, 1.0f);
-
     struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
     struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
     cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
-    ggml_graph_compute(gb, &cplan);
 
-    opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
+    bool cancel = false;
+
+    // compute the function value
+    float fx = 0;
+    ggml_set_zero(opt->adam.g);
+    for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+        if (callback) {
+            callback(callback_data, accum_step, &sched, &cancel);
+            if (cancel) {
+                return GGML_OPT_CANCEL;
+            }
+        }
+        // ggml_graph_reset  (gf);
+        ggml_set_f32      (f->grad, 1.0f);
+        ggml_graph_compute(gb, &cplan);
+        ggml_opt_acc_grad(np, ps, g, accum_norm);
+        fx += ggml_get_f32_1d(f, 0);
+    }
+    fx *= accum_norm;
+
+    opt->adam.fx_prev = fx;
     opt->adam.fx_best = opt->adam.fx_prev;
     if (pf) {
         pf[opt->iter % params.past] = opt->adam.fx_prev;
@@ -19091,12 +19818,8 @@ static enum ggml_opt_result ggml_opt_adam(
             if (gclip > 0.0f) {
                 // gradient clipping
                 ggml_float sum = 0.0;
-                for (int p = 0; p < np; ++p) {
-                    const int64_t ne = ggml_nelements(ps[p]);
-                    for (int64_t j = 0; j < ne; ++j) {
-                        float g = ggml_get_f32_1d(ps[p]->grad, j);
-                        sum += (ggml_float)(g*g);
-                    }
+                for (int64_t i = 0; i < nx; ++i) {
+                    sum += (ggml_float)(g[i]*g[i]);
                 }
                 ggml_float norm = sqrt(sum);
                 if (norm > (ggml_float) gclip) {
@@ -19110,10 +19833,10 @@ static enum ggml_opt_result ggml_opt_adam(
                 const int64_t ne = ggml_nelements(ps[p]);
                 const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
                 for (int64_t j = 0; j < ne; ++j) {
-                    float x = ggml_get_f32_1d(ps[p], j);
-                    float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
-                    m[i] = m[i]*beta1 +   g*(1.0f - beta1);
-                    v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
+                    float x  = ggml_get_f32_1d(ps[p], j);
+                    float g_ = g[i]*gnorm;
+                    m[i] = m[i]*beta1 +    g_*(1.0f - beta1);
+                    v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
                     float mh = m[i]*beta1h;
                     float vh = v[i]*beta2h;
                     vh = sqrtf(vh) + eps;
@@ -19124,16 +19847,23 @@ static enum ggml_opt_result ggml_opt_adam(
             }
         }
 
-        if (callback) {
-            callback(callback_data, &sched);
+        fx = 0;
+        ggml_set_zero(opt->adam.g);
+        for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+            if (callback) {
+                callback(callback_data, accum_step, &sched, &cancel);
+                if (cancel) {
+                    return GGML_OPT_CANCEL;;
+                }
+            }
+            // ggml_graph_reset  (gf);
+            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_compute(gb, &cplan);
+            ggml_opt_acc_grad(np, ps, g, accum_norm);
+            fx += ggml_get_f32_1d(f, 0);
         }
+        fx *= accum_norm;
 
-        ggml_graph_reset  (gf);
-        ggml_set_f32      (f->grad, 1.0f);
-
-        ggml_graph_compute(gb, &cplan);
-
-        const float fx = ggml_get_f32_1d(f, 0);
         opt->loss_after = fx;
 
 
@@ -19213,11 +19943,11 @@ static enum ggml_opt_result linesearch_backtracking(
         float * step,
         const float * xp,
         struct ggml_tensor * f,
-        struct ggml_cgraph * gf,
         struct ggml_cgraph * gb,
         struct ggml_cplan  * cplan,
         const int np,
         struct ggml_tensor * ps[],
+        bool * cancel,
         ggml_opt_callback callback,
         void * callback_data) {
     int count = 0;
@@ -19231,6 +19961,9 @@ static enum ggml_opt_result linesearch_backtracking(
     const float dec = 0.5f;
     const float inc = 2.1f;
 
+    const int n_accum = MAX(1, params->n_gradient_accumulation);
+    const float accum_norm = 1.0f / (float) n_accum;
+
     if (*step <= 0.f) {
         return GGML_LINESEARCH_INVALID_PARAMETERS;
     }
@@ -19248,12 +19981,6 @@ static enum ggml_opt_result linesearch_backtracking(
     dgtest = params->lbfgs.ftol*dginit;
 
     while (true) {
-        if (callback) {
-            // LBFG-S does not support learning rate -> ignore learning schedule
-            float sched = 0;
-            callback(callback_data, &sched);
-        }
-
         ggml_vec_cpy_f32(nx, x, xp);
         ggml_vec_mad_f32(nx, x, d, *step);
 
@@ -19261,14 +19988,25 @@ static enum ggml_opt_result linesearch_backtracking(
         {
             ggml_opt_set_params(np, ps, x);
 
-            ggml_graph_reset  (gf);
-            ggml_set_f32      (f->grad, 1.0f);
-
-            ggml_graph_compute(gb, cplan);
-
-            ggml_opt_get_grad(np, ps, g);
+            *fx = 0;
+            memset(g, 0, sizeof(float)*nx);
+            for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+                if (callback) {
+                    // LBFG-S does not support learning rate -> ignore learning schedule
+                    float sched = 0;
+                    callback(callback_data, accum_step, &sched, cancel);
+                    if (*cancel) {
+                        return GGML_OPT_CANCEL;
+                    }
+                }
+                // ggml_graph_reset  (gf);
+                ggml_set_f32      (f->grad, 1.0f);
+                ggml_graph_compute(gb, cplan);
+                ggml_opt_acc_grad(np, ps, g, accum_norm);
+                *fx += ggml_get_f32_1d(f, 0);
+            }
+            *fx *= accum_norm;
 
-            *fx = ggml_get_f32_1d(f, 0);
         }
 
         ++count;
@@ -19314,7 +20052,7 @@ static enum ggml_opt_result linesearch_backtracking(
         (*step) *= width;
     }
 
-    return GGML_LINESEARCH_FAIL;
+    GGML_UNREACHABLE();
 }
 
 static enum ggml_opt_result ggml_opt_lbfgs(
@@ -19369,6 +20107,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
 
     float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
 
+    const int n_accum = MAX(1, params.n_gradient_accumulation);
+    const float accum_norm = 1.0f / (float) n_accum;
+
     float fx    = 0.0f; // cost function value
     float xnorm = 0.0f; // ||x||
     float gnorm = 0.0f; // ||g||
@@ -19382,24 +20123,30 @@ static enum ggml_opt_result ggml_opt_lbfgs(
     float * lm_s     = opt->lbfgs.lms->data;
     float * lm_y     = opt->lbfgs.lmy->data;
 
-    if (callback) {
-        // LBFG-S does not support learning rate -> ignore learning schedule
-        float sched = 0;
-        callback(callback_data, &sched);
-    }
+    bool cancel = false;
 
     // evaluate the function value and its gradient
     {
         ggml_opt_set_params(np, ps, x);
 
-        ggml_graph_reset  (gf);
-        ggml_set_f32      (f->grad, 1.0f);
-
-        ggml_graph_compute(gb, &cplan);
-
-        ggml_opt_get_grad(np, ps, g);
-
-        fx = ggml_get_f32_1d(f, 0);
+        fx = 0;
+        memset(g, 0, sizeof(float)*nx);
+        for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+            if (callback) {
+                // LBFG-S does not support learning rate -> ignore learning schedule
+                float sched = 0;
+                callback(callback_data, accum_step, &sched, &cancel);
+                if (cancel) {
+                    return GGML_OPT_CANCEL;
+                }
+            }
+            // ggml_graph_reset  (gf);
+            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_compute(gb, &cplan);
+            ggml_opt_acc_grad(np, ps, g, accum_norm);
+            fx += ggml_get_f32_1d(f, 0);
+        }
+        fx *= accum_norm;
 
         opt->loss_before = fx;
         opt->loss_after  = fx;
@@ -19457,7 +20204,14 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         ggml_vec_cpy_f32(nx, xp, x);
         ggml_vec_cpy_f32(nx, gp, g);
 
-        ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data);
+        // TODO: instead of passing &cancel here, use the return code of the linesearch
+        //       to determine if the optimization should be cancelled
+        //       this is a simple change, but not doing this atm, since I don't have a nice
+        //       way to test and don't want to break something with so many changes lined up
+        ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
+        if (cancel) {
+            return GGML_OPT_CANCEL;
+        }
 
         if (ls < 0) {
             // linesearch failed - go back to the previous point and return
@@ -19566,7 +20320,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         step[0] = 1.0;
     }
 
-    return GGML_OPT_DID_NOT_CONVERGE;
+    GGML_UNREACHABLE();
 }
 
 struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
@@ -19586,6 +20340,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
                     .print_forward_graph  = true,
                     .print_backward_graph = true,
 
+                    .n_gradient_accumulation = 1,
+
                     .adam = {
                         .n_iter = 10000,
                         .sched  = 1.000f,
@@ -19614,6 +20370,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
                     .print_forward_graph  = true,
                     .print_backward_graph = true,
 
+                    .n_gradient_accumulation = 1,
+
                     .lbfgs = {
                         .m              = 6,
                         .n_iter         = 100,
@@ -19644,13 +20402,32 @@ GGML_API void ggml_opt_init(
     opt->iter = 0;
     opt->nx = nx;
     opt->just_initialized = true;
+    if (opt->ctx == NULL) {
+        struct ggml_init_params ctx_opt_params;
+        if (opt->params.type == GGML_OPT_ADAM) {
+            ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
+            if (opt->params.past > 0) {
+                ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
+            }
+        } else if (opt->params.type == GGML_OPT_LBFGS) {
+            ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
+            if (opt->params.past > 0) {
+                ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
+            }
+        }
+        ctx_opt_params.mem_buffer = NULL;
+        ctx_opt_params.no_alloc   = false;
+
+        opt->ctx = ggml_init(ctx_opt_params);
+    }
     switch (opt->params.type) {
         case GGML_OPT_ADAM:
             {
-                opt->adam.m  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->adam.v  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->adam.g  = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+                opt->adam.m  = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+                opt->adam.v  = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
                 opt->adam.pf = params.past > 0
-                    ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+                    ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
                     : NULL;
                 ggml_set_zero(opt->adam.m);
                 ggml_set_zero(opt->adam.v);
@@ -19660,18 +20437,18 @@ GGML_API void ggml_opt_init(
             } break;
         case GGML_OPT_LBFGS:
             {
-                opt->lbfgs.x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->lbfgs.g  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->lbfgs.d  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.x  = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.g  = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+                opt->lbfgs.d  = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
                 opt->lbfgs.pf = params.past > 0
-                    ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
+                    ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
                     : NULL;
-                opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
-                opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
-                opt->lbfgs.lms  = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
-                opt->lbfgs.lmy  = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+                opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
+                opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
+                opt->lbfgs.lms  = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+                opt->lbfgs.lmy  = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
                 ggml_set_zero(opt->lbfgs.x);
                 ggml_set_zero(opt->lbfgs.xp);
                 ggml_set_zero(opt->lbfgs.g);
@@ -20277,10 +21054,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
                                 } break;
                             case GGUF_TYPE_ARRAY:
                             case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
-                        };
+                        }
                     } break;
                 case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
-            };
+            }
 
             if (!ok) {
                 break;
@@ -20556,78 +21333,94 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
     return keyfound;
 }
 
-const char * gguf_get_key(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].key.data;
+const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
+    return ctx->kv[key_id].key.data;
 }
 
-enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].type;
+enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
+    return ctx->kv[key_id].type;
 }
 
-enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.arr.type;
+enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+    return ctx->kv[key_id].value.arr.type;
 }
 
-const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.arr.data;
+const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+    return ctx->kv[key_id].value.arr.data;
 }
 
 const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     struct gguf_kv * kv = &ctx->kv[key_id];
     struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
     return str->data;
 }
 
-int gguf_get_arr_n(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.arr.n;
+int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+    return ctx->kv[key_id].value.arr.n;
 }
 
-uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.uint8;
+uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
+    return ctx->kv[key_id].value.uint8;
 }
 
-int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.int8;
+int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
+    return ctx->kv[key_id].value.int8;
 }
 
-uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.uint16;
+uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
+    return ctx->kv[key_id].value.uint16;
 }
 
-int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.int16;
+int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
+    return ctx->kv[key_id].value.int16;
 }
 
-uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.uint32;
+uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
+    return ctx->kv[key_id].value.uint32;
 }
 
-int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.int32;
+int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
+    return ctx->kv[key_id].value.int32;
 }
 
-float gguf_get_val_f32(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.float32;
+float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
+    return ctx->kv[key_id].value.float32;
 }
 
-uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.uint64;
+uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
+    return ctx->kv[key_id].value.uint64;
 }
 
-int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.int64;
+int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
+    return ctx->kv[key_id].value.int64;
 }
 
-double gguf_get_val_f64(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.float64;
+double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
+    return ctx->kv[key_id].value.float64;
 }
 
-bool gguf_get_val_bool(const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.bool_;
+bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
+    return ctx->kv[key_id].value.bool_;
 }
 
-const char * gguf_get_val_str (const struct gguf_context * ctx, int i) {
-    return ctx->kv[i].value.str.data;
+const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
+    return ctx->kv[key_id].value.str.data;
 }
 
 int gguf_get_n_tensors(const struct gguf_context * ctx) {
@@ -20992,10 +21785,10 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf *
                             } break;
                         case GGUF_TYPE_ARRAY:
                         case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break;
-                    };
+                    }
                 } break;
             case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type");
-        };
+        }
     }
 
     // write tensor infos
index 0977d3ef83770746352820753f8abcc501c81101..cd0cc5ffd2df45bfff578c76ecab2fd6e4f30d26 100644 (file)
@@ -14,7 +14,7 @@
 
 #include <Accelerate/Accelerate.h>
 
-uint64_t get_time_us() {
+uint64_t get_time_us(void) {
     struct timeval tv;
     gettimeofday(&tv, NULL);
     return tv.tv_sec * 1000000 + tv.tv_usec;
index 468cde66adc65e93b5eea422361a979bedf1b682..0a559b27ab370e7af3dac9d111bc60a48b812782 100644 (file)
@@ -107,7 +107,7 @@ static struct ggml_tensor * get_random_tensor_f32(
             break;
         default:
             assert(false);
-    };
+    }
 
     return result;
 }
@@ -155,7 +155,7 @@ static struct ggml_tensor * get_random_tensor_f16(
             break;
         default:
             assert(false);
-    };
+    }
 
     return result;
 }
@@ -203,29 +203,9 @@ static struct ggml_tensor * get_random_tensor_i32(
             break;
         default:
             assert(false);
-    };
-
-    return result;
-}
-
-static void print_elements(const char* label, const struct ggml_tensor * t) {
-    if (!t) {
-        printf("%s: %s = null\n", __func__, label);
-        return;
     }
-    const int nelements = ggml_nelements(t);
-    printf("%s: %s = [", __func__, label);
-    for (int k = 0; k < nelements; ++k) {
-        if (k > 0) { printf(", "); }
-        printf("%.5f", ggml_get_f32_1d(t, k));
-    }
-    printf("] shape: [");
-    for (int k = 0; k < t->n_dims; ++k) {
-        if (k > 0) { printf(", "); }
-        printf("%d", (int)t->ne[k]);
-    }
-    printf("]\n");
 
+    return result;
 }
 
 static bool check_gradient(
@@ -251,18 +231,20 @@ static bool check_gradient(
         printf("GGML_N_THREADS = %d\n", n_threads);
     }
 
-    struct ggml_cgraph gf = ggml_build_forward (f);
-    struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+    struct ggml_cgraph * gf = ggml_build_forward_ctx(ctx0, f);
+    struct ggml_cgraph * gb = ggml_new_graph(ctx0);
+    *gb = *gf;
+    ggml_build_backward_expand(ctx0, gf, gb, false);
 
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
-    ggml_graph_reset  (&gf);
+    ggml_graph_reset  (gf);
     ggml_set_f32      (f->grad, 1.0f);
 
-    ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+    ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
-    // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
-    // ggml_graph_dump_dot(&gb, &gf,  "test-grad0-backward.dot");
+    // ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
+    // ggml_graph_dump_dot(gb, gf,  "test-grad0-backward.dot");
 
     for (int i = 0; i < nargs; ++i) {
         const int nelements = ggml_nelements(x[i]);
@@ -273,13 +255,13 @@ static bool check_gradient(
             const float xp = x0 + eps;
             ggml_set_f32_1d(x[i], k, xp);
 
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
             const double f0 = ggml_get_f32_1d(f, 0);
 
             ggml_set_f32_1d(x[i], k, xm);
 
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
             const double f1 = ggml_get_f32_1d(f, 0);
             const double g0 = (f0 - f1)/(2.0*(double) eps);
@@ -287,10 +269,10 @@ static bool check_gradient(
             ggml_set_f32_1d(x[i], k, x0);
 
             // compute gradient using backward graph
-            ggml_graph_reset  (&gf);
+            ggml_graph_reset  (gf);
             ggml_set_f32      (f->grad, 1.0f);
 
-            ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
+            ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
             const double g1 = ggml_get_f32_1d(x[i]->grad, k);
 
@@ -373,7 +355,7 @@ static bool check_mat_mul(
 
 int main(int argc, const char ** argv) {
     struct ggml_init_params params = {
-        /* .mem_size   = */ 128*1024*1024,
+        /* .mem_size   = */ 256*1024*1024,
         /* .mem_buffer = */ NULL,
         /* .no_alloc   = */ false,
     };
@@ -405,6 +387,7 @@ int main(int argc, const char ** argv) {
         }
     }
 
+    unsigned seed_iter = 1;
 
     // original loop: 1000
     int niter = 4;
@@ -416,6 +399,10 @@ int main(int argc, const char ** argv) {
         niter = atoi(argv[1]);
     }
     for (int iter = 0; iter < niter; ++iter) {
+        srand(seed_iter);
+        seed_iter = rand();
+        unsigned seed = rand();
+
         printf("test-grad0: iter:%d/%d\n", iter, niter);
         struct ggml_context * ctx0 = ggml_init(params);
 
@@ -425,6 +412,7 @@ int main(int argc, const char ** argv) {
 
         // add f32
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -441,6 +429,7 @@ int main(int argc, const char ** argv) {
 
         // add f16
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -457,6 +446,7 @@ int main(int argc, const char ** argv) {
 
         // sub
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -473,6 +463,7 @@ int main(int argc, const char ** argv) {
 
         // mul
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -489,6 +480,7 @@ int main(int argc, const char ** argv) {
 
         // div
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -505,6 +497,7 @@ int main(int argc, const char ** argv) {
 
         // sqr
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -521,6 +514,7 @@ int main(int argc, const char ** argv) {
 
         // sqrt
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -537,6 +531,7 @@ int main(int argc, const char ** argv) {
 
         // log
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -553,6 +548,7 @@ int main(int argc, const char ** argv) {
 
         // sum
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -570,6 +566,7 @@ int main(int argc, const char ** argv) {
 
         // sum_rows
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -587,6 +584,7 @@ int main(int argc, const char ** argv) {
         // mean, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -604,6 +602,7 @@ int main(int argc, const char ** argv) {
         // argmax
         if (0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -620,6 +619,7 @@ int main(int argc, const char ** argv) {
 
         // repeat
         {
+            srand(seed);
             int64_t ne2[4];
             get_random_dims(ne2, 4);
 
@@ -642,6 +642,7 @@ int main(int argc, const char ** argv) {
 
         // repeat back
         {
+            srand(seed);
             int64_t ne2[4];
             get_random_dims(ne2, 4);
 
@@ -680,6 +681,7 @@ int main(int argc, const char ** argv) {
 
         // sgn
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -696,6 +698,7 @@ int main(int argc, const char ** argv) {
 
         // neg
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -712,6 +715,7 @@ int main(int argc, const char ** argv) {
 
         // step
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -729,6 +733,7 @@ int main(int argc, const char ** argv) {
         // tanh, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -745,33 +750,45 @@ int main(int argc, const char ** argv) {
 
         // mul_mat
         {
+            srand(seed);
             const int nargs = 2;
 
-            for (int ndims = 2; ndims <= 2; ++ndims) {
+            for (int ndims = 2; ndims <= 4; ++ndims) {
+                int max_nrep = (ndims >= 3) ? 2 : 1;
                 x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                {
-                    int64_t ne2[4];
-                    get_random_dims(ne2, 4);
-                    ne2[0] = ne[0];
-                    x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                }
+                for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) {
+                    for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) {
+                        {
+                            int64_t ne2[4];
+                            get_random_dims(ne2, 4);
+                            ne2[0] = ne[0];
+                            ne2[2] = nrep2 * ne[2];
+                            ne2[3] = nrep3 * ne[3];
+                            x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                        }
 
-                ggml_set_param(ctx0, x[0]);
-                ggml_set_param(ctx0, x[1]);
+                        ggml_set_param(ctx0, x[0]);
+                        ggml_set_param(ctx0, x[1]);
 
-                struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
-                struct ggml_tensor * f = ggml_sum(ctx0, m);
+                        struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
+                        struct ggml_tensor * f = ggml_sum(ctx0, m);
 
-                GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
+                        GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
 
-                check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
-                check_mat_mul(m, x[1], x[0]);
+                        check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                        if (ndims == 2) {
+                            // check_mat_mul does not support ndims > 2
+                            check_mat_mul(m, x[1], x[0]);
+                        }
+                    }
+                }
             }
         }
 
         // elu, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -788,6 +805,7 @@ int main(int argc, const char ** argv) {
 
         // relu
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -805,6 +823,7 @@ int main(int argc, const char ** argv) {
         // gelu, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
@@ -821,6 +840,7 @@ int main(int argc, const char ** argv) {
 
         // silu
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -842,6 +862,7 @@ int main(int argc, const char ** argv) {
 
         // rms_norm
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -858,6 +879,7 @@ int main(int argc, const char ** argv) {
 
         // scale
         {
+            srand(seed);
             const int nargs = 2;
 
             int64_t ne2[4];
@@ -878,6 +900,7 @@ int main(int argc, const char ** argv) {
 
         // cpy f32
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -895,6 +918,7 @@ int main(int argc, const char ** argv) {
 
         // cpy f16
         {
+            srand(seed);
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -912,6 +936,7 @@ int main(int argc, const char ** argv) {
 
         // reshape (1d->nd)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -935,6 +960,7 @@ int main(int argc, const char ** argv) {
 
         // reshape (nd->1d)
         {
+            srand(seed);
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
@@ -958,6 +984,7 @@ int main(int argc, const char ** argv) {
 
         // acc 1d
         {
+            srand(seed);
             int64_t ne2[4] = { 1, 1, 1, 1 };
 
             const int nargs = 2;
@@ -985,6 +1012,7 @@ int main(int argc, const char ** argv) {
 
         // acc 2d
         {
+            srand(seed);
             int64_t ne2[4]         = { 1, 1, 1, 1 };
             int64_t max_offsets[4] = { 0, 0, 0, 0 };
             int64_t offsets[4]     = { 0, 0, 0, 0 };
@@ -1017,6 +1045,7 @@ int main(int argc, const char ** argv) {
 
         // acc 3d
         {
+            srand(seed);
             int64_t ne2[4]         = { 1, 1, 1, 1 };
             int64_t max_offsets[4] = { 0, 0, 0, 0 };
             int64_t offsets[4]     = { 0, 0, 0, 0 };
@@ -1051,6 +1080,7 @@ int main(int argc, const char ** argv) {
 
         // acc 4d
         {
+            srand(seed);
             int64_t ne2[4]         = { 1, 1, 1, 1 };
             int64_t max_offsets[4] = { 0, 0, 0, 0 };
             int64_t offsets[4]     = { 0, 0, 0, 0 };
@@ -1087,6 +1117,7 @@ int main(int argc, const char ** argv) {
 
         // set_1d
         {
+            srand(seed);
             int64_t ne2[4];
 
             const int nargs = 2;
@@ -1114,6 +1145,7 @@ int main(int argc, const char ** argv) {
 
         // set_2d
         {
+            srand(seed);
             int64_t ne2[4];
             int64_t max_offsets[4] = { 0, 0, 0, 0 };
             int64_t offsets[4]     = { 0, 0, 0, 0 };
@@ -1146,6 +1178,7 @@ int main(int argc, const char ** argv) {
 
         // view_1d
         {
+            srand(seed);
             const int nargs = 1;
             for (int ndims = 1; ndims <= 4; ++ndims) {
 
@@ -1169,6 +1202,7 @@ int main(int argc, const char ** argv) {
 
         // view_2d
         {
+            srand(seed);
             int64_t ne2[4];
             int64_t nb2[4];
 
@@ -1199,6 +1233,7 @@ int main(int argc, const char ** argv) {
 
         // view_3d
         {
+            srand(seed);
             int64_t ne2[4] = {1,1,1,1};
             int64_t nb2[4] = {0,0,0,0};
 
@@ -1230,6 +1265,7 @@ int main(int argc, const char ** argv) {
 
         // permute
         {
+            srand(seed);
             int64_t ne2[4];
 
             const int nargs = 1;
@@ -1263,6 +1299,7 @@ int main(int argc, const char ** argv) {
 
         // transpose
         {
+            srand(seed);
             int64_t ne2[4];
 
             const int nargs = 1;
@@ -1290,6 +1327,7 @@ int main(int argc, const char ** argv) {
 
         // get_rows
         {
+            srand(seed);
             int64_t ne2[4] = {ne[0], ne[1], 1, 1};
             int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
             const int nargs = 1;
@@ -1306,6 +1344,7 @@ int main(int argc, const char ** argv) {
 
         // diag_mask_inf
         {
+            srand(seed);
             const int nargs = 1;
             const int ndims = 2;
 
@@ -1321,6 +1360,7 @@ int main(int argc, const char ** argv) {
 
         // diag_mask_zero
         {
+            srand(seed);
             const int nargs = 1;
             const int ndims = 2;
 
@@ -1336,6 +1376,7 @@ int main(int argc, const char ** argv) {
 
         // softmax
         {
+            srand(seed);
             const int nargs = 1;
 
             int64_t ne2[4];
@@ -1357,11 +1398,16 @@ int main(int argc, const char ** argv) {
                                                     ggml_new_f32(ctx0, eps))));
 
                 check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
+                // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
+                // this may result in different gradients too finite differences.
+                // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
+                // if only the table lookup causes gradients to differ this is acceptable.
             }
         }
 
         // cross_entropy_loss
         {
+            srand(seed);
             const int nargs = 1;
 
             int64_t ne2[4];
@@ -1392,6 +1438,7 @@ int main(int argc, const char ** argv) {
 
         // rope f32
         {
+            srand(seed);
             const int nargs = 1;
 
             int64_t ne2[4];
@@ -1404,6 +1451,11 @@ int main(int argc, const char ** argv) {
                     for (int n_past = 1; n_past < ne2[2]; ++n_past) {
                         x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
 
+                        struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
+                        for (int i = 0; i < ne2[2]; ++i) {
+                            ((int32_t *) p->data)[i] = n_past + i;
+                        }
+
                         ggml_set_param(ctx0, x[0]);
 
                         const bool skip_past = (mode & 1);
@@ -1415,7 +1467,7 @@ int main(int argc, const char ** argv) {
                             continue;
                         }
 
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
 
                         GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
                         check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
@@ -1426,6 +1478,7 @@ int main(int argc, const char ** argv) {
 
         // rope f16
         {
+            srand(seed);
             const int nargs = 1;
 
             int64_t ne2[4];
@@ -1438,6 +1491,11 @@ int main(int argc, const char ** argv) {
                     for (int n_past = 1; n_past < ne2[2]; ++n_past) {
                         x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
 
+                        struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
+                        for (int i = 0; i < ne2[2]; ++i) {
+                            ((int32_t *) p->data)[i] = n_past + i;
+                        }
+
                         ggml_set_param(ctx0, x[0]);
 
                         const bool skip_past = (mode & 1);
@@ -1449,7 +1507,7 @@ int main(int argc, const char ** argv) {
                             continue;
                         }
 
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
 
                         GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
                         check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
@@ -1460,6 +1518,7 @@ int main(int argc, const char ** argv) {
 
         // flash_attn f32
         {
+            srand(seed);
             const int nargs = 3;
 
             int64_t ne2[4];
@@ -1472,28 +1531,31 @@ int main(int argc, const char ** argv) {
 
             for (int masked = 0; masked <= 1; ++masked) {
                 for (int ndims = 2; ndims <= 4; ++ndims) {
-                    int64_t neq[4] = { D, N, B, ne[3] };
-                    int64_t nek[4] = { D, M, B, ne[3] };
-                    int64_t nev[4] = { M, D, B, ne[3] };
-                    if (ndims == 2) {
-                        neq[2] = 1; neq[3] = 1;
-                        nek[2] = 1; nek[3] = 1;
-                        nev[2] = 1; nev[3] = 1;
-                    } else if (ndims == 3) {
-                        neq[3] = 1;
-                        nek[3] = 1;
-                        nev[3] = 1;
-                    }
-                    x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
-                    x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
-                    x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
-                    ggml_set_param(ctx0, x[0]);
-                    ggml_set_param(ctx0, x[1]);
-                    ggml_set_param(ctx0, x[2]);
+                    int max_nrep = (ndims >= 3) ? 2 : 1;
+                    for (int nrep = 1; nrep < max_nrep; ++nrep) {
+                        int64_t neq[4] = { D, N, B*nrep, ne[3] };
+                        int64_t nek[4] = { D, M, B, ne[3] };
+                        int64_t nev[4] = { M, D, B, ne[3] };
+                        if (ndims == 2) {
+                            neq[2] = 1; neq[3] = 1;
+                            nek[2] = 1; nek[3] = 1;
+                            nev[2] = 1; nev[3] = 1;
+                        } else if (ndims == 3) {
+                            neq[3] = 1;
+                            nek[3] = 1;
+                            nev[3] = 1;
+                        }
+                        x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
+                        x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
+                        x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
+                        ggml_set_param(ctx0, x[0]);
+                        ggml_set_param(ctx0, x[1]);
+                        ggml_set_param(ctx0, x[2]);
 
-                    struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
 
-                    check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+                        check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+                    }
                 }
             }
         }
@@ -1501,6 +1563,7 @@ int main(int argc, const char ** argv) {
         // flash_attn f16, not yet fully implemented
         if(0)
         {
+            srand(seed);
             const int nargs = 3;
 
             int64_t ne2[4];
index fc240778cfe6cb8b8903b1ecbdef6200f3d67146..b725a5872c7f4c7910961a60c8fa58e1506bda65 100644 (file)
@@ -16,7 +16,7 @@ const int M = 1280;
 const int N = 1536;
 const int K = 1280;
 
-uint64_t get_time_us() {
+uint64_t get_time_us(void) {
     struct timeval tv;
     gettimeofday(&tv, NULL);
     return tv.tv_sec * 1000000 + tv.tv_usec;
index 8ab240202a5850da56f10ee468ea5b0b081a0a22..bb8af59620b146fc9c2e93d048cdb6a162c6937f 100644 (file)
 #define GGML_PRINT(...) printf(__VA_ARGS__)
 
 
-float frand(void) {
+static float frand(void) {
     return (float)rand()/(float)RAND_MAX;
 }
 
-int irand(int n) {
-    return rand()%n;
-}
-
-void get_random_dims(int64_t * dims, int ndims) {
-    dims[0] = dims[1] = dims[2] = dims[3] = 1;
-
-    for (int i = 0; i < ndims; i++) {
-        dims[i] = 1 + irand(4);
-    }
-}
-
-void get_random_dims_minmax(int64_t * dims, int ndims, int min, int max) {
-    dims[0] = dims[1] = dims[2] = dims[3] = 1;
-
-    for (int i = 0; i < ndims; i++) {
-        dims[i] = min + irand(max-min);
-    }
-}
-
-
-struct ggml_tensor * get_random_tensor(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        float fmin,
-        float fmax) {
+static struct ggml_tensor * get_random_tensor(
+    struct ggml_context * ctx0, int ndims, int64_t ne[], float fmin, float fmax
+) {
     struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
 
     switch (ndims) {
@@ -104,19 +80,11 @@ struct ggml_tensor * get_random_tensor(
             break;
         default:
             assert(false);
-    };
+    }
 
     return result;
 }
 
-float get_element(const struct ggml_tensor * t, int idx) {
-    return ((float *)t->data)[idx];
-}
-
-void set_element(struct ggml_tensor * t, int idx, float value) {
-    ((float *)t->data)[idx] = value;
-}
-
 int main(void) {
     struct ggml_init_params params = {
         /* .mem_size   = */ 1024*1024*1024,
@@ -127,7 +95,7 @@ int main(void) {
     struct ggml_context * ctx = ggml_init(params);
 
     int64_t ne1[4] = {4, 128, 1, 1};
-    int64_t ne2[4] = {4, 256, 1, 1};;
+    int64_t ne2[4] = {4, 256, 1, 1};
     int64_t ne3[4] = {128, 256, 1, 1};
 
     struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
index 8d3c162d2bfa04bdb29a40131368dd9efda7523d..884af40548fb7912cd2e80c3c7e503bba938c06b 100644 (file)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
-const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
-const float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
-const float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
-const float MAX_DOT_PRODUCT_ERROR = 0.02f;
+constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
+constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
 
-const char* RESULT_STR[] = {"ok", "FAILED"};
+static const char* RESULT_STR[] = {"ok", "FAILED"};
 
 
 // Generate synthetic data
-void generate_data(float offset, size_t n, float * dst) {
+static void generate_data(float offset, size_t n, float * dst) {
     for (size_t i = 0; i < n; i++) {
         dst[i] = 0.1 + 2*cosf(i + offset);
     }
 }
 
 // Calculate RMSE between two float arrays
-float array_rmse(const float * a1, const float * a2, size_t n) {
+static float array_rmse(const float * a1, const float * a2, size_t n) {
     double sum = 0;
     for (size_t i = 0; i < n; i++) {
         double diff = a1[i] - a2[i];
@@ -40,7 +40,7 @@ float array_rmse(const float * a1, const float * a2, size_t n) {
 }
 
 // Total quantization error on test data
-float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
+static float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
     std::vector<uint8_t> tmp_q(2*test_size);
     std::vector<float> tmp_out(test_size);
 
@@ -50,7 +50,7 @@ float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, cons
 }
 
 // Total quantization error on test data
-float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
+static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
     std::vector<uint8_t> tmp_q(2*test_size);
     std::vector<float> tmp_out(test_size);
     std::vector<float> tmp_out_ref(test_size);
@@ -64,7 +64,7 @@ float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size,
     return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
 }
 
-float dot_product(const float * a1, const float * a2, size_t test_size) {
+static float dot_product(const float * a1, const float * a2, size_t test_size) {
     double sum = 0;
     for (size_t i = 0; i < test_size; i++) {
         sum += a1[i] * a2[i];
@@ -73,7 +73,9 @@ float dot_product(const float * a1, const float * a2, size_t test_size) {
 }
 
 // Total dot product error
-float dot_product_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
+static float dot_product_error(
+    ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2
+) {
     std::vector<uint8_t> tmp_q1(2*test_size);
     std::vector<uint8_t> tmp_q2(2*test_size);
 
index cbea7d4525ca4c7479c48d73fc936aa030a2b537..88fac0e23106bc4f11c4899da9c12479a29c352a 100644 (file)
@@ -61,37 +61,36 @@ inline int64_t cpu_cycles() {
 
 
 // Generate synthetic data
-void generate_data(float offset, size_t n, float * dst) {
+static void generate_data(float offset, size_t n, float * dst) {
     for (size_t i = 0; i < n; i++) {
         dst[i] = 0.1 + 2*cosf(i + offset);
     }
 }
 
-float gigabytes_per_second(size_t bytes, int64_t usecs) {
+static float gigabytes_per_second(size_t bytes, int64_t usecs) {
     return bytes / (float) usecs * 1000000 / (1024*1024*1024);
 }
 
-void * align_with_offset(void * ptr, int offset) {
+static void * align_with_offset(void * ptr, int offset) {
     size_t dummy_size = MAX_ALIGNMENT * 4;
     return (char *) std::align(MAX_ALIGNMENT, MAX_ALIGNMENT, ptr, dummy_size) + offset;
 }
 
-void benchmark_function(size_t size, size_t q_size, int64_t iterations, const std::function<size_t(void)> & function) {
+static void benchmark_function(size_t size, size_t q_size, int64_t iterations, const std::function<float(void)> & func) {
     int64_t min_time_us = INT64_MAX;
     int64_t total_time_us = 0;
     int64_t min_time_cycles = INT64_MAX;
     int64_t total_time_cycles = 0;
 
     for (int i = 0; i < WARMUP; i++) {
-        function();
+        func();
     }
 
-
     for (int i = 0; i < iterations; i++) {
         const int64_t start_time = ggml_time_us();
         const int64_t start_cycles = cpu_cycles();
 
-        function();
+        func();
 
         const int64_t end_cycles = cpu_cycles();
         const int64_t end_time = ggml_time_us();
@@ -108,7 +107,7 @@ void benchmark_function(size_t size, size_t q_size, int64_t iterations, const st
     printf("      quantized throughput : %9.2f GB/s\n",  gigabytes_per_second(q_size * iterations, total_time_us));
 }
 
-void usage(char * argv[]) {
+static void usage(char * argv[]) {
     printf("Benchmark quantization specific functions on synthetic data\n");
     printf("\n");
     printf("usage: %s [options]\n", argv[0]);
@@ -245,15 +244,15 @@ int main(int argc, char * argv[]) {
 
     std::vector<uint8_t> test_data1_v(largest*4 + MAX_ALIGNMENT*2);
     std::vector<uint8_t> test_data2_v(largest*4 + MAX_ALIGNMENT*2);
-    std::vector<uint8_t> test_q1_v(largest*4 + MAX_ALIGNMENT*2);
-    std::vector<uint8_t> test_q2_v(largest*4 + MAX_ALIGNMENT*2);
-    std::vector<uint8_t> test_out_v(largest*4 + MAX_ALIGNMENT*2);
+    std::vector<uint8_t> test_q1_v   (largest*4 + MAX_ALIGNMENT*2);
+    std::vector<uint8_t> test_q2_v   (largest*4 + MAX_ALIGNMENT*2);
+    std::vector<uint8_t> test_out_v  (largest*4 + MAX_ALIGNMENT*2);
 
     float * test_data1 = (float *) align_with_offset(test_data1_v.data(), params.alignment_offset);
     float * test_data2 = (float *) align_with_offset(test_data2_v.data(), params.alignment_offset);
-    float * test_q1 = (float *) align_with_offset(test_q1_v.data(), params.alignment_offset);
-    float * test_q2 = (float *) align_with_offset(test_q2_v.data(), params.alignment_offset);
-    float * test_out = (float *) align_with_offset(test_out_v.data(), params.alignment_offset);
+    float * test_q1    = (float *) align_with_offset(test_q1_v.data(),    params.alignment_offset);
+    float * test_q2    = (float *) align_with_offset(test_q2_v.data(),    params.alignment_offset);
+    float * test_out   = (float *) align_with_offset(test_out_v.data(),   params.alignment_offset);
 
     generate_data(0, largest, test_data1);
     generate_data(1, largest, test_data2);
@@ -283,7 +282,7 @@ int main(int argc, char * argv[]) {
                 printf("  quantize_row_q_reference\n");
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
-                    auto quantize_fn = [&](void ) {
+                    auto quantize_fn = [&](void) -> float {
                         qfns.from_float_reference(test_data1, test_q1, size);
                         return test_q1[0];
                     };
@@ -297,7 +296,7 @@ int main(int argc, char * argv[]) {
                 printf("  quantize_row_q\n");
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
-                    auto quantize_fn = [&](void ) {
+                    auto quantize_fn = [&](void) -> float {
                         qfns.from_float(test_data1, test_q1, size);
                         return test_q1[0];
                     };
@@ -312,7 +311,7 @@ int main(int argc, char * argv[]) {
                 qfns.from_float(test_data1, test_q1, largest);
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
-                    auto quantize_fn = [&](void ) {
+                    auto quantize_fn = [&](void) -> float {
                         qfns.to_float(test_q1, test_out, size);
                         return test_out[0];
                     };
@@ -326,7 +325,7 @@ int main(int argc, char * argv[]) {
                 printf("  quantize_row_q_dot\n");
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
-                    auto quantize_fn = [&](void ) {
+                    auto quantize_fn = [&](void) -> float {
                         auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
                         vdot.from_float(test_data1, test_q1, size);
                         return test_q1[0];
@@ -343,7 +342,7 @@ int main(int argc, char * argv[]) {
                 qfns.from_float(test_data2, test_q2, largest);
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
-                    auto quantize_fn = [&](void ) {
+                    auto quantize_fn = [&](void) -> float {
                         float result;
                         qfns.vec_dot(size, &result, test_q1, test_q2);
                         return result;
index 2295c9db9fba0f79803d16c164c022f0d88b53c3..8160bd3b9c7780197ba2c37177d0ae738d84a264 100644 (file)
@@ -15,7 +15,7 @@
 #include <Accelerate/Accelerate.h>
 #endif
 
-float frand() {
+float frand(void) {
     return (float) rand() / (float) RAND_MAX;
 }
 
index 465cf53dfe7d478e71329d362b92959885ec1803..4fa364ca5a5d322783c871b58dab8d74b10258fc 100644 (file)
@@ -170,7 +170,7 @@ void mul_mat_vec_f16_1(
     }
 }
 
-uint64_t get_time_us() {
+uint64_t get_time_us(void) {
     struct timeval tv;
     gettimeofday(&tv, NULL);
     return tv.tv_sec * 1000000 + tv.tv_usec;
index a8c64e556a325605344d577963a2b243fb3acd80..9db47d9bcc83204f2361470818d8a650f70b8e2f 100644 (file)
@@ -18,7 +18,7 @@ int main(int argc, char ** argv) {
         .mem_size   = 16*1024*1024,
         .mem_buffer = NULL,
     };
-   
+
     // memory allocation happens here
     struct ggml_context * ctx = ggml_init(params);
 
@@ -30,22 +30,28 @@ int main(int argc, char ** argv) {
         ((float*) K->data)[i] = 2.0f;
     }
 
-    struct ggml_tensor * Qx = ggml_rope_xpos_inplace(ctx, Q, 1, n_embd_head, 512.0f, false);
-    struct ggml_tensor * Kx = ggml_rope_xpos_inplace(ctx, K, 1, n_embd_head, 512.0f, true);
-   
+    struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
+    int * data = (int *) KQ_pos->data;
+    for (int i = 0; i < N; ++i) {
+        data[i] = 1 + i;
+    }
+
+    struct ggml_tensor * Qx = ggml_rope_xpos_inplace(ctx, Q, KQ_pos, n_embd_head, 512.0f, false);
+    struct ggml_tensor * Kx = ggml_rope_xpos_inplace(ctx, K, KQ_pos, n_embd_head, 512.0f, true);
+
     struct ggml_cgraph gf = ggml_build_forward(Qx);
     ggml_build_forward_expand(&gf, Kx);
     ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
 
        // expected output for Qx:
-    // -0.6009  2.7568  1.9782  2.0182 
-    // -2.6379  0.9815  1.9562  2.0361 
-    // -2.2457 -1.6853  1.9341  2.0538 
-    //  0.2043 -2.7934  1.9118  2.0712 
-    //  2.4550 -1.3341  1.8894  2.0884 
-    //  2.4430  1.3417  1.8668  2.1054 
-    //  0.1905  2.7739  1.8440  2.1221 
-    // -2.2257  1.6550  1.8212  2.1386 
+    // -0.6009  2.7568  1.9782  2.0182
+    // -2.6379  0.9815  1.9562  2.0361
+    // -2.2457 -1.6853  1.9341  2.0538
+    //  0.2043 -2.7934  1.9118  2.0712
+    //  2.4550 -1.3341  1.8894  2.0884
+    //  2.4430  1.3417  1.8668  2.1054
+    //  0.1905  2.7739  1.8440  2.1221
+    // -2.2257  1.6550  1.8212  2.1386
 
     for (int i = 0; i < ggml_nelements(Q); i++) {
         if (((float*) Qx->data)[i] > 0) printf(" ");
@@ -60,15 +66,15 @@ int main(int argc, char ** argv) {
     GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 3],  2.1386f, 0.0001f));
 
     // expected output for Kx:
-       // -0.6038  2.7703  1.9816  2.0216 
-    // -2.6639  0.9911  1.9630  2.0431 
-    // -2.2789 -1.7103  1.9441  2.0644 
-    //  0.2083 -2.8486  1.9251  2.0856 
-    //  2.5158 -1.3671  1.9057  2.1065 
-    //  2.5158  1.3816  1.8862  2.1273 
-    //  0.1972  2.8705  1.8665  2.1479 
-    // -2.3146  1.7211  1.8465  2.1684 
-    
+       // -0.6038  2.7703  1.9816  2.0216
+    // -2.6639  0.9911  1.9630  2.0431
+    // -2.2789 -1.7103  1.9441  2.0644
+    //  0.2083 -2.8486  1.9251  2.0856
+    //  2.5158 -1.3671  1.9057  2.1065
+    //  2.5158  1.3816  1.8862  2.1273
+    //  0.1972  2.8705  1.8665  2.1479
+    // -2.3146  1.7211  1.8465  2.1684
+
     for (int i = 0; i < ggml_nelements(K); i++) {
         if (((float*) Kx->data)[i] > 0) printf(" ");
         printf("%.4f ", ((float*) Kx->data)[i]);