]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sync : ggml (new ops, new backend, etc) (#1602)
authorGeorgi Gerganov <redacted>
Thu, 7 Dec 2023 20:27:19 +0000 (22:27 +0200)
committerGitHub <redacted>
Thu, 7 Dec 2023 20:27:19 +0000 (22:27 +0200)
* sync : ggml (new ops, new backend, etc)

* whisper : remove obsolete broadcasting code

* ggml : remove backend self-registers + fix ggml_concat + n_task logic

* metal : fix assert

* metal : print resource path

* whisper : fix bug if metal init fails

16 files changed:
ggml-alloc.c
ggml-alloc.h
ggml-backend-impl.h
ggml-backend.c
ggml-backend.h
ggml-cuda.cu
ggml-cuda.h
ggml-impl.h
ggml-metal.h
ggml-metal.m
ggml-metal.metal
ggml-opencl.cpp
ggml-quants.c
ggml.c
ggml.h
whisper.cpp

index cdfe4caf69613d6778f8a80af9b0d6ffc8e67652..d3049efb497a0a09a2c0b2b94f45629f8531e0a7 100644 (file)
@@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
 
 #ifdef GGML_ALLOCATOR_DEBUG
     add_allocated_tensor(alloc, tensor);
-    size_t cur_max = (char*)addr - (char*)alloc->data + size;
+    size_t cur_max = (char*)addr - (char*)alloc->base + size;
     if (cur_max > alloc->max_size) {
         printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
         for (int i = 0; i < 1024; i++) {
@@ -168,10 +168,6 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
     size = aligned_offset(NULL, size, alloc->alignment);
     AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
 
-    if (!alloc->measure) {
-        ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
-    }
-
 #ifdef GGML_ALLOCATOR_DEBUG
     remove_allocated_tensor(alloc, tensor);
 #endif
@@ -237,7 +233,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
 }
 
 ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
-    struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
+    struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
 
     ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
 
@@ -449,7 +445,6 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
 static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
     ggml_tallocr_t alloc = node_tallocr(galloc, view);
 
-    //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
     GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
     if (update_backend) {
         view->backend = view->view_src->backend;
@@ -459,7 +454,7 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
 
     // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
     // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
-    assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
+    assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
 
     if (!alloc->measure) {
         ggml_backend_buffer_init_tensor(alloc->buffer, view);
@@ -765,3 +760,43 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
 size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
     return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
 }
+
+// utils
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
+
+    size_t alignment = ggml_backend_buft_get_alignment(buft);
+
+    size_t nbytes = 0;
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL && t->view_src == NULL) {
+            nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
+        }
+    }
+
+    if (nbytes == 0) {
+        fprintf(stderr, "%s: no tensors to allocate\n", __func__);
+        return NULL;
+    }
+
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
+    ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
+
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL) {
+            if (t->view_src == NULL) {
+                ggml_tallocr_alloc(tallocr, t);
+            } else {
+                ggml_backend_view_init(buffer, t);
+            }
+        }
+    }
+
+    ggml_tallocr_free(tallocr);
+
+    return buffer;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
+    return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
+}
index dde2a06bf803098a85d60c15a47953753e4a0bd8..ad87cebc8873f4584ad99e2c67e110a603f4a71c 100644 (file)
@@ -8,6 +8,7 @@ extern "C" {
 
 struct ggml_backend;
 struct ggml_backend_buffer;
+struct ggml_backend_buffer_type;
 
 //
 // Legacy API
@@ -80,6 +81,12 @@ GGML_API void   ggml_gallocr_alloc_graph_n(
                     struct ggml_hash_set hash_set,
                     ggml_tallocr_t * hash_node_talloc);
 
+
+// Utils
+// Create a buffer and allocate all the tensors in a ggml_context
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
+
 #ifdef  __cplusplus
 }
 #endif
index 211e3d4247387b2b598caca2e79d6863ffdf4c35..f588af60282650fa65296ebd462b7affbfae8e2e 100644 (file)
@@ -12,31 +12,50 @@ extern "C" {
     // Backend buffer
     //
 
+    // buffer type
+    typedef void * ggml_backend_buffer_type_context_t;
+
+    struct ggml_backend_buffer_type_i {
+        ggml_backend_buffer_t (*alloc_buffer)    (ggml_backend_buffer_type_t buft, size_t size);
+        size_t                (*get_alignment)   (ggml_backend_buffer_type_t buft); // tensor alignment
+        size_t                (*get_alloc_size)  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
+        bool                  (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
+    };
+
+    struct ggml_backend_buffer_type {
+        struct ggml_backend_buffer_type_i  iface;
+        ggml_backend_buffer_type_context_t context;
+    };
+
+    // buffer
     typedef void * ggml_backend_buffer_context_t;
 
     struct ggml_backend_buffer_i {
-        void   (*free_buffer)   (ggml_backend_buffer_t buffer);
-        void * (*get_base)      (ggml_backend_buffer_t buffer); // get base pointer
-        size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
-        void   (*init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
-        void   (*free_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
+        void     (*free_buffer)(ggml_backend_buffer_t buffer);
+        //void     (*reset)      (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+        void *   (*get_base)   (ggml_backend_buffer_t buffer);
+        void     (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        void     (*set_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void     (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        // (optional) copy tensor between different buffer-type, allow for single-copy tranfers
+        void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to)  (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
     };
 
     struct ggml_backend_buffer {
-        struct ggml_backend_buffer_i iface;
-
-        ggml_backend_t                backend;
+        struct ggml_backend_buffer_i  iface;
+        ggml_backend_buffer_type_t    buft;
         ggml_backend_buffer_context_t context;
-
         size_t size;
     };
 
-    GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
-            struct ggml_backend                  * backend,
+    ggml_backend_buffer_t ggml_backend_buffer_init(
+                   ggml_backend_buffer_type_t      buft,
             struct ggml_backend_buffer_i           iface,
                    ggml_backend_buffer_context_t   context,
                    size_t                          size);
 
+
     //
     // Backend
     //
@@ -49,20 +68,17 @@ extern "C" {
         void (*free)(ggml_backend_t backend);
 
         // buffer allocation
-        ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
+        ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
 
-        // get buffer alignment
-        size_t (*get_alignment)(ggml_backend_t backend);
-
-        // tensor data access
-        // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
+        // (optional) asynchroneous tensor data access
         void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
         void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        void (*synchronize)     (ggml_backend_t backend);
 
-        // (optional) copy tensor between different backends, allow for single-copy tranfers
-        void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
-        void (*cpy_tensor_to)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+        // (optional) asynchroneous tensor copy
+        void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+        void (*cpy_tensor_to_async)  (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+        void (*synchronize)     (ggml_backend_t backend);
 
         // compute graph with a plan
         ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
@@ -82,6 +98,15 @@ extern "C" {
         ggml_backend_context_t context;
     };
 
+
+    //
+    // Backend registry
+    //
+
+    typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
+
+    void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
+
 #ifdef  __cplusplus
 }
 #endif
index 856757463aae4f7779ddd7edff0c4ecf5beca772..3a22cd085eac0a4bbf5b96c602c49f6c6354b38a 100644 (file)
@@ -9,14 +9,36 @@
 #include <stdlib.h>
 #include <string.h>
 
-#define UNUSED GGML_UNUSED
 
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+
+// backend buffer type
+
+ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    return buft->iface.alloc_buffer(buft, size);
+}
+
+size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_alignment(buft);
+}
+
+size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
+    // get_alloc_size is optional, defaults to ggml_nbytes
+    if (buft->iface.get_alloc_size) {
+        return buft->iface.get_alloc_size(buft, tensor);
+    }
+    return ggml_nbytes(tensor);
+}
+
+bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return buft->iface.supports_backend(buft, backend);
+}
+
 // backend buffer
 
 ggml_backend_buffer_t ggml_backend_buffer_init(
-        struct ggml_backend                  * backend,
+               ggml_backend_buffer_type_t      buft,
         struct ggml_backend_buffer_i           iface,
                ggml_backend_buffer_context_t   context,
                size_t                          size) {
@@ -26,7 +48,7 @@ ggml_backend_buffer_t ggml_backend_buffer_init(
 
     (*buffer) = (struct ggml_backend_buffer) {
         /* .interface = */ iface,
-        /* .backend   = */ backend,
+        /* .buft      = */ buft,
         /* .context   = */ context,
         /* .size      = */ size,
     };
@@ -45,10 +67,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
     free(buffer);
 }
 
-size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
-    return ggml_backend_get_alignment(buffer->backend);
-}
-
 size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
     return buffer->size;
 }
@@ -61,14 +79,6 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
     return base;
 }
 
-size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
-    // get_alloc_size is optional, defaults to ggml_nbytes
-    if (buffer->iface.get_alloc_size) {
-        return buffer->iface.get_alloc_size(buffer, tensor);
-    }
-    return ggml_nbytes(tensor);
-}
-
 void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     // init_tensor is optional
     if (buffer->iface.init_tensor) {
@@ -76,19 +86,20 @@ void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_t
     }
 }
 
-void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
-    // free_tensor is optional
-    if (buffer->iface.free_tensor) {
-        buffer->iface.free_tensor(buffer, tensor);
-    }
+size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_get_alignment(ggml_backend_buffer_type(buffer));
 }
 
-// backend
+size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor);
+}
 
-ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
-    return tensor->buffer ? tensor->buffer->backend : NULL;
+ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) {
+    return buffer->buft;
 }
 
+// backend
+
 const char * ggml_backend_name(ggml_backend_t backend) {
     if (backend == NULL) {
         return "NULL";
@@ -104,43 +115,53 @@ void ggml_backend_free(ggml_backend_t backend) {
     backend->iface.free(backend);
 }
 
+ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+    return backend->iface.get_default_buffer_type(backend);
+}
+
 ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
-    return backend->iface.alloc_buffer(backend, size);
+    return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
 }
 
 size_t ggml_backend_get_alignment(ggml_backend_t backend) {
-    return backend->iface.get_alignment(backend);
+    return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
 }
 
-void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
+void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    backend->iface.set_tensor_async(backend, tensor, data, offset, size);
 }
 
-void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
+void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    backend->iface.get_tensor_async(backend, tensor, data, offset, size);
 }
 
 void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    ggml_backend_t backend = ggml_get_backend(tensor);
-
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-    GGML_ASSERT(backend != NULL && "tensor backend not set");
+    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
 
-    backend->iface.set_tensor_async(backend, tensor, data, offset, size);
-    backend->iface.synchronize(backend);
+    tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size);
 }
 
 void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    ggml_backend_t backend = ggml_get_backend(tensor);
-
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-    GGML_ASSERT(backend != NULL && "tensor backend not set");
+    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
 
-    backend->iface.get_tensor_async(backend, tensor, data, offset, size);
-    backend->iface.synchronize(backend);
+    tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size);
 }
 
 void ggml_backend_synchronize(ggml_backend_t backend) {
+    if (backend->iface.synchronize == NULL) {
+        return;
+    }
+
     backend->iface.synchronize(backend);
 }
 
@@ -154,10 +175,16 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
 
 void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
     backend->iface.graph_plan_compute(backend, plan);
+
+    // TODO: optional sync
+    ggml_backend_synchronize(backend);
 }
 
 void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     backend->iface.graph_compute(backend, cgraph);
+
+    // TODO: optional sync
+    ggml_backend_synchronize(backend);
 }
 
 bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -194,14 +221,15 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
 
     // TODO: allow backends to support copy to/from same backend
 
-    if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) {
-        ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst);
-    } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) {
-        ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst);
+    if (dst->buffer->iface.cpy_tensor_from != NULL) {
+        dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst);
+    } else if (src->buffer->iface.cpy_tensor_to != NULL) {
+        src->buffer->iface.cpy_tensor_to(src->buffer, src, dst);
     } else {
         // shouldn't be hit when copying from/to CPU
         #ifndef NDEBUG
-        fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend));
+        fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to "
+                        "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name);
         #endif
         size_t nbytes = ggml_nbytes(src);
         void * data = malloc(nbytes);
@@ -211,101 +239,259 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
     }
 }
 
-// backend CPU
+// backend registry
 
-struct ggml_backend_cpu_context {
-    int n_threads;
-    void * work_data;
-    size_t work_size;
+#define GGML_MAX_BACKENDS_REG 16
+
+struct ggml_backend_reg {
+    char name[128];
+    ggml_backend_init_fn init_fn;
+    ggml_backend_buffer_type_t default_buffer_type;
+    void * user_data;
 };
 
-static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
-    return "CPU";
+static struct ggml_backend_reg ggml_backend_registry[GGML_MAX_BACKENDS_REG];
+static size_t ggml_backend_registry_count = 0;
+
+static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
+
+static void ggml_backend_registry_init(void) {
+    static bool initialized = false;
+
+    if (initialized) {
+        return;
+    }
+
+    initialized = true;
 
-    UNUSED(backend);
+    ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL);
+
+    // add forward decls here to avoid including the backend headers
+#ifdef GGML_USE_CUBLAS
+    extern void ggml_backend_cuda_reg_devices(void);
+    ggml_backend_cuda_reg_devices();
+#endif
+
+#ifdef GGML_USE_METAL
+    extern ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
+    extern ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+    ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
+#endif
 }
 
-static void ggml_backend_cpu_free(ggml_backend_t backend) {
-    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
-    free(cpu_ctx->work_data);
-    free(cpu_ctx);
-    free(backend);
+void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
+    GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
+
+    int id = ggml_backend_registry_count;
+
+    ggml_backend_registry[id] = (struct ggml_backend_reg) {
+        /* .name                = */ {0},
+        /* .fn                  = */ init_fn,
+        /* .default_buffer_type = */ default_buffer_type,
+        /* .user_data           = */ user_data,
+    };
+
+    snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
+
+#ifndef NDEBUG
+    fprintf(stderr, "%s: registered backend %s\n", __func__, name);
+#endif
+
+    ggml_backend_registry_count++;
+}
+
+size_t ggml_backend_reg_get_count(void) {
+    ggml_backend_registry_init();
+
+    return ggml_backend_registry_count;
+}
+
+size_t ggml_backend_reg_find_by_name(const char * name) {
+    ggml_backend_registry_init();
+
+    for (size_t i = 0; i < ggml_backend_registry_count; i++) {
+        // TODO: case insensitive in a portable way
+        if (strcmp(ggml_backend_registry[i].name, name) == 0) {
+            return i;
+        }
+    }
+    return SIZE_MAX;
+}
+
+// init from backend:params string
+ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) {
+    ggml_backend_registry_init();
+
+    const char * params = strchr(backend_str, ':');
+    char backend_name[128];
+    if (params == NULL) {
+        strcpy(backend_name, backend_str);
+        params = "";
+    } else {
+        strncpy(backend_name, backend_str, params - backend_str);
+        backend_name[params - backend_str] = '\0';
+        params++;
+    }
+
+    size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
+    if (backend_i == SIZE_MAX) {
+        fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
+        return NULL;
+    }
+
+    return ggml_backend_reg_init_backend(backend_i, params);
+}
+
+const char * ggml_backend_reg_get_name(size_t i) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].name;
+}
+
+ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data);
+}
+
+ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].default_buffer_type;
+}
+
+ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size);
 }
 
+// backend CPU
+
 static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
     return (void *)buffer->context;
 }
 
 static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     free(buffer->context);
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy((char *)tensor->data + offset, data, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+
+    GGML_UNUSED(buffer);
 }
 
 static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
-    /* .free_buffer    = */ ggml_backend_cpu_buffer_free_buffer,
-    /* .get_base       = */ ggml_backend_cpu_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL, // no initialization required
-    /* .free_tensor    = */ NULL, // no cleanup required
+    /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_cpu_buffer_cpy_tensor_to,
 };
 
 // for buffers from ptr, free is not called
 static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
-    /* .free_buffer    = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
-    /* .get_base       = */ ggml_backend_cpu_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL,
-    /* .free_tensor    = */ NULL,
+    /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_cpu_buffer_cpy_tensor_to,
 };
 
 static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
 
-static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) {
+static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     size += TENSOR_ALIGNMENT;   // malloc may return an address that is not aligned
     void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
 
     GGML_ASSERT(data != NULL && "failed to allocate buffer");
 
-    return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
+    return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
 }
 
-static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return TENSOR_ALIGNMENT;
-    UNUSED(backend);
-}
 
-static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_UNUSED(buft);
+}
 
-    memcpy((char *)tensor->data + offset, data, size);
+static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_cpu(backend);
 
-    UNUSED(backend);
+    GGML_UNUSED(buft);
 }
 
-static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
-    memcpy(data, (const char *)tensor->data + offset, size);
+ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_cpu = {
+        /* .iface = */ {
+            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend,
+        },
+        /* .context = */ NULL,
+    };
 
-    UNUSED(backend);
+    return &ggml_backend_buffer_type_cpu;
 }
 
-static void ggml_backend_cpu_synchronize(ggml_backend_t backend) {
-    UNUSED(backend);
-}
+struct ggml_backend_cpu_context {
+    int n_threads;
+    void * work_data;
+    size_t work_size;
+};
 
-static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
+    return "CPU";
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
-static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+static void ggml_backend_cpu_free(ggml_backend_t backend) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+    free(cpu_ctx->work_data);
+    free(cpu_ctx);
+    free(backend);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_cpu_buffer_type();
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 struct ggml_backend_plan_cpu {
@@ -334,7 +520,7 @@ static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backen
     free(cpu_plan->cplan.work_data);
     free(cpu_plan);
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
@@ -342,7 +528,7 @@ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_bac
 
     ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
@@ -363,25 +549,25 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
 
 static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
     return true;
-    UNUSED(backend);
-    UNUSED(op);
+
+    GGML_UNUSED(backend);
+    GGML_UNUSED(op);
 }
 
 static struct ggml_backend_i cpu_backend_i = {
-    /* .get_name            = */ ggml_backend_cpu_name,
-    /* .free                = */ ggml_backend_cpu_free,
-    /* .alloc_buffer        = */ ggml_backend_cpu_alloc_buffer,
-    /* .get_alignment       = */ ggml_backend_cpu_get_alignment,
-    /* .set_tensor_async    = */ ggml_backend_cpu_set_tensor_async,
-    /* .get_tensor_async    = */ ggml_backend_cpu_get_tensor_async,
-    /* .synchronize         = */ ggml_backend_cpu_synchronize,
-    /* .cpy_tensor_from     = */ ggml_backend_cpu_cpy_tensor_from,
-    /* .cpy_tensor_to       = */ ggml_backend_cpu_cpy_tensor_to,
-    /* .graph_plan_create   = */ ggml_backend_cpu_graph_plan_create,
-    /* .graph_plan_free     = */ ggml_backend_cpu_graph_plan_free,
-    /* .graph_plan_compute  = */ ggml_backend_cpu_graph_plan_compute,
-    /* .graph_compute       = */ ggml_backend_cpu_graph_compute,
-    /* .supports_op         = */ ggml_backend_cpu_supports_op,
+    /* .get_name                = */ ggml_backend_cpu_name,
+    /* .free                    = */ ggml_backend_cpu_free,
+    /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
+    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
+    /* .supports_op             = */ ggml_backend_cpu_supports_op,
 };
 
 ggml_backend_t ggml_backend_cpu_init(void) {
@@ -411,10 +597,18 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
     ctx->n_threads = n_threads;
 }
 
-ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
-    return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
+ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
+    return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
+}
+
+static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
+    return ggml_backend_cpu_init();
+
+    GGML_UNUSED(params);
+    GGML_UNUSED(user_data);
 }
 
+
 // scheduler
 
 #define GGML_MAX_BACKENDS 4
@@ -427,7 +621,7 @@ struct ggml_backend_sched_split {
     int i_end;
     struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
     int n_inputs;
-    struct ggml_cgraph graph;
+    struct ggml_cgraph graph;
 };
 
 struct ggml_backend_sched {
@@ -453,7 +647,7 @@ struct ggml_backend_sched {
     #else
     __attribute__((aligned(GGML_MEM_ALIGN)))
     #endif
-    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
+    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
 };
 
 #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
@@ -482,23 +676,57 @@ static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr)
     return INT_MAX;
 }
 
+static ggml_backend_t get_buffer_backend(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) {
+    if (buffer == NULL) {
+        return NULL;
+    }
+    // find highest prio backend that supports the buffer type
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
+            return sched->backends[i];
+        }
+    }
+    GGML_ASSERT(false && "tensor buffer type not supported by any backend");
+}
+
+static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
+    if (allocr == NULL) {
+        return NULL;
+    }
+    // find highest prio backend that supports the buffer type
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->tallocs[i] == allocr) {
+            return sched->backends[i];
+        }
+    }
+    GGML_UNREACHABLE();
+}
+
+#if 0
+static char causes[GGML_DEFAULT_GRAPH_SIZE*8 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
+#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
+#define GET_CAUSE(node) causes[hash_id(node)]
+#else
+#define SET_CAUSE(node, ...)
+#define GET_CAUSE(node) ""
+#endif
+
 // returns the backend that should be used for the node based on the current locations
-char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
 static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
     // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
     // ie. kv cache updates
     // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
     // dst
-    ggml_backend_t cur_backend = ggml_get_backend(node);
+    ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
     if (cur_backend != NULL) {
-        sprintf(causes[hash_id(node)], "1.dst");
+        SET_CAUSE(node, "1.dst");
         return cur_backend;
     }
 
     // view_src
-    if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
-        sprintf(causes[hash_id(node)], "1.vsrc");
-        return ggml_get_backend(node->view_src);
+    if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) {
+        SET_CAUSE(node, "1.vsrc");
+        return get_buffer_backend(sched, node->view_src->buffer);
     }
 
     // src
@@ -510,7 +738,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
         if (src == NULL) {
             break;
         }
-        ggml_backend_t src_backend = ggml_get_backend(src);
+        ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
         if (src_backend != NULL) {
             int src_prio = sched_backend_prio(sched, src_backend);
             size_t src_size = ggml_nbytes(src);
@@ -518,7 +746,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
                 cur_prio = src_prio;
                 cur_size = src_size;
                 cur_backend = src_backend;
-                sprintf(causes[hash_id(node)], "1.src%d", i);
+                SET_CAUSE(node, "1.src%d", i);
             }
         }
     }
@@ -539,10 +767,12 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
     int cur_split = 0;
     for (int i = 0; i < graph->n_nodes; i++) {
         if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
-            ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
-            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
+            ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
+            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
+                sched->splits[cur_split].n_inputs);
             for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
-                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
+                    fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
             }
             fprintf(stderr, "\n");
             cur_split++;
@@ -552,16 +782,18 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
             continue;
         }
         ggml_tallocr_t node_allocr = node_allocr(node);
-        ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
-        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
+        ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
+        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name,
+            fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
                 break;
             }
             ggml_tallocr_t src_allocr = node_allocr(src);
-            ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
-            fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
+            ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
+            fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
+                fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
         }
         fprintf(stderr, "\n");
     }
@@ -587,9 +819,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
     sched->n_splits = 0;
 
     struct ggml_init_params params = {
-        /*.mem_size =   */ sizeof(sched->context_buffer),
-        /*.mem_buffer = */ sched->context_buffer,
-        /*.no_alloc =   */ true
+        /* .mem_size =   */ sizeof(sched->context_buffer),
+        /* .mem_buffer = */ sched->context_buffer,
+        /* .no_alloc =   */ true
     };
 
     if (sched->ctx != NULL) {
@@ -605,9 +837,9 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
             // do not overwrite user assignments
             continue;
         }
-        ggml_backend_t leaf_backend = ggml_get_backend(leaf);
+        ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer);
         if (leaf_backend == NULL && leaf->view_src != NULL) {
-            leaf_backend = ggml_get_backend(leaf->view_src);
+            leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer);
         }
         if (leaf_backend != NULL) {
             node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
@@ -649,7 +881,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
                         cur_prio = src_prio;
                         cur_size = src_size;
                         node_allocr = src_allocr;
-                        sprintf(causes[hash_id(node)], "2.src%d", j);
+                        SET_CAUSE(node, "2.src%d", j);
                     }
                 }
             }
@@ -733,7 +965,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
                     struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
                     sched->node_copies[id][cur_backend_id] = tensor_copy;
                     node_allocr(tensor_copy) = cur_allocr;
-                    ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
+                    ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
                     ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
                 }
                 node->src[j] = sched->node_copies[id][cur_backend_id];
@@ -761,8 +993,8 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
             ggml_tallocr_t src_allocr = node_allocr(src);
             if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
                 fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
-                    node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
-                    j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
+                    node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
+                    j, src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
             }
         }
     }
@@ -773,7 +1005,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
     struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &sched->splits[i];
-        split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
+        split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
 
         // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
         for (int j = 0; j < split->n_inputs; j++) {
@@ -806,31 +1038,29 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
 
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &splits[i];
-        ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
+        ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
         int split_backend_id = sched_backend_prio(sched, split_backend);
 
         // copy the input tensors to the split backend
         uint64_t copy_start_us = ggml_time_us();
         for (int j = 0; j < split->n_inputs; j++) {
-            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
-            if (split->inputs[j]->buffer == NULL) {
-                if (split->inputs[j]->view_src == NULL) {
-                    fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
+            struct ggml_tensor * input = split->inputs[j];
+            struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
+            if (input->buffer == NULL) {
+                if (input->view_src == NULL) {
+                    fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
                     exit(1);
                 }
-                struct ggml_tensor * view = split->inputs[j];
-                view->backend = view->view_src->backend;
-                view->buffer  = view->view_src->buffer;
-                view->data    = (char *)view->view_src->data + view->view_offs;
-                ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
+                // FIXME: may need to use the sched buffer instead
+                ggml_backend_view_init(input->view_src->buffer, input);
             }
             if (input_cpy->buffer == NULL) {
                 fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
                 exit(1);
             }
-            GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
-            GGML_ASSERT(input_cpy->buffer->backend == split_backend);
-            ggml_backend_tensor_copy(split->inputs[j], input_cpy);
+            //GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
+            //GGML_ASSERT(input_cpy->buffer->backend == split_backend);
+            ggml_backend_tensor_copy(input, input_cpy);
         }
         // ggml_backend_synchronize(split_backend);
         int64_t copy_end_us = ggml_time_us();
@@ -843,7 +1073,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
 #endif
 
         uint64_t compute_start_us = ggml_time_us();
-        ggml_backend_graph_compute(split_backend, split->graph);
+        ggml_backend_graph_compute(split_backend, &split->graph);
         // ggml_backend_synchronize(split_backend);
         uint64_t compute_end_us = ggml_time_us();
         compute_us[split_backend_id] += compute_end_us - compute_start_us;
@@ -872,8 +1102,6 @@ ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_bac
     struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
     memset(sched, 0, sizeof(struct ggml_backend_sched));
 
-    fprintf(stderr, "ggml_backend_sched size: %zu KB\n", sizeof(struct ggml_backend_sched)/1024);
-
     sched->n_backends = n_backends;
     for (int i = 0; i < n_backends; i++) {
         sched->backends[i] = backends[i];
@@ -948,3 +1176,182 @@ void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
     node_allocr(node) = sched->tallocs[backend_index];
 }
+
+// utils
+void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src != NULL);
+    GGML_ASSERT(tensor->view_src->buffer != NULL);
+    GGML_ASSERT(tensor->view_src->data != NULL);
+
+    tensor->buffer = buffer;
+    tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
+    tensor->backend = tensor->view_src->backend;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src == NULL);
+    GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+    GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+                (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+
+    tensor->buffer = buffer;
+    tensor->data = addr;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+static struct ggml_tensor * graph_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
+    struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
+
+    GGML_ASSERT(src != NULL);
+    GGML_ASSERT(src->data && "graph must be allocated");
+
+    size_t id = ggml_hash_insert(hash_set, src);
+    if (id == GGML_HASHTABLE_ALREADY_EXISTS) {
+        return node_copies[ggml_hash_find(hash_set, src)];
+    }
+
+    struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
+    if (src->view_src != NULL) {
+        dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
+        dst->view_offs = src->view_offs;
+    }
+    dst->op = src->op;
+    memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
+    ggml_set_name(dst, src->name);
+
+    // copy src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            break;
+        }
+        dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
+    }
+
+    node_copies[id] = dst;
+    return dst;
+}
+
+static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+    size_t id = ggml_hash_find(hash_set, src);
+    if (node_init[id]) {
+        return;
+    }
+    node_init[id] = true;
+
+    struct ggml_tensor * dst = node_copies[id];
+    if (dst->view_src != NULL) {
+        ggml_backend_view_init(dst->view_src->buffer, dst);
+    }
+    else {
+        ggml_backend_tensor_copy(src, dst);
+    }
+
+    // init src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            break;
+        }
+        graph_init_tensor(hash_set, node_copies, node_init, s);
+    }
+}
+
+struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+    struct ggml_hash_set hash_set = {
+        /* .size = */ graph->visited_hash_table.size,
+        /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
+    };
+    struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
+    bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
+
+    struct ggml_init_params params = {
+        /* .mem_size   = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ true
+    };
+
+    struct ggml_context * ctx_allocated = ggml_init(params);
+    struct ggml_context * ctx_unallocated = ggml_init(params);
+
+    // dup nodes
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
+    }
+
+    // allocate nodes
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
+
+    //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
+
+    // copy data and init views
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_init_tensor(hash_set, node_copies, node_init, node);
+    }
+
+    // build graph copy
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)];
+        graph_copy->nodes[i] = node_copy;
+    }
+    graph_copy->n_nodes = graph->n_nodes;
+
+    free(hash_set.keys);
+    free(node_copies);
+    free(node_init);
+
+    return (struct ggml_backend_graph_copy) {
+        /* .buffer           = */ buffer,
+        /* .ctx_allocated    = */ ctx_allocated,
+        /* .ctx_unallocated  = */ ctx_unallocated,
+        /* .graph            = */ graph_copy,
+    };
+}
+
+void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
+    ggml_backend_buffer_free(copy.buffer);
+    ggml_free(copy.ctx_allocated);
+    ggml_free(copy.ctx_unallocated);
+}
+
+void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+    struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
+    struct ggml_cgraph * g1 = graph;
+    struct ggml_cgraph * g2 = copy.graph;
+
+    assert(g1->n_nodes == g2->n_nodes);
+
+    for (int i = 0; i < g1->n_nodes; i++) {
+        //printf("eval %d/%d\n", i, g1->n_nodes);
+        struct ggml_tensor * t1 = g1->nodes[i];
+        struct ggml_tensor * t2 = g2->nodes[i];
+
+        assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+
+        struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+        struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+
+        ggml_backend_graph_compute(backend1, &g1v);
+        ggml_backend_graph_compute(backend2, &g2v);
+
+        if (ggml_is_view_op(t1->op)) {
+            continue;
+        }
+
+        // compare results, calculate rms etc
+        if (!callback(i, t1, t2, user_data)) {
+            break;
+        }
+    }
+
+    ggml_backend_graph_copy_free(copy);
+}
index 966687320ac96d971e9d635429d4ea1018a7ce57..58d5ccae6ed101ca80df9390386adc5329536722 100644 (file)
@@ -7,41 +7,44 @@
 extern "C" {
 #endif
 
+    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    typedef struct ggml_backend * ggml_backend_t;
+    typedef void * ggml_backend_graph_plan_t;
+
     //
     // Backend buffer
     //
 
-    struct ggml_backend_buffer;
-    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    // buffer type
+    GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
+    GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
+    GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
 
-    // backend buffer functions
+    // buffer
     GGML_API void   ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
-    GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
     GGML_API void * ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
     GGML_API size_t ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
-    GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
     GGML_API void   ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API void   ggml_backend_buffer_free_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
 
     //
     // Backend
     //
 
-    struct ggml_backend;
-    typedef struct ggml_backend * ggml_backend_t;
-    typedef void * ggml_backend_graph_plan_t;
-
-    GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
 
     GGML_API const char * ggml_backend_name(ggml_backend_t backend);
     GGML_API void         ggml_backend_free(ggml_backend_t backend);
 
-    GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
-
-    GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
+    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);
 
-    GGML_API void ggml_backend_tensor_set_async(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-    GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 
     GGML_API void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
@@ -57,6 +60,7 @@ extern "C" {
 
     // tensor copy between different backends
     GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
+    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
 
     //
     // CPU backend
@@ -68,8 +72,23 @@ extern "C" {
     GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
 
     // Create a backend buffer from an existing pointer
-    GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
+    GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+
+    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
 
+    //
+    // Backend registry
+    //
+
+    // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
+
+    GGML_API size_t                     ggml_backend_reg_get_count(void);
+    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
+    GGML_API const char *               ggml_backend_reg_get_name(size_t i);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
+    GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
+    GGML_API ggml_backend_buffer_t      ggml_backend_reg_alloc_buffer(size_t i, size_t size);
 
     //
     // Backend scheduler
@@ -131,6 +150,32 @@ extern "C" {
             ggml_backend_sched_t sched,
             struct ggml_cgraph * graph);
 
+
+    //
+    // Utils
+    //
+
+    struct ggml_backend_graph_copy {
+        ggml_backend_buffer_t buffer;
+        struct ggml_context * ctx_allocated;
+        struct ggml_context * ctx_unallocated;
+        struct ggml_cgraph * graph;
+    };
+
+    // Copy a graph to a different backend
+    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
+    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
+
+    typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+
+    // Compare the output of two backends
+    GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+
+    // Tensor initialization
+    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+    GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+
+
 #ifdef  __cplusplus
 }
 #endif
index e80b7a761bd1188ee2809ec284c316b971ece68b..85f7a293783bed95758dd729fb49fb6c97d27017 100644 (file)
@@ -1,7 +1,8 @@
 #include <algorithm>
-#include <cinttypes>
 #include <cstddef>
 #include <cstdint>
+#include <cinttypes>
+#include <float.h>
 #include <limits>
 #include <stdint.h>
 #include <stdio.h>
@@ -69,6 +70,7 @@
 #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
 #define cudaSetDevice hipSetDevice
 #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamFireAndForget hipStreamFireAndForget
 #define cudaStreamNonBlocking hipStreamNonBlocking
 #define cudaStreamSynchronize hipStreamSynchronize
 #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
@@ -190,7 +192,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
             fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
                 cudaGetErrorString(err_));                                              \
             fprintf(stderr, "current device: %d\n", id);                                \
-            exit(1);                                                                    \
+            GGML_ASSERT(!"CUDA error");                                                 \
         }                                                                               \
     } while (0)
 
@@ -204,7 +206,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
             fprintf(stderr, "current device: %d\n", id);                                \
-            exit(1);                                                                    \
+            GGML_ASSERT(!"cuBLAS error");                                               \
         }                                                                               \
     } while (0)
 #else
@@ -216,7 +218,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
             cudaGetDevice(&id);                                                         \
             fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);  \
             fprintf(stderr, "current device: %d\n", id);                                \
-            exit(1);                                                                    \
+            GGML_ASSERT(!"cuBLAS error");                                               \
         }                                                                               \
     } while (0)
 #endif // CUDART_VERSION >= 11
@@ -433,8 +435,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
-#define CUDA_ADD_BLOCK_SIZE 256
-#define CUDA_MUL_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_CLAMP_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
 #define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -501,40 +502,112 @@ static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= kx) {
-        return;
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
     }
-    dst[i] = x[i] + y[i%ky];
+    return x;
 }
 
-static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+    }
+    return a;
+}
 
-    if (i >= k) {
-        return;
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
     }
-    dst[i] = __hadd(x[i], __float2half(y[i]));
+    return x;
 }
 
-static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __device__ __forceinline__ float op_repeat(const float a, const float b) {
+    return b;
+}
 
-    if (i >= k) {
+static __device__ __forceinline__ float op_add(const float a, const float b) {
+    return a + b;
+}
+
+static __device__ __forceinline__ float op_mul(const float a, const float b) {
+    return a * b;
+}
+
+static __device__ __forceinline__ float op_div(const float a, const float b) {
+    return a / b;
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s10,*/ int s11, int s12, int s13) {
+    const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
+    const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
+    const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
+    const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
+
+    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
         return;
     }
-    dst[i] = __half2float(x[i]) + y[i];
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  = i_src0;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
+        const int i10 = i0 % ne10;
+        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    }
 }
 
-static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s10,*/ int s11, int s12, int s13) {
+
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i >= kx) {
+    const int i3 = i/(ne2*ne1*ne0);
+    const int i2 = (i/(ne1*ne0)) % ne2;
+    const int i1 = (i/ne0) % ne1;
+    const int i0 = i % ne0;
+
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
         return;
     }
-    dst[i] = x[i] * y[i%ky];
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  = i_src0;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    const int i10 = i0 % ne10;
+    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
 }
 
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
@@ -577,22 +650,11 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] * x[i];
 }
 
-static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
-        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
-    }
-    return a;
-}
-
 template <int block_size>
-static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
+static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
-    const float eps = 1e-5f;
-
     float2 mean_var = make_float2(0.f, 0.f);
 
     for (int col = tid; col < ncols; col += block_size) {
@@ -624,14 +686,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     }
 }
 
-static __device__ __forceinline__ float warp_reduce_sum(float x) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
-    }
-    return x;
-}
-
 template <int block_size>
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4550,6 +4604,116 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q8_0 * dsti = (block_q8_0 *) cdsti;
+
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = xi[j];
+        amax = fmaxf(amax, fabsf(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = xi[j]*id;
+
+        dsti->qs[j] = roundf(x0);
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_0 * dsti = (block_q4_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK4_0/2 + j]*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_1 * dsti = (block_q4_1 *) cdsti;
+
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = xi[j];
+
+        if (v < vmin) vmin = v;
+        if (v > vmax) vmax = v;
+    }
+
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->dm.x = d;
+    dsti->dm.y = vmin;
+
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (xi[0       + j] - vmin)*id;
+        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+                                 const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+                                 const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+    if (i >= ne) {
+        return;
+    }
+
+    const int i02 = i / (ne00*ne01);
+    const int i01 = (i - i02*ne01*ne00) / ne00;
+    const int i00 = (i - i02*ne01*ne00 - i01*ne00);
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+
+    const int i12 = i / (ne10*ne11);
+    const int i11 = (i - i12*ne10*ne11) / ne10;
+    const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+
+    cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
 static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
     const float y = (i0 / 2 - low) / max(0.001f, high - low);
     return 1.0f - min(1.0f, max(0.0f, y));
@@ -4610,8 +4774,8 @@ static __global__ void rope(
 
 template<typename T, bool has_pos>
 static __global__ void rope_neox(
-    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
 ) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
@@ -4620,23 +4784,25 @@ static __global__ void rope_neox(
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int i = row*ncols + col/2;
+    const int ib = col / n_dims;
+    const int ic = col % n_dims;
+
+    const int i = row*ncols + ib*n_dims + ic/2;
     const int i2 = row/p_delta_rows;
 
-    // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
-    const float cur_rot = -float(col)/ncols;
+    float cur_rot = inv_ndims * ic - ib;
 
     const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*powf(freq_base, cur_rot);
+    const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
 
     float cos_theta, sin_theta;
     rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
-    const float x1 = x[i + ncols/2];
+    const float x1 = x[i + n_dims/2];
 
-    dst[i + 0]       = x0*cos_theta - x1*sin_theta;
-    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
 static __global__ void rope_glm_f32(
@@ -4702,6 +4868,65 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
     dst[i] = col * m_k + x[i];
 }
 
+static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.y;
+    const int col = threadIdx.x;
+
+    float sum = 0.0f;
+    for (int i = col; i < ncols; i += blockDim.x) {
+        sum += x[row * ncols + i];
+    }
+
+    sum = warp_reduce_sum(sum);
+
+    if (col == 0) {
+        dst[row] = sum;
+    }
+}
+
+template<typename T>
+static inline __device__ void swap(T & a, T & b) {
+    T tmp = a;
+    a = b;
+    b = tmp;
+}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
+    // bitonic sort
+    int col = threadIdx.x;
+    int row = blockIdx.y;
+
+    if (col >= ncols) return;
+
+    const float * x_row = x + row * ncols;
+    int * dst_row = dst + row * ncols;
+
+    // initialize indices
+    if (col < ncols) {
+        dst_row[col] = col;
+    }
+    __syncthreads();
+
+    for (int k = 2; k <= ncols; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+                        swap(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+                        swap(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            __syncthreads();
+        }
+    }
+}
+
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
     const int col = blockDim.y*blockIdx.y + threadIdx.y;
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
@@ -4711,49 +4936,79 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
     }
 
     const int i = row*ncols + col;
-    // dst[i] = col > n_past + row ? -INFINITY : x[i];
-    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
-// the CUDA soft max implementation differs from the CPU implementation
-// instead of doubles floats are used
-static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int block_size = blockDim.y;
-    const int tid = threadIdx.y;
+static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+
+    const int block_size = blockDim.x;
+
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+
+    __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
 
     float max_val = -INFINITY;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
-        max_val = max(max_val, x[i]);
+        const int ix = rowx*ncols + col;
+        const int iy = rowy*ncols + col;
+        max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
     }
 
     // find the max value in the block
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
+    max_val = warp_reduce_max(max_val);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf[lane_id] = -INFINITY;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf[warp_id] = max_val;
+        }
+        __syncthreads();
+
+        max_val = buf[lane_id];
+        max_val = warp_reduce_max(max_val);
     }
 
     float tmp = 0.f;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
-        const float val = expf(x[i] - max_val);
+        const int ix = rowx*ncols + col;
+        const int iy = rowy*ncols + col;
+        const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
         tmp += val;
-        dst[i] = val;
+        dst[ix] = val;
     }
 
-    // sum up partial sums
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    // find the sum of exps in the block
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf[lane_id] = 0.f;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf[warp_id] = tmp;
+        }
+        __syncthreads();
+
+        tmp = buf[lane_id];
+        tmp = warp_reduce_sum(tmp);
     }
 
     const float inv_tmp = 1.f / tmp;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
+        const int i = rowx*ncols + col;
         dst[i] *= inv_tmp;
     }
 }
@@ -4805,25 +5060,119 @@ static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const
     k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
 }
 
-static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
-    const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
-}
-
-static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
-}
-
-static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
-}
+template<float (*bin_op)(const float, const float)>
+struct bin_bcast_cuda {
+    template<typename src0_t, typename src1_t, typename dst_t>
+    void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
+            const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+            cudaStream_t stream) {
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+
+        int nr0 = ne10/ne0;
+        int nr1 = ne11/ne1;
+        int nr2 = ne12/ne2;
+        int nr3 = ne13/ne3;
+
+        int nr[4] = { nr0, nr1, nr2, nr3 };
+
+        // collapse dimensions until first broadcast dimension
+        int64_t cne0[] = {ne0, ne1, ne2, ne3};
+        int64_t cne1[] = {ne10, ne11, ne12, ne13};
+        size_t cnb0[] = {nb0, nb1, nb2, nb3};
+        size_t cnb1[] = {nb10, nb11, nb12, nb13};
+        auto collapse = [](int64_t cne[]) {
+            cne[0] *= cne[1];
+            cne[1] = cne[2];
+            cne[2] = cne[3];
+            cne[3] = 1;
+        };
+
+        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+            cnb[1] *= cne[1];
+            cnb[2] *= cne[2];
+            cnb[3] *= cne[3];
+        };
+
+        for (int i = 0; i < 4; i++) {
+            if (nr[i] != 1) {
+                break;
+            }
+            if (i > 0) {
+                collapse_nb(cnb0, cne0);
+                collapse_nb(cnb1, cne1);
+                collapse(cne0);
+                collapse(cne1);
+            }
+        }
+        {
+            int64_t ne0 = cne0[0];
+            int64_t ne1 = cne0[1];
+            int64_t ne2 = cne0[2];
+            int64_t ne3 = cne0[3];
+
+            int64_t ne10 = cne1[0];
+            int64_t ne11 = cne1[1];
+            int64_t ne12 = cne1[2];
+            int64_t ne13 = cne1[3];
+
+            //size_t nb0 = cnb0[0];
+            size_t nb1 = cnb0[1];
+            size_t nb2 = cnb0[2];
+            size_t nb3 = cnb0[3];
+
+            //size_t nb10 = cnb1[0];
+            size_t nb11 = cnb1[1];
+            size_t nb12 = cnb1[2];
+            size_t nb13 = cnb1[3];
+
+            //size_t s0 = nb0 / sizeof(src1_t);
+            size_t s1 = nb1 / sizeof(src1_t);
+            size_t s2 = nb2 / sizeof(src1_t);
+            size_t s3 = nb3 / sizeof(src1_t);
+
+            //size_t s10 = nb10 / sizeof(src1_t);
+            size_t s11 = nb11 / sizeof(src1_t);
+            size_t s12 = nb12 / sizeof(src1_t);
+            size_t s13 = nb13 / sizeof(src1_t);
+
+
+            const int block_size = 128;
+
+            int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+            dim3 block_dims;
+            block_dims.x = std::min<unsigned int>(hne0, block_size);
+            block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
+            block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+            dim3 block_nums(
+                (hne0 + block_dims.x - 1) / block_dims.x,
+                (ne1 + block_dims.y - 1) / block_dims.y,
+                (ne2*ne3 + block_dims.z - 1) / block_dims.z
+            );
 
-static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
-    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
-    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
-}
+            if (block_nums.z > 65535) {
+                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+                k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s10, */ s11, s12, s13);
+            } else {
+                k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s10, */ s11, s12, s13);
+            }
+        }
+    }
+};
 
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
@@ -4845,14 +5194,14 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
     sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+        norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+        norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
     }
 }
 
@@ -4874,34 +5223,10 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
     quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
 }
 
-template<typename dst_t>
-static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+    dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
 template<typename dst_t>
@@ -4950,6 +5275,64 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
 #endif
 }
 
+static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_F32:
+            return dequantize_block_cuda<1, 1, convert_f32>;
+        default:
+            return nullptr;
+    }
+}
+
+static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_F16:
+            return dequantize_block_cuda<1, 1, convert_f16>;
+        default:
+            return nullptr;
+    }
+}
+
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -5038,13 +5421,22 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
     dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
-static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % QK4_0 == 0);
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
-        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+    dequantize_mul_mat_vec<1, 1, convert_f16>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK4_0 == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
 static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -5128,83 +5520,6 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
-static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
-    dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<1, 1, convert_f16>
-        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_row_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_row_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_row_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_row_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_row_q8_0_cuda;
-        case GGML_TYPE_Q2_K:
-            return dequantize_row_q2_K_cuda;
-        case GGML_TYPE_Q3_K:
-            return dequantize_row_q3_K_cuda;
-        case GGML_TYPE_Q4_K:
-            return dequantize_row_q4_K_cuda;
-        case GGML_TYPE_Q5_K:
-            return dequantize_row_q5_K_cuda;
-        case GGML_TYPE_Q6_K:
-            return dequantize_row_q6_K_cuda;
-        case GGML_TYPE_F32:
-            return convert_fp32_to_fp16_cuda;
-        default:
-            return nullptr;
-    }
-}
-
-static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_row_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_row_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_row_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_row_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_row_q8_0_cuda;
-        case GGML_TYPE_Q2_K:
-            return dequantize_row_q2_K_cuda;
-        case GGML_TYPE_Q3_K:
-            return dequantize_row_q3_K_cuda;
-        case GGML_TYPE_Q4_K:
-            return dequantize_row_q4_K_cuda;
-        case GGML_TYPE_Q5_K:
-            return dequantize_row_q5_K_cuda;
-        case GGML_TYPE_Q6_K:
-            return dequantize_row_q6_K_cuda;
-        case GGML_TYPE_F16:
-            return convert_fp16_to_fp32_cuda;
-        default:
-            return nullptr;
-    }
-}
-
 static void ggml_mul_mat_q4_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -5697,6 +6012,39 @@ static void ggml_cpy_f32_f16_cuda(
         (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
 }
 
+static void ggml_cpy_f32_q8_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK8_0 == 0);
+    const int num_blocks = ne / QK8_0;
+    cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_q4_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_0 == 0);
+    const int num_blocks = ne / QK4_0;
+    cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_q4_1_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_1 == 0);
+    const int num_blocks = ne / QK4_1;
+    cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
 static void ggml_cpy_f16_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -5739,20 +6087,26 @@ static void rope_cuda(
 
 template<typename T>
 static void rope_neox_cuda(
-    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
 ) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.0f / n_dims;
+
     if (pos == nullptr) {
         rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+            theta_scale, inv_ndims
         );
     } else {
         rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+            theta_scale, inv_ndims
         );
     }
 }
@@ -5777,6 +6131,27 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
     alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
 }
 
+static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(1, nrows, 1);
+    k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
+static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+    // bitonic sort requires ncols to be power of 2
+    GGML_ASSERT((ncols & (ncols - 1)) == 0);
+
+    const dim3 block_dims(ncols, 1, 1);
+    const dim3 block_nums(1, nrows, 1);
+    if (order == GGML_SORT_ASC) {
+        k_argsort_f32_i32<GGML_SORT_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+    } else if (order == GGML_SORT_DESC) {
+        k_argsort_f32_i32<GGML_SORT_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+    } else {
+        GGML_ASSERT(false);
+    }
+}
+
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
     const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -5784,10 +6159,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
 }
 
-static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
-    const dim3 block_dims(1, WARP_SIZE, 1);
+static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
+    int nth = WARP_SIZE;
+    while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+    const dim3 block_dims(nth,     1, 1);
     const dim3 block_nums(nrows_x, 1, 1);
-    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
+    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
 }
 
 static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -5867,7 +6244,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
         return ptr;
     }
 #ifdef DEBUG_CUDA_MALLOC
-    fprintf(stderr, "%s: %d buffers, max_size = %u MiB, tot_size = %u MiB, requested %u MiB\n", __func__, nnz,
+    fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
             (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
 #endif
     void * ptr;
@@ -6005,7 +6382,7 @@ void * ggml_cuda_host_malloc(size_t size) {
         // The allocation error can be bypassed. A null ptr will assigned out of this function.
         // This can fixed the OOM error in WSL.
         cudaGetLastError();
-        fprintf(stderr, "WARNING: failed to allocate %.2f MiB of pinned memory: %s\n",
+        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
             size/1024.0/1024.0, cudaGetErrorString(err));
         return nullptr;
     }
@@ -6064,63 +6441,6 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     }
 }
 
-static void ggml_cuda_op_repeat(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const int nr0 = (int)(ne0/ne00);
-    const int nr1 = (int)(ne1/ne01);
-    const int nr2 = (int)(ne2/ne02);
-    const int nr3 = (int)(ne3/ne03);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
-    for                         (int i3 = 0; i3 < nr3;  i3++) {
-        for                     (int k3 = 0; k3 < ne03; k3++) {
-            for                 (int i2 = 0; i2 < nr2;  i2++) {
-                for             (int k2 = 0; k2 < ne02; k2++) {
-                    for         (int i1 = 0; i1 < nr1;  i1++) {
-                        for     (int k1 = 0; k1 < ne01; k1++) {
-                            for (int i0 = 0; i0 < nr0;  i0++) {
-                                CUDA_CHECK(cudaMemcpyAsync(
-                                              (char *)  dst_d + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0,
-                                        (const char *) src0_d + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01,
-                                        ne00*nb0, cudaMemcpyDeviceToDevice, stream));
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    (void) src1;
-    (void) src1_d;
-}
-
 static void ggml_cuda_op_get_rows(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
@@ -6165,47 +6485,55 @@ static void ggml_cuda_op_get_rows(
     }
 }
 
-inline void ggml_cuda_op_add(
+template<class op>
+inline void ggml_cuda_op_bin_bcast(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
+        op()(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
+        op()(src0, src1, dst, (const half *) src0_dd, src1_dd, (half *) dst_dd, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
-        add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
+        op()(src0, src1, dst, (const half *) src0_dd, src1_dd, dst_dd, main_stream);
     } else {
-        fprintf(stderr, "src0->type: %d  dst->type: %d\n", src0->type, dst->type);
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
         GGML_ASSERT(false);
     }
+}
+
+static void ggml_cuda_op_repeat(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
+
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
 
     (void) src1;
-    (void) dst;
+    (void) src1_d;
 }
 
-inline void ggml_cuda_op_mul(
+inline void ggml_cuda_op_add(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
+inline void ggml_cuda_op_mul(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
 
-    (void) dst;
+inline void ggml_cuda_op_div(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
 inline void ggml_cuda_op_gelu(
@@ -6274,7 +6602,10 @@ inline void ggml_cuda_op_norm(
     const int64_t ne00 = src0->ne[0];
     const int64_t nrows = ggml_nrows(src0);
 
-    norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
 
     (void) src1;
     (void) dst;
@@ -6429,6 +6760,8 @@ inline void ggml_cuda_op_mul_mat_vec_q(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, const cudaStream_t & stream) {
 
+    GGML_ASSERT(ggml_nrows(src1) == 1);
+
     const int64_t ne00 = src0->ne[0];
     const int64_t row_diff = row_high - row_low;
 
@@ -6488,7 +6821,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
     size_t ash;
     dfloat * src1_dfloat = nullptr; // dfloat == half
 
-    bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+    bool src1_convert_f16 =
+        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
         src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
         src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
 
@@ -6710,15 +7044,14 @@ inline void ggml_cuda_op_rope(
         GGML_ASSERT(false);
         rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
     } else if (is_neox) {
-        GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_cuda(
-                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, main_stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
             rope_neox_cuda(
-                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
                 attn_factor, corr_dims, main_stream
             );
         } else {
@@ -6815,6 +7148,42 @@ inline void ggml_cuda_op_im2col(
     (void) src0_dd;
 }
 
+inline void ggml_cuda_op_sum_rows(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    sum_rows_f32_cuda(src0_dd, dst_dd, ncols, nrows, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_argsort(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+    argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6842,14 +7211,18 @@ inline void ggml_cuda_op_soft_max(
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
+    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
     const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t nrows_x = ggml_nrows(src0);
+    const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
 
-    soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
+    float scale = 1.0f;
+    memcpy(&scale, dst->op_params, sizeof(float));
+
+    soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
 
-    (void) src1;
     (void) dst;
-    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_scale(
@@ -7019,7 +7392,7 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
-    // const int64_t nrows0 = ggml_nrows(src0);
+    const int64_t nrows0 = ggml_nrows(src0);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
@@ -7055,10 +7428,9 @@ static void ggml_cuda_op_mul_mat(
 
     const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
     const bool src0_is_contiguous = ggml_is_contiguous(src0);
-
     const bool src1_is_contiguous = ggml_is_contiguous(src1);
-    const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
-        ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
+
+    const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
 
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
     GGML_ASSERT(!(split && ne02 > 1));
@@ -7183,7 +7555,7 @@ static void ggml_cuda_op_mul_mat(
                 const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
+                char  *  src0_dd_i =  src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
                 float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
                 char  * src1_ddq_i = src1_ddq[id] +  src1_ddq_i_offset;
                 float *   dst_dd_i =   dst_dd[id] + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -7328,6 +7700,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
 }
 
+static void ggml_cuda_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_div);
+}
+
 static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
 }
@@ -7353,7 +7729,7 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    if (!g_cublas_loaded) { return false; }
+    if (!g_cublas_loaded) return false;
 
     const int64_t ne10 = src1->ne[0];
 
@@ -7431,7 +7807,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
-__global__ static void k_compute_batched_ptrs(
+static __global__ void k_compute_batched_ptrs(
         const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
         const void ** ptrs_src, void ** ptrs_dst,
         int ne12, int ne13,
@@ -7487,9 +7863,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     CUDA_CHECK(ggml_cuda_set_device(g_main_device));
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
-    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
 
     ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -7546,7 +7920,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
         CUBLAS_CHECK(
-        cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+        cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
                 &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
                             (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
@@ -7580,7 +7954,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         CUDA_CHECK(cudaGetLastError());
 
         CUBLAS_CHECK(
-        cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+        cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
                 &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
                             (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
@@ -7650,10 +8024,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
 #ifdef GGML_CUDA_FORCE_DMMV
             const bool use_mul_mat_vec_q = false;
 #else
-            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
+            const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
 #endif // GGML_CUDA_FORCE_DMMV
 
             if (use_mul_mat_vec_q) {
+                // NOTE: this kernel does not support ggml_nrows(src1) > 1
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
             } else {
                 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
@@ -7678,42 +8053,255 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
     }
 }
 
-static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
-}
+#if 0
+template<typename ... Srcs>
+static __global__ void k_compute_batched_ptrs_id(
+        const void ** ptrs_src, void ** ptrs_dst,
+        int ne12, int ne13,
+        int ne23,
+        int nb02, int nb03,
+        int nb12, int nb13,
+        int nb2, int nb3,
+        int r2, int r3,
+        ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
+        const half * src1_f16, half * dst_f16,
+        const int32_t * ids, const int id,
+        Srcs... src0s) {
+
+    int i = ids[id];
+
+    half * src0_f16;
+    const void * srcs_ar[] = { (const half *) src0s... };
+    if (src0_type == GGML_TYPE_F16) {
+        src0_f16 = (half *) srcs_ar[i];
+    } else {
+        src0_f16 = src0_as_f16;
+        if (threadIdx.x == 0 && threadIdx.y == 0) {
+            const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
+            to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
+        }
+    }
 
-static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
+    int i13 = blockIdx.x * blockDim.x + threadIdx.x;
+    int i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (i13 >= ne13 || i12 >= ne12) {
+        return;
+    }
+
+    int i03 = i13 / r3;
+    int i02 = i12 / r2;
+
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02   + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)  dst_f16 + i12* nb2/2 + i13* nb3/2;
 }
 
-static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const int64_t ne = ggml_nelements(src0);
-    GGML_ASSERT(ne == ggml_nelements(src1));
+static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
+    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src00 = dst->src[2];
 
-    GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
-    GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
+    const int id = dst->op_params[0];
 
-    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
-    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+    GGML_ASSERT(!ggml_is_transposed(src00));
+    GGML_ASSERT(!ggml_is_transposed(src1));
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    GGML_ASSERT(src0->ne[3] == 1);
+    GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
-    const int64_t nb00 = src0->nb[0];
-    const int64_t nb01 = src0->nb[1];
-    const int64_t nb02 = src0->nb[2];
+    const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
+    const int64_t ne01 = src00->ne[1];
+    const int64_t ne02 = src00->ne[2];
+    const int64_t ne03 = src00->ne[3];
+
+    //const int64_t nb01 = src00->nb[1];
+    const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
+    const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
-    GGML_ASSERT(src1->ne[3] == 1);
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
-    const int64_t nb10 = src1->nb[0];
-    const int64_t nb11 = src1->nb[1];
-    const int64_t nb12 = src1->nb[2];
+    //const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
+    const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
 
-    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
-    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+    const int64_t ne1 = ggml_nelements(src1);
+    const int64_t ne  = ggml_nelements(dst);
+
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
+
+    //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    //void * src0_ddq = src0_extra->data_device[g_main_device];
+    //half * src0_as_f16 = (half *) src0_ddq;
+
+    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+    ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+    // convert src1 to fp16
+    const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+    GGML_ASSERT(to_fp16_cuda != nullptr);
+
+    size_t src1_as = 0;
+    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
+    to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
+
+    size_t dst_as = 0;
+    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+    const half alpha_f16 = 1.0f;
+    const half beta_f16  = 0.0f;
+
+    // use cublasGemmBatchedEx
+    const int ne23 = ne12*ne13;
+
+    const void ** ptrs_src = nullptr;
+          void ** ptrs_dst = nullptr;
+
+    size_t ptrs_src_s = 0;
+    size_t ptrs_dst_s = 0;
+
+    ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
+    ptrs_dst = (      void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
+
+    int64_t src0_ne = ggml_nelements(src00);
+    half * src0_as_f16 = nullptr;
+    size_t src0_as = 0;
+    if (src00->type != GGML_TYPE_F16) {
+        src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
+    }
+
+    static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
+    dim3 block_dims(ne13, ne12);
+    k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
+            ptrs_src, ptrs_dst,
+            ne12, ne13,
+            ne23,
+            ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
+            nb12, nb13,
+            dst->nb[2], dst->nb[3],
+            r2, r3,
+            src00->type, src0_as_f16, src0_ne,
+            src1_as_f16, dst_f16,
+            (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
+            dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
+            dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
+            dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
+            dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
+    );
+    CUDA_CHECK(cudaGetLastError());
+
+    CUBLAS_CHECK(
+    cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
+            ne01, ne11, ne10,
+            &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
+                        (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
+            &beta_f16,  (      void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
+            ne23,
+            CUBLAS_COMPUTE_16F,
+            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+    if (src0_as != 0) {
+        ggml_cuda_pool_free(src0_as_f16, src0_as);
+    }
+    if (ptrs_src_s != 0) {
+        ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
+    }
+    if (ptrs_dst_s != 0) {
+        ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
+    }
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+    to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+
+    ggml_cuda_pool_free(src1_as_f16, src1_as);
+    ggml_cuda_pool_free(dst_f16, dst_as);
+}
+#endif
+
+static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+#if 0
+//#ifdef CUDA_USE_TENSOR_CORES
+//    const bool use_tensor_cores = true;
+//#else
+//    const bool use_tensor_cores = false;
+//#endif
+
+    ggml_cuda_mul_mat_id_cublas(dst);
+
+    // TODO: mmq/mmv support
+#else
+    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const int id = dst->op_params[0];
+
+    int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+
+    int32_t a_id;
+    CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+
+    ggml_cuda_mul_mat(src0, src1, dst);
+#endif
+
+    (void) _src0;
+    (void) _src1;
+}
+
+static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
+}
+
+static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
+}
+
+static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne == ggml_nelements(src1));
+
+    GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
+    GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
+
+    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    GGML_ASSERT(src0->ne[3] == 1);
+
+    const int64_t nb00 = src0->nb[0];
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    GGML_ASSERT(src1->ne[3] == 1);
+
+    const int64_t nb10 = src1->nb[0];
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2];
+
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
     const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -7722,14 +8310,17 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
 
     if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
-                              ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7740,6 +8331,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
 }
 
 static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    // TODO: why do we pass dst as src1 here?
     ggml_cuda_cpy(src0, dst, nullptr);
     (void) src1;
 }
@@ -7765,6 +8357,16 @@ static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
 }
 
+static void ggml_cuda_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sum_rows);
+}
+
+static void ggml_cuda_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_argsort);
+}
+
 static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -8020,8 +8622,9 @@ void ggml_cuda_set_main_device(const int main_device) {
                 main_device, g_device_count, g_main_device);
         return;
     }
-    g_main_device = main_device;
-    if (g_device_count > 1) {
+
+    if (g_main_device != main_device && g_device_count > 1) {
+        g_main_device = main_device;
         cudaDeviceProp prop;
         CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
         fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
@@ -8047,7 +8650,7 @@ void ggml_cuda_free_scratch() {
 }
 
 bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
-    if (!g_cublas_loaded) { return false; }
+    if (!g_cublas_loaded) return false;
 
     ggml_cuda_func_t func;
     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
@@ -8083,6 +8686,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_MUL:
             func = ggml_cuda_mul;
             break;
+        case GGML_OP_DIV:
+            func = ggml_cuda_div;
+            break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(tensor)) {
                 case GGML_UNARY_OP_GELU:
@@ -8096,7 +8702,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
                     break;
                 default:
                     return false;
-            } break;
+            }
+            break;
         case GGML_OP_NORM:
             func = ggml_cuda_norm;
             break;
@@ -8109,6 +8716,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_mul_mat;
             break;
+        case GGML_OP_MUL_MAT_ID:
+            if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
+                return false;
+            }
+            func = ggml_cuda_mul_mat_id;
+            break;
         case GGML_OP_SCALE:
             func = ggml_cuda_scale;
             break;
@@ -8148,6 +8761,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_IM2COL:
             func = ggml_cuda_im2col;
             break;
+        case GGML_OP_SUM_ROWS:
+            func = ggml_cuda_sum_rows;
+            break;
+        case GGML_OP_ARGSORT:
+            func = ggml_cuda_argsort;
+            break;
         default:
             return false;
     }
@@ -8164,7 +8783,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
 
 int ggml_cuda_get_device_count() {
     int device_count;
-    CUDA_CHECK(cudaGetDeviceCount(&device_count));
+    if (cudaGetDeviceCount(&device_count) != cudaSuccess) {
+        return 0;
+    }
     return device_count;
 }
 
@@ -8180,27 +8801,16 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des
 
 #define UNUSED GGML_UNUSED
 
-struct ggml_backend_context_cuda {
-};
-
-static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
-    return GGML_CUDA_NAME;
-
-    UNUSED(backend);
-}
-
-static void ggml_backend_cuda_free(ggml_backend_t backend) {
-    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
-    delete cuda_ctx;
-    delete backend;
-}
+// cuda buffer
 
 struct ggml_backend_buffer_context_cuda {
-    void * device;
-
+    int device;
+    void * dev_ptr = nullptr;
     ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
     size_t temp_tensor_extra_index = 0;
 
+    ggml_backend_buffer_context_cuda(int device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}
+
     ~ggml_backend_buffer_context_cuda() {
         delete[] temp_tensor_extras;
     }
@@ -8221,41 +8831,20 @@ struct ggml_backend_buffer_context_cuda {
 
 static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
-    CUDA_CHECK(cudaFree(ctx->device));
+    CUDA_CHECK(cudaFree(ctx->dev_ptr));
     delete ctx;
 }
 
 static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
     ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
-    return ctx->device;
-}
-
-static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
-    int64_t row_low = 0;
-    int64_t row_high = ggml_nrows(tensor);
-    int64_t nrows_split = row_high - row_low;
-
-    size_t size = ggml_nbytes_split(tensor, nrows_split);
-
-    int64_t ne0 = tensor->ne[0];
-
-    if (ggml_is_quantized(tensor->type)) {
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
-                * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
-        }
-    }
-
-    return size;
-
-    UNUSED(buffer);
+    return ctx->dev_ptr;
 }
 
 static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
     ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
 
     if (tensor->view_src != NULL && tensor->view_offs == 0) {
-        assert(tensor->view_src->buffer->backend == buffer->backend);
+        assert(tensor->view_src->buffer->buft == buffer->buft); // TODO
         tensor->backend = tensor->view_src->backend;
         tensor->extra = tensor->view_src->extra;
         return;
@@ -8263,7 +8852,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
 
     ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra();
 
-    extra->data_device[g_main_device] = tensor->data;
+    extra->data_device[ctx->device] = tensor->data;
 
     tensor->backend = GGML_BACKEND_GPU;
     tensor->extra = extra;
@@ -8275,64 +8864,208 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
         int64_t nrows_split = row_high - row_low;
 
         size_t original_size = ggml_nbytes_split(tensor, nrows_split);
-        size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor);
+        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
 
         if (padded_size > original_size && tensor->view_src == nullptr) {
-            CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0]));
+            CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[ctx->device][0]));
         }
     }
 
     UNUSED(buffer);
 }
 
+static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+
+    CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
+
+    CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost));
+
+    UNUSED(buffer);
+}
+
 static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
-    /* .free_buffer    = */ ggml_backend_cuda_buffer_free_buffer,
-    /* .get_base       = */ ggml_backend_cuda_buffer_get_base,
-    /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size,
-    /* .init_tensor    = */ ggml_backend_cuda_buffer_init_tensor,
-    /* .free_tensor    = */ NULL,
+    /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cuda_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,
+    /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,
+    /* .cpy_tensor_from = */ NULL,
+    /* .cpy_tensor_to   = */ NULL,
 };
 
-static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) {
-    ggml_cuda_set_device(g_main_device);
+// cuda buffer type
 
-    ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
+static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    int device = (int) (intptr_t) buft->context;
+
+    ggml_cuda_set_device(device);
 
     size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
 
-    ggml_cuda_set_device(g_main_device);
-    CUDA_CHECK(cudaMalloc(&ctx->device, size));
+    void * dev_ptr;
+    CUDA_CHECK(cudaMalloc(&dev_ptr, size));
 
-    return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size);
+    ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda(device, dev_ptr);
+
+    return ggml_backend_buffer_init(buft, cuda_backend_buffer_interface, ctx, size);
 }
 
-static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return 128;
+
+    UNUSED(buft);
+}
+
+static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, ggml_tensor * tensor) {
+    int64_t row_low = 0;
+    int64_t row_high = ggml_nrows(tensor);
+    int64_t nrows_split = row_high - row_low;
+
+    size_t size = ggml_nbytes_split(tensor, nrows_split);
+
+    int64_t ne0 = tensor->ne[0];
+
+    if (ggml_is_quantized(tensor->type)) {
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
+                * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
+        }
+    }
+
+    return size;
+
+    UNUSED(buft);
+}
+
+static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_cuda(backend);
+
+    UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i cuda_backend_buffer_type_interface = {
+    /* .alloc_buffer     = */ ggml_backend_cuda_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cuda_buffer_type_get_alignment,
+    /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,
+    /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda[GGML_CUDA_MAX_DEVICES];
+    static bool ggml_backend_buffer_type_cuda_initialized = false;
+    if (!ggml_backend_buffer_type_cuda_initialized) {
+        for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) {
+            ggml_backend_buffer_type_cuda[i] = {
+                /* .iface    = */ cuda_backend_buffer_type_interface,
+                /* .context  = */ (ggml_backend_buffer_type_context_t) (intptr_t) i,
+            };
+        }
+        ggml_backend_buffer_type_cuda_initialized = true;
+    }
+
+    return &ggml_backend_buffer_type_cuda[device];
+}
+
+// host buffer type
+
+static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
+    CUDA_CHECK(cudaFreeHost(ctx->dev_ptr));
+    delete ctx;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * ptr;
+    CUDA_CHECK(cudaMallocHost(&ptr, size));
+
+    // FIXME: this is a hack to avoid having to implement a new buffer type
+    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+    buffer->buft = buft;
+    buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
+
+    return buffer;
+
+    UNUSED(buft);
+}
+
+struct ggml_backend_buffer_type_i cuda_backend_host_buffer_type_interface = {
+    /* .alloc_buffer     = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+    /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+    /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda_host = {
+        /* .iface    = */ cuda_backend_host_buffer_type_interface,
+        /* .context  = */ nullptr,
+    };
+
+    return &ggml_backend_buffer_type_cuda_host;
+}
+
+// backend
+
+struct ggml_backend_context_cuda {
+    int device;
+};
+
+static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
+    return GGML_CUDA_NAME;
+
     UNUSED(backend);
 }
 
+static void ggml_backend_cuda_free(ggml_backend_t backend) {
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    delete cuda_ctx;
+    delete backend;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) {
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    return ggml_backend_cuda_buffer_type(cuda_ctx->device);
+}
+
 static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
     GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
 
-    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0]));
-
-    UNUSED(backend);
+    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0]));
 }
 
 static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
     GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
 
-    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
-
-    UNUSED(backend);
+    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0]));
 }
 
 static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
-    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[cuda_ctx->device][0]));
 
     UNUSED(backend);
 }
@@ -8346,14 +9079,14 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
     UNUSED(cgraph);
 }
 
-[[noreturn]] static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
     GGML_ASSERT(!"not implemented");
 
     UNUSED(backend);
     UNUSED(plan);
 }
 
-[[noreturn]] static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
     GGML_ASSERT(!"not implemented");
 
     UNUSED(backend);
@@ -8361,7 +9094,9 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
 }
 
 static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_cuda_set_device(g_main_device);
+    ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
+
+    ggml_cuda_set_main_device(cuda_ctx->device);
 
     ggml_compute_params params = {};
     params.type = GGML_TASK_COMPUTE;
@@ -8369,13 +9104,18 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
-        if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) {
+        if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE)
             continue;
-        }
+
         assert(node->backend == GGML_BACKEND_GPU);
+        assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+        assert(node->extra != nullptr);
+
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             if (node->src[j] != nullptr) {
                 assert(node->src[j]->backend == GGML_BACKEND_GPU);
+                assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+                assert(node->src[j]->extra != nullptr);
             }
         }
 
@@ -8412,27 +9152,98 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
     UNUSED(backend);
 }
 
+static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                    return true;
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+            {
+                struct ggml_tensor * a;
+                struct ggml_tensor * b;
+                if (op->op == GGML_OP_MUL_MAT) {
+                    a = op->src[0];
+                    b = op->src[1];
+                } else {
+                    a = op->src[2];
+                    b = op->src[1];
+                }
+                if (a->ne[3] != b->ne[3]) {
+                    return false;
+                }
+                return true;
+            } break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_NORM:
+        case GGML_OP_REPEAT:
+        case GGML_OP_GET_ROWS:
+        case GGML_OP_DUP:
+        case GGML_OP_ADD:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_CLAMP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_ROPE:
+        case GGML_OP_ALIBI:
+        case GGML_OP_IM2COL:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_ARGSORT:
+            return true;
+        default:
+            return false;
+    }
+
+    UNUSED(backend);
+}
+
 static ggml_backend_i cuda_backend_i = {
-    /* .get_name            = */ ggml_backend_cuda_name,
-    /* .free                = */ ggml_backend_cuda_free,
-    /* .alloc_buffer        = */ ggml_backend_cuda_alloc_buffer,
-    /* .get_alignment       = */ ggml_backend_cuda_get_alignment,
-    /* .set_tensor_async    = */ ggml_backend_cuda_set_tensor_async,
-    /* .get_tensor_async    = */ ggml_backend_cuda_get_tensor_async,
-    /* .synchronize         = */ ggml_backend_cuda_synchronize,
-    /* .cpy_tensor_from     = */ nullptr,
-    /* .cpy_tensor_to       = */ nullptr,
-    /* .graph_plan_create   = */ ggml_backend_cuda_graph_plan_create,
-    /* .graph_plan_free     = */ ggml_backend_cuda_graph_plan_free,
-    /* .graph_plan_compute  = */ ggml_backend_cuda_graph_plan_compute,
-    /* .graph_compute       = */ ggml_backend_cuda_graph_compute,
-    /* .supports_op         = */ nullptr,
+    /* .get_name                = */ ggml_backend_cuda_name,
+    /* .free                    = */ ggml_backend_cuda_free,
+    /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type,
+    /* .set_tensor_async        = */ ggml_backend_cuda_set_tensor_async,
+    /* .get_tensor_async        = */ ggml_backend_cuda_get_tensor_async,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ ggml_backend_cuda_synchronize,
+    /* .graph_plan_create       = */ ggml_backend_cuda_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cuda_graph_plan_free,
+    /* .graph_plan_compute      = */ ggml_backend_cuda_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
+    /* .supports_op             = */ ggml_backend_cuda_supports_op,
 };
 
-ggml_backend_t ggml_backend_cuda_init() {
+ggml_backend_t ggml_backend_cuda_init(int device) {
     ggml_init_cublas(); // TODO: remove from ggml.c
 
-    ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda;
+    if (device < 0 || device >= ggml_cuda_get_device_count()) {
+        fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
+        return nullptr;
+    }
+
+    // not strictly necessary, but it may reduce the overhead of the first graph_compute
+    ggml_cuda_set_main_device(device);
+
+    ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda {
+        /* .device = */ device
+    };
 
     ggml_backend_t cuda_backend = new ggml_backend {
         /* .interface = */ cuda_backend_i,
@@ -8441,3 +9252,25 @@ ggml_backend_t ggml_backend_cuda_init() {
 
     return cuda_backend;
 }
+
+bool ggml_backend_is_cuda(ggml_backend_t backend) {
+    return backend->iface.get_name == ggml_backend_cuda_name;
+}
+
+static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
+    ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
+    return cuda_backend;
+
+    UNUSED(params);
+}
+
+extern "C" int ggml_backend_cuda_reg_devices() {
+    int device_count = ggml_cuda_get_device_count();
+    //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
+    for (int i = 0; i < device_count; i++) {
+        char name[128];
+        snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i);
+        ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i);
+    }
+    return device_count;
+}
index 528e66c33a20738ce185744ff5780203c869e4ad..cdb0c0c41618a0692bdad716af60b64cd0b70ad2 100644 (file)
@@ -49,7 +49,15 @@ GGML_API int    ggml_cuda_get_device_count(void);
 GGML_API void   ggml_cuda_get_device_description(int device, char * description, size_t description_size);
 
 // backend API
-GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
+GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
+
+GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
+GGML_API int  ggml_backend_cuda_get_device(ggml_backend_t backend);
+
+GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+
+// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
+GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
 
 #ifdef  __cplusplus
 }
index 06c07339e926999ed1688da0cf40d21ab08ae8ae..1f5610a86cfd9e8347952dbf415a176f1fd23baf 100644 (file)
@@ -232,7 +232,7 @@ bool   ggml_hash_contains      (const struct ggml_hash_set hash_set, struct ggml
 // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
 size_t ggml_hash_find          (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
 
-// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
+// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
 size_t ggml_hash_insert        (      struct ggml_hash_set hash_set, struct ggml_tensor * key);
 
 // return index, asserts if table is full
index de46b8804dc69fcdffb9df4d578efcf4dbaa4f95..bf52d9cd34da48246be9f343b89f4557b557296b 100644 (file)
@@ -52,11 +52,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx);
 void * ggml_metal_host_malloc(size_t n);
 void   ggml_metal_host_free  (void * data);
 
-// helper to check if the device supports a specific family
-// ideally, the user code should be doing these checks
-// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
-bool ggml_metal_supports_family(struct ggml_metal_context * ctx, int family);
-
 // set the number of command buffers to use
 void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
 
@@ -104,7 +99,11 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
 GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
 GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
+GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
 
+// helper to check if the device supports a specific family
+// ideally, the user code should be doing these checks
+// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
 GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
 
 #ifdef __cplusplus
index 51382adea57aaf500ef7672995e42a2cc5e54828..f9bd69dc84bbe78128c5b9a01917decc40b627f0 100644 (file)
@@ -62,6 +62,8 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
+    GGML_METAL_DECL_KERNEL(div);
+    GGML_METAL_DECL_KERNEL(div_row);
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(scale_4);
     GGML_METAL_DECL_KERNEL(silu);
@@ -112,15 +114,35 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
     GGML_METAL_DECL_KERNEL(rope_f32);
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(im2col_f16);
+    GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
+    GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
+    GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
+    //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
+    //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
     GGML_METAL_DECL_KERNEL(concat);
     GGML_METAL_DECL_KERNEL(sqr);
+    GGML_METAL_DECL_KERNEL(sum_rows);
 
 #undef GGML_METAL_DECL_KERNEL
 };
@@ -164,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
     }
 }
 
-
-
 struct ggml_metal_context * ggml_metal_init(int n_cb) {
     GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
 
-    id <MTLDevice> device;
+    id<MTLDevice> device;
     NSString * s;
 
 #if TARGET_OS_OSX
@@ -215,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
             NSString * sourcePath;
             NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+
+            GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
+
             if (ggmlMetalPathResources) {
                 sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
             } else {
@@ -245,6 +268,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         }
     }
 
+#if TARGET_OS_OSX
+    // print MTL GPU family:
+    GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
+
+    // determine max supported GPU family
+    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+    for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
+        if ([ctx->device supportsFamily:i]) {
+            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
+            break;
+        }
+    }
+
+    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
+    if (ctx->device.maxTransferRate != 0) {
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
+    } else {
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
+    }
+#endif
+
     // load kernels
     {
         NSError * error = nil;
@@ -266,6 +312,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(add_row);
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul_row);
+        GGML_METAL_ADD_KERNEL(div);
+        GGML_METAL_ADD_KERNEL(div_row);
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(scale_4);
         GGML_METAL_ADD_KERNEL(silu);
@@ -317,43 +365,40 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
             GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
         }
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(im2col_f16);
+        GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
+        GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
+        GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
+        //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
+        //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
         GGML_METAL_ADD_KERNEL(concat);
         GGML_METAL_ADD_KERNEL(sqr);
+        GGML_METAL_ADD_KERNEL(sum_rows);
 
 #undef GGML_METAL_ADD_KERNEL
     }
 
-#if TARGET_OS_OSX
-    // print MTL GPU family:
-    GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
-
-    // determine max supported GPU family
-    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
-    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
-    for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
-        if ([ctx->device supportsFamily:i]) {
-            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
-            break;
-        }
-    }
-
-    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
-    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
-    if (ctx->device.maxTransferRate != 0) {
-        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
-    } else {
-        GGML_METAL_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
-    }
-#endif
-
     return ctx;
 }
 
@@ -367,6 +412,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(add_row);
     GGML_METAL_DEL_KERNEL(mul);
     GGML_METAL_DEL_KERNEL(mul_row);
+    GGML_METAL_DEL_KERNEL(div);
+    GGML_METAL_DEL_KERNEL(div_row);
     GGML_METAL_DEL_KERNEL(scale);
     GGML_METAL_DEL_KERNEL(scale_4);
     GGML_METAL_DEL_KERNEL(silu);
@@ -418,16 +465,36 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
         GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
     }
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(im2col_f16);
+    GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
+    GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
+    GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
+    //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
+    //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
     GGML_METAL_DEL_KERNEL(concat);
     GGML_METAL_DEL_KERNEL(sqr);
+    GGML_METAL_DEL_KERNEL(sum_rows);
 
 #undef GGML_METAL_DEL_KERNEL
 
@@ -459,10 +526,6 @@ void ggml_metal_host_free(void * data) {
     free(data);
 }
 
-bool ggml_metal_supports_family(struct ggml_metal_context * ctx, int family) {
-    return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
-}
-
 void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
     ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
 }
@@ -475,6 +538,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
     return ctx->concur_list;
 }
 
+// temporarily defined here for compatibility between ggml-backend and the old API
+struct ggml_backend_metal_buffer_context {
+    void * data;
+
+    id<MTLBuffer> metal;
+};
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -484,8 +554,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
 
     const int64_t tsize = ggml_nbytes(t);
 
-    if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
-        ctx = t->buffer->backend->context;
+    // compatibility with ggml-backend
+    if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
+        struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
+
+        const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
+
+        GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
+
+        *offs = (size_t) ioffs;
+
+        return buf_ctx->metal;
     }
 
     // find the view that contains the tensor fully
@@ -545,11 +624,11 @@ bool ggml_metal_add_buffer(
             ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
             if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1e6);
+                GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
                 return false;
             }
 
-            GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1e6);
+            GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
 
             ++ctx->n_buffers;
         } else {
@@ -569,11 +648,11 @@ bool ggml_metal_add_buffer(
                 ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
                 if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                    GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1e6);
+                    GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
                     return false;
                 }
 
-                GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1e6, i);
+                GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
                 if (i + size_step < size) {
                     GGML_METAL_LOG_INFO("\n");
                 }
@@ -584,8 +663,8 @@ bool ggml_metal_add_buffer(
 
 #if TARGET_OS_OSX
         GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
-                ctx->device.currentAllocatedSize / 1e6,
-                ctx->device.recommendedMaxWorkingSetSize / 1e6);
+                ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
+                ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
             GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
@@ -593,7 +672,7 @@ bool ggml_metal_add_buffer(
             GGML_METAL_LOG_INFO("\n");
         }
 #else
-        GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1e6);
+        GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
 #endif
     }
 
@@ -710,6 +789,51 @@ void ggml_metal_graph_find_concurrency(
     }
 }
 
+static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_GELU:
+                    return true;
+                default:
+                    return false;
+            }
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_CONCAT:
+        case GGML_OP_ADD:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_NORM:
+        case GGML_OP_ALIBI:
+        case GGML_OP_ROPE:
+        case GGML_OP_IM2COL:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_DUP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+            return true;
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_GET_ROWS:
+            {
+                return op->ne[0] % 4 == 0;
+            }
+        default:
+            return false;
+    }
+}
 void ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
@@ -780,6 +904,8 @@ void ggml_metal_graph_compute(
                         } break;
                 }
 
+                GGML_ASSERT(ggml_metal_supports_op(dst));
+
                 const int64_t  ne00 = src0 ? src0->ne[0] : 0;
                 const int64_t  ne01 = src0 ? src0->ne[1] : 0;
                 const int64_t  ne02 = src0 ? src0->ne[2] : 0;
@@ -872,6 +998,8 @@ void ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ADD:
+                    case GGML_OP_MUL:
+                    case GGML_OP_DIV:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
                             GGML_ASSERT(ggml_is_contiguous(src1));
@@ -885,11 +1013,21 @@ void ggml_metal_graph_compute(
                                 GGML_ASSERT(ne11 == 1);
 
                                 nb = ne00 / 4;
-                                [encoder setComputePipelineState:ctx->pipeline_add_row];
+                                switch (dst->op) {
+                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
+                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
+                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+                                    default: GGML_ASSERT(false);
+                                }
 
                                 bcast_row = true;
                             } else {
-                                [encoder setComputePipelineState:ctx->pipeline_add];
+                                switch (dst->op) {
+                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
+                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
+                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+                                    default: GGML_ASSERT(false);
+                                }
                             }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -930,31 +1068,6 @@ void ggml_metal_graph_compute(
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
                         } break;
-                    case GGML_OP_MUL:
-                        {
-                            GGML_ASSERT(ggml_is_contiguous(src0));
-                            GGML_ASSERT(ggml_is_contiguous(src1));
-
-                            // utilize float4
-                            GGML_ASSERT(ne00 % 4 == 0);
-                            const int64_t nb = ne00/4;
-
-                            if (ggml_nelements(src1) == ne10) {
-                                // src1 is a row
-                                GGML_ASSERT(ne11 == 1);
-                                [encoder setComputePipelineState:ctx->pipeline_mul_row];
-                            } else {
-                                [encoder setComputePipelineState:ctx->pipeline_mul];
-                            }
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&nb     length:sizeof(nb) atIndex:3];
-
-                            const int64_t n = ggml_nelements(dst)/4;
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        } break;
                     case GGML_OP_SCALE:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1027,25 +1140,66 @@ void ggml_metal_graph_compute(
                             const int64_t n = ggml_nelements(dst);
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
+                    case GGML_OP_SUM_ROWS:
+                        {
+                            GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+
+                            [encoder setComputePipelineState:ctx->pipeline_sum_rows];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
                     case GGML_OP_SOFT_MAX:
                         {
                             int nth = 32; // SIMD width
 
                             if (ne00%4 == 0) {
+                                while (nth < ne00/4 && nth < 256) {
+                                    nth *= 2;
+                                }
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
                             } else {
-                                do {
+                                while (nth < ne00 && nth < 1024) {
                                     nth *= 2;
-                                } while (nth <= ne00 && nth <= 1024);
-                                nth /= 2;
+                                }
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max];
                             }
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+
+                            const float scale = ((float *) dst->op_params)[0];
+
+                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst    atIndex:2];
+                            [encoder setBytes:&ne00  length:sizeof(ne00)  atIndex:3];
+                            [encoder setBytes:&ne01  length:sizeof(ne01)  atIndex:4];
+                            [encoder setBytes:&ne02  length:sizeof(ne02)  atIndex:5];
+                            [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1074,9 +1228,13 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL_MAT:
                         {
                             GGML_ASSERT(ne00 == ne10);
-                            GGML_ASSERT(ne03 == ne13);
 
-                            const unsigned int gqa = ne12/ne02;
+                            // TODO: assert that dim2 and dim3 are contiguous
+                            GGML_ASSERT(ne12 % ne02 == 0);
+                            GGML_ASSERT(ne13 % ne03 == 0);
+
+                            const uint r2 = ne12/ne02;
+                            const uint r3 = ne13/ne03;
 
                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                             // to the matrix-vector kernel
@@ -1111,7 +1269,7 @@ void ggml_metal_graph_compute(
                                 !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
                                 ne00 % 32 == 0 && ne00 >= 64 &&
-                                ne11 > ne11_mm_min) {
+                                (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
                                 //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
                                 switch (src0->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
@@ -1141,9 +1299,10 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
                                 [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
@@ -1179,90 +1338,60 @@ void ggml_metal_graph_compute(
                                         } break;
                                     case GGML_TYPE_Q4_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
                                         } break;
                                     case GGML_TYPE_Q4_1:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
                                         } break;
                                     case GGML_TYPE_Q5_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
                                         } break;
                                     case GGML_TYPE_Q5_1:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
                                         } break;
                                     case GGML_TYPE_Q8_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
                                         } break;
                                     case GGML_TYPE_Q2_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 4; //1;
                                             nth1 = 8; //32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1291,32 +1420,125 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                                [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
+                                [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17];
+                                [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                     src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
                                     src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                                 }
                                 else if (src0t == GGML_TYPE_Q5_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     int64_t ny = (ne11 + nrows - 1)/nrows;
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                            }
+                        } break;
+                    case GGML_OP_MUL_MAT_ID:
+                        {
+                            //GGML_ASSERT(ne00 == ne10);
+                            //GGML_ASSERT(ne03 == ne13);
+
+                            GGML_ASSERT(src0t == GGML_TYPE_I32);
+
+                            const int n_as = ne00;
+
+                            // TODO: make this more general
+                            GGML_ASSERT(n_as <= 8);
+
+                            struct ggml_tensor * src2 = gf->nodes[i]->src[2];
+
+                            const int64_t  ne20 = src2 ? src2->ne[0] : 0;
+                            const int64_t  ne21 = src2 ? src2->ne[1] : 0;
+                            const int64_t  ne22 = src2 ? src2->ne[2] : 0;
+                            const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+                            const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+                            const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+                            const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+                            const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
+
+                            const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+                            GGML_ASSERT(!ggml_is_transposed(src2));
+                            GGML_ASSERT(!ggml_is_transposed(src1));
+
+                            GGML_ASSERT(ne20 % 32 == 0);
+                            // !!!!!!!!! TODO: this assert is probably required but not sure!
+                            //GGML_ASSERT(ne20 >= 64);
+                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+                            const uint r2 = ne12/ne22;
+                            const uint r3 = ne13/ne23;
+
+                            // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                            // to the matrix-vector kernel
+                            int ne11_mm_min = 0;
+
+                            const int idx = ((int32_t *) dst->op_params)[0];
+
+                            // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                            // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                                ne11 > ne11_mm_min) {
+                                switch (src2->type) {
+                                    case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break;
+                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break;
+                                    case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
+                                    case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
+                                    case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
+                                    case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
+                                    case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
+                                    case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
+                                    case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
+                                    case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
+                                    case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
+                                    case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
+                                    default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
+                                }
+                                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:3];
+                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:4];
+                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:5];
+                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:6];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
+                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
+                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
+                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
+                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:15];
+                                // TODO: how to make this an array? read Metal docs
+                                for (int j = 0; j < n_as; ++j) {
+                                    struct ggml_tensor * src_cur = dst->src[2 + j];
+
+                                    size_t offs_src_cur = 0;
+                                    id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
                                 }
+
+                                [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             }
                         } break;
                     case GGML_OP_GET_ROWS:
@@ -1355,15 +1577,19 @@ void ggml_metal_graph_compute(
                             float eps;
                             memcpy(&eps, dst->op_params, sizeof(float));
 
-                            const int nth = MIN(512, ne00);
+                            int nth = 32; // SIMD width
+
+                            while (nth < ne00/4 && nth < 1024) {
+                                nth *= 2;
+                            }
 
                             [encoder setComputePipelineState:ctx->pipeline_rms_norm];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
-                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -1437,7 +1663,8 @@ void ggml_metal_graph_compute(
                             const int n_past     = ((int32_t *) dst->op_params)[0];
                             const int n_dims     = ((int32_t *) dst->op_params)[1];
                             const int mode       = ((int32_t *) dst->op_params)[2];
-                            const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
+                            // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
+                            const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
 
                             float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                             memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
@@ -1537,18 +1764,48 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
                         } break;
+                    case GGML_OP_ARGSORT:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                            GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+                            const int nrows = ggml_nrows(src0);
+
+                            enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+                            switch (order) {
+                                case GGML_SORT_ASC:  [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc];  break;
+                                case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
+                                default: GGML_ASSERT(false);
+                            };
+
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
                         {
-                            const int nth = MIN(1024, ne00);
+                            GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+                            int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
 
                             switch (src0t) {
                                 case GGML_TYPE_F32:
                                     {
+                                        GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
+
                                         switch (dstt) {
-                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
-                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
+                                            case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16];  break;
+                                            case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];  break;
+                                            case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
+                                            case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
+                                            case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
+                                            //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
+                                            //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
@@ -1623,81 +1880,150 @@ void ggml_metal_graph_compute(
 
 // backend interface
 
-static const char * ggml_backend_metal_name(ggml_backend_t backend) {
-    return "Metal";
+static id<MTLDevice> g_backend_device = nil;
+static int g_backend_device_ref_count = 0;
 
-    UNUSED(backend);
+static id<MTLDevice> ggml_backend_metal_get_device(void) {
+    if (g_backend_device == nil) {
+        g_backend_device = MTLCreateSystemDefaultDevice();
+    }
+
+    g_backend_device_ref_count++;
+
+    return g_backend_device;
 }
 
-static void ggml_backend_metal_free(ggml_backend_t backend) {
-    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
-    ggml_metal_free(ctx);
-    free(backend);
+static void ggml_backend_metal_free_device(void) {
+    assert(g_backend_device_ref_count > 0);
+
+    g_backend_device_ref_count--;
+
+    if (g_backend_device_ref_count == 0) {
+        [g_backend_device release];
+        g_backend_device = nil;
+    }
 }
 
 static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
-    return (void *)buffer->context;
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    return ctx->data;
 }
 
 static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    free(buffer->context);
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    [ctx->metal release];
+    ggml_backend_metal_free_device();
+
+    free(ctx->data);
+    free(ctx);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy((char *)tensor->data + offset, data, size);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+
     UNUSED(buffer);
 }
 
 static struct ggml_backend_buffer_i metal_backend_buffer_i = {
-    /* .free_buffer    = */ ggml_backend_metal_buffer_free_buffer,
-    /* .get_base       = */ ggml_backend_metal_buffer_get_base,
-    /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
-    /* .init_tensor    = */ NULL, // no initialization required
-    /* .free_tensor    = */ NULL, // no cleanup required
+    /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_metal_buffer_get_base,
+    /* .init_tensor     = */ NULL,
+    /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
+    /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
+    /* .cpy_tensor_to   = */ ggml_backend_metal_buffer_cpy_tensor_to,
 };
 
-static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
-    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
 
-    void * data = ggml_metal_host_malloc(size);
+    const size_t size_page = sysconf(_SC_PAGESIZE);
 
-    // TODO: set proper name of the buffers
-    ggml_metal_add_buffer(ctx, "backend", data, size, 0);
+    size_t size_aligned = size;
+    if ((size_aligned % size_page) != 0) {
+        size_aligned += (size_page - (size_aligned % size_page));
+    }
+
+    ctx->data  = ggml_metal_host_malloc(size);
+    ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
+                    length:size_aligned
+                    options:MTLResourceStorageModeShared
+                    deallocator:nil];
 
-    return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
+    return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
 }
 
-static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
+static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return 32;
-    UNUSED(backend);
+    UNUSED(buft);
 }
 
-static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
-    memcpy((char *)tensor->data + offset, data, size);
+static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
 
-    UNUSED(backend);
+    GGML_UNUSED(buft);
 }
 
-static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-
-    memcpy(data, (const char *)tensor->data + offset, size);
+ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
+        /* .iface = */ {
+            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_metal_buffer_type_get_alignment,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
+        },
+        /* .context = */ NULL,
+    };
 
-    UNUSED(backend);
+    return &ggml_backend_buffer_type_metal;
 }
 
-static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
+static const char * ggml_backend_metal_name(ggml_backend_t backend) {
+    return "Metal";
+
     UNUSED(backend);
 }
 
-static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+static void ggml_backend_metal_free(ggml_backend_t backend) {
+    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+    ggml_metal_free(ctx);
+    free(backend);
+}
 
+static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
     UNUSED(backend);
 }
 
-static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
-    ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
+static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_metal_buffer_type();
 
     UNUSED(backend);
 }
@@ -1709,32 +2035,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
 }
 
 static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    return true;
+    return ggml_metal_supports_op(op);
+
     UNUSED(backend);
-    UNUSED(op);
 }
 
 static struct ggml_backend_i metal_backend_i = {
-    /* .get_name            = */ ggml_backend_metal_name,
-    /* .free                = */ ggml_backend_metal_free,
-    /* .alloc_buffer        = */ ggml_backend_metal_alloc_buffer,
-    /* .get_alignment       = */ ggml_backend_metal_get_alignment,
-    /* .set_tensor_async    = */ ggml_backend_metal_set_tensor_async,
-    /* .get_tensor_async    = */ ggml_backend_metal_get_tensor_async,
-    /* .synchronize         = */ ggml_backend_metal_synchronize,
-    /* .cpy_tensor_from     = */ ggml_backend_metal_cpy_tensor_from,
-    /* .cpy_tensor_to       = */ ggml_backend_metal_cpy_tensor_to,
-    /* .graph_plan_create   = */ NULL, // the metal implementation does not require creating graph plans atm
-    /* .graph_plan_free     = */ NULL,
-    /* .graph_plan_compute  = */ NULL,
-    /* .graph_compute       = */ ggml_backend_metal_graph_compute,
-    /* .supports_op         = */ ggml_backend_metal_supports_op,
+    /* .get_name                = */ ggml_backend_metal_name,
+    /* .free                    = */ ggml_backend_metal_free,
+    /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_from_async   = */ NULL,
+    /* .cpy_tensor_to_async     = */ NULL,
+    /* .synchronize             = */ ggml_backend_metal_synchronize,
+    /* .graph_plan_create       = */ NULL, // the metal implementation does not require creating graph plans atm
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_metal_graph_compute,
+    /* .supports_op             = */ ggml_backend_metal_supports_op,
 };
 
+// TODO: make a common log callback for all backends in ggml-backend
+static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
+    fprintf(stderr, "%s", msg);
+
+    UNUSED(level);
+    UNUSED(user_data);
+}
+
 ggml_backend_t ggml_backend_metal_init(void) {
-    struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+    ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
 
-    ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
+    struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
+
+    if (ctx == NULL) {
+        return NULL;
+    }
 
     ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
 
@@ -1751,13 +2088,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
 }
 
 void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
     struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
 
     ggml_metal_set_n_cb(ctx, n_cb);
 }
 
 bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
     struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
 
-    return ggml_metal_supports_family(ctx, family);
+    return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
+}
+
+ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
+
+ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
+    return ggml_backend_metal_init();
+
+    GGML_UNUSED(params);
+    GGML_UNUSED(user_data);
 }
index 5d1357cd72d4592782802a60222e81d2cacb8d8f..2f8ea22d66226040f9b9b1de4b70466994f6ff94 100644 (file)
@@ -3,6 +3,8 @@
 using namespace metal;
 
 #define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
 
 #define QK4_0 32
 #define QR4_0 2
@@ -39,8 +41,15 @@ typedef struct {
     int8_t  qs[QK8_0]; // quants
 } block_q8_0;
 
-// general-purpose kernel for addition of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+enum ggml_sort_order {
+    GGML_SORT_ASC,
+    GGML_SORT_DESC,
+};
+
+// general-purpose kernel for addition, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
 // cons: not very efficient
 kernel void kernel_add(
         device const char * src0,
@@ -81,16 +90,111 @@ kernel void kernel_add(
     const int64_t i12 = i02 % ne12;
     const int64_t i11 = i01 % ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
+kernel void kernel_mul(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant  int64_t & nb00,
+        constant  int64_t & nb01,
+        constant  int64_t & nb02,
+        constant  int64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant  int64_t & nb10,
+        constant  int64_t & nb11,
+        constant  int64_t & nb12,
+        constant  int64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant  int64_t & nb0,
+        constant  int64_t & nb1,
+        constant  int64_t & nb2,
+        constant  int64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
 
     for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
+kernel void kernel_div(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant  int64_t & nb00,
+        constant  int64_t & nb01,
+        constant  int64_t & nb02,
+        constant  int64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant  int64_t & nb10,
+        constant  int64_t & nb11,
+        constant  int64_t & nb12,
+        constant  int64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant  int64_t & nb0,
+        constant  int64_t & nb1,
+        constant  int64_t & nb2,
+        constant  int64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
 
-        src0_ptr += ntg.x*nb00;
-        src1_ptr += ntg.x*nb10;
-        dst_ptr  += ntg.x*nb0;
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
     }
 }
 
@@ -105,23 +209,22 @@ kernel void kernel_add_row(
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
 
-kernel void kernel_mul(
+kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
+        constant    int64_t & nb  [[buffer(27)]],
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src1[tpig];
+    dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
 
-// assumption: src1 is a row
-// broadcast src1 into src0
-kernel void kernel_mul_row(
+kernel void kernel_div_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb,
+        constant    int64_t & nb  [[buffer(27)]],
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src1[tpig % nb];
+    dst[tpig] = src0[tpig] / src1[tpig % nb];
 }
 
 kernel void kernel_scale(
@@ -162,6 +265,54 @@ kernel void kernel_sqr(
     dst[tpig] = src0[tpig] * src0[tpig];
 }
 
+kernel void kernel_sum_rows(
+        device const float * src0,
+        device       float * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant  int64_t & nb00,
+        constant  int64_t & nb01,
+        constant  int64_t & nb02,
+        constant  int64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant  int64_t & nb10,
+        constant  int64_t & nb11,
+        constant  int64_t & nb12,
+        constant  int64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant  int64_t & nb0,
+        constant  int64_t & nb1,
+        constant  int64_t & nb2,
+        constant  int64_t & nb3,
+        uint3 tpig[[thread_position_in_grid]]) {
+    int64_t i3 = tpig.z;
+    int64_t i2 = tpig.y;
+    int64_t i1 = tpig.x;
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float row_sum = 0;
+
+    for (int64_t i0 = 0; i0 < ne00; i0++) {
+        row_sum += src_row[i0];
+    }
+
+    dst_row[0] = row_sum;
+}
+
 constant float GELU_COEF_A    = 0.044715f;
 constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
@@ -180,10 +331,12 @@ kernel void kernel_gelu(
 
 kernel void kernel_soft_max(
         device const float * src0,
+        device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
+        constant     float & scale,
         threadgroup float  * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
@@ -194,73 +347,77 @@ kernel void kernel_soft_max(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * psrc0 =        src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * pmask = src1 ? src1                                      + i01*ne00 : nullptr;
+    device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+    float lmax = -INFINITY;
 
-    for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]);
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
     }
 
-    float max = simd_max(lmax);
-    if (tiisg == 0) {
-        buf[sgitg] = max;
-    }
+    // find the max value in the block
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    max = buf[0];
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
 
     // parallel sum
     float lsum = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp(psrc0[i00] - max);
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
         lsum += exp_psrc0;
-        // Remember the result of exp here. exp is expensive, so we really do not
-        // wish to compute it twice.
         pdst[i00] = exp_psrc0;
     }
 
     float sum = simd_sum(lsum);
-    if (tiisg == 0) {
-        buf[sgitg] = sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] += buf[tpitg + i];
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
 
-    sum = buf[0];
+    const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        pdst[i00] /= sum;
+        pdst[i00] *= inv_sum;
     }
 }
 
 kernel void kernel_soft_max_4(
         device const float * src0,
+        device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
+        constant     float & scale,
         threadgroup float  * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
@@ -271,64 +428,68 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * psrc4 =        (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * pmask = src1 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
+    device       float4 * pdst4 =        (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
-    float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+    float4 lmax4 = -INFINITY;
 
-    for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]);
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
-    float max = simd_max(lmax);
-    if (tiisg == 0) {
-        buf[sgitg] = max;
-    }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
-       }
-    }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
 
-    max = buf[0];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
 
     // parallel sum
     float4 lsum4 = 0.0f;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp(psrc4[i00] - max);
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
 
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
     float sum = simd_sum(lsum);
-    if (tiisg == 0) {
-        buf[sgitg] = sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] += buf[tpitg + i];
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
 
-    sum = buf[0];
+    const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        pdst4[i00] /= sum;
+        pdst4[i00] *= inv_sum;
     }
 }
 
@@ -435,14 +596,13 @@ kernel void kernel_rms_norm(
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
         constant     float & eps,
-        threadgroup float  * sum [[threadgroup(0)]],
+        threadgroup float  * buf [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    device const float4 * x        = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
-    device const float  * x_scalar = (device const float  *) x;
+    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
 
     float4 sumf = 0;
     float all_sum = 0;
@@ -453,40 +613,30 @@ kernel void kernel_rms_norm(
     }
     all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
     all_sum = simd_sum(all_sum);
-    if (tiisg == 0) {
-        sum[sgitg] = all_sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           sum[tpitg] += sum[tpitg + i];
-       }
-    }
-    if (tpitg == 0) {
-        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
-            sum[0] += x_scalar[i];
+        if (tiisg == 0) {
+            buf[sgitg] = all_sum;
         }
-        sum[0] /= ne00;
-    }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        all_sum = buf[tiisg];
+        all_sum = simd_sum(all_sum);
+    }
 
-    const float mean  = sum[0];
+    const float mean  = all_sum/ne00;
     const float scale = 1.0f/sqrt(mean + eps);
 
     device float4 * y = (device float4 *) (dst + tgpig*ne00);
-    device float * y_scalar = (device float *) y;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         y[i00] = x[i00] * scale;
     }
-    if (tpitg == 0) {
-        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
-            y_scalar[i00] = x_scalar[i00] * scale;
-        }
-    }
 }
 
 // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -576,15 +726,25 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4        // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2  // number of SIMD groups in a thread group
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 //Note: This is a template, but strictly speaking it only applies to
 //      quantizations where the block size is 32. It also does not
 //      giard against the number of rows not being divisible by
 //      N_DST, so this is another explicit assumption of the implementation.
 template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
-                    int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
-                    uint3 tgpig, uint tiisg, uint sgitg) {
+void mul_vec_q_n_f32(
+        device const void  * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+                   uint3 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
 
     const int r0 = tgpig.x;
@@ -593,7 +753,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
 
     const int first_row = (r0 * nsg + sgitg) * nr;
 
-    const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
     device const block_q_type * x = (device const block_q_type *) src0 + offset0;
     device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +806,14 @@ kernel void kernel_mul_mv_q4_0_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +825,14 @@ kernel void kernel_mul_mv_q4_1_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +844,14 @@ kernel void kernel_mul_mv_q5_0_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -697,13 +863,14 @@ kernel void kernel_mul_mv_q5_1_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 
@@ -718,9 +885,10 @@ kernel void kernel_mul_mv_q8_0_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -732,8 +900,14 @@ kernel void kernel_mul_mv_q8_0_f32(
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
+
     const int first_row = (r0 * nsg + sgitg) * nr;
-    const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
     device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
@@ -791,6 +965,8 @@ kernel void kernel_mul_mv_f32_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -798,7 +974,12 @@ kernel void kernel_mul_mv_f32_f32(
     const int64_t rb = tgpig.y*N_F32_F32;
     const int64_t im = tgpig.z;
 
-    device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const float * x = (device const float *) (src0 + offset0);
 
     if (ne00 < 128) {
         for (int row = 0; row < N_F32_F32; ++row) {
@@ -864,6 +1045,8 @@ kernel void kernel_mul_mv_f16_f16(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -871,7 +1054,12 @@ kernel void kernel_mul_mv_f16_f16(
     const int64_t rb = tgpig.y*N_F16_F16;
     const int64_t im = tgpig.z;
 
-    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half * x = (device const half *) (src0 + offset0);
 
     if (ne00 < 128) {
         for (int row = 0; row < N_F16_F16; ++row) {
@@ -935,6 +1123,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -942,7 +1132,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
     const int64_t r1 = tgpig.y;
     const int64_t im = tgpig.z;
 
-    device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half  * x = (device const half  *) (src0 + offset0);
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
     float sumf = 0;
@@ -989,6 +1184,8 @@ kernel void kernel_mul_mv_f16_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
 
@@ -996,7 +1193,12 @@ kernel void kernel_mul_mv_f16_f32(
     const int64_t rb = tgpig.y*N_F16_F32;
     const int64_t im = tgpig.z;
 
-    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half * x = (device const half *) (src0 + offset0);
 
     if (ne00 < 128) {
         for (int row = 0; row < N_F16_F32; ++row) {
@@ -1061,6 +1263,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1068,7 +1272,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
     const int64_t r0 = tgpig.x;
     const int64_t im = tgpig.z;
 
-    device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half4 * x4 = (device const half4 *) (src0 + offset0);
 
     for (int r1 = 0; r1 < nrows; ++r1) {
         device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1120,17 +1329,21 @@ kernel void kernel_alibi_f32(
     const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+    const int64_t k = i3*ne3 + i2;
 
-    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
     float m_k;
-    if (i2 < n_heads_log2_floor) {
-        m_k = pow(m0, i2 + 1);
+    if (k < n_heads_log2_floor) {
+        m_k = pow(m0, k + 1);
     } else {
-        m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
+        m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
     }
+
+    device       char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
+    device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
+        const  float   src_v = *(device float *)(src_row + i00*nb00);
+        device float * dst_v =  (device float *)(dst_row + i00*nb0);
+        *dst_v = i00 * m_k + src_v;
     }
 }
 
@@ -1335,6 +1548,58 @@ kernel void kernel_im2col_f16(
     }
 }
 
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+        device const float * x,
+        device     int32_t * dst,
+        constant   int64_t & ncols,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_f32_i32(
+        device const float   * x,
+        device       int32_t * dst,
+        constant     int64_t & ncols,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]) {
+    // bitonic sort
+    int col = tpitg[0];
+    int row = tgpig[1];
+
+    if (col >= ncols) return;
+
+    device const float   * x_row   = x   + row * ncols;
+    device       int32_t * dst_row = dst + row * ncols;
+
+    // initialize indices
+    if (col < ncols) {
+        dst_row[col] = col;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    for (int k = 2; k <= ncols; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+        }
+    }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
+
 kernel void kernel_cpy_f16_f16(
         device const half * src0,
         device       half * dst,
@@ -1460,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
     }
 }
 
+kernel void kernel_cpy_f32_q8_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+
+    device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = src[j];
+            amax = MAX(amax, fabs(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK8_0].d = d;
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = src[j]*id;
+
+            dst_data[i00/QK8_0].qs[j] = round(x0);
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+
+    device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_0].d = d;
+
+        for (int j = 0; j < QK4_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK4_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            dst_data[i00/QK4_0].qs[j]  = xi0;
+            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+        device const float * src0,
+        device        void * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+
+    device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
+        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < QK4_1; j++) {
+            const float v = src[j];
+            if (min > v) min = v;
+            if (max < v) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_1].d = d;
+        dst_data[i00/QK4_1].m = min;
+
+        for (int j = 0; j < QK4_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            dst_data[i00/QK4_1].qs[j]  = xi0;
+            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
 kernel void kernel_concat(
     device const char * src0,
     device const char * src1,
@@ -1617,23 +2073,30 @@ kernel void kernel_mul_mv_q2_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int ib_row = first_row * nb;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
@@ -1642,11 +2105,11 @@ kernel void kernel_mul_mv_q2_K_f32(
 #if QK_K == 256
     const int ix = tiisg/8;  // 0...3
     const int it = tiisg%8;  // 0...7
-    const int im = it/4;     // 0 or 1
+    const int iq = it/4;     // 0 or 1
     const int ir = it%4;     // 0...3
     const int is = (8*ir)/16;// 0 or 1
 
-    device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
+    device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
 
     for (int ib = ix; ib < nb; ib += 4) {
 
@@ -1658,8 +2121,8 @@ kernel void kernel_mul_mv_q2_K_f32(
             yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
         }
 
-        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*im + is;
-        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*iq + is;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
@@ -1746,7 +2209,7 @@ kernel void kernel_mul_mv_q2_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -1761,9 +2224,10 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1772,12 +2236,17 @@ kernel void kernel_mul_mv_q3_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
-    const int64_t r2 = tgpig.z;
+    const int64_t im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
     float yl[32];
 
@@ -1899,7 +2368,7 @@ kernel void kernel_mul_mv_q3_K_f32(
     }
     if (tiisg == 0) {
         for (int row = 0; row < 2; ++row) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
         }
     }
 }
@@ -1913,26 +2382,33 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
-    const int64_t r2 = tgpig.z;
+    const int64_t im = tgpig.z;
 
     const int row = 2 * r0 + sgitg;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
     const int ix = tiisg/4;
     const int il = 4 * (tiisg%4);// 0, 4, 8, 12
-    const int im = il/8;         // 0, 0, 1, 1
+    const int iq = il/8;         // 0, 0, 1, 1
     const int in = il%8;         // 0, 4, 0, 4
 
     float2 sum = {0.f, 0.f};
@@ -1952,7 +2428,7 @@ kernel void kernel_mul_mv_q3_K_f32(
         const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
 
         for (int l = 0; l < 4; l += 2) {
-            const uint16_t hm = h[l/2] >> im;
+            const uint16_t hm = h[l/2] >> iq;
             sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 :  4))
                     + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
                     + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1968,7 +2444,7 @@ kernel void kernel_mul_mv_q3_K_f32(
 
     const float tot = simd_sum(sumf);
     if (tiisg == 0) {
-        dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
     }
 
 }
@@ -1986,10 +2462,11 @@ kernel void kernel_mul_mv_q4_K_f32(
         constant   int64_t & ne12 [[buffer(11)]],
         constant   int64_t & ne0  [[buffer(15)]],
         constant   int64_t & ne1  [[buffer(16)]],
-        constant   uint    & gqa  [[buffer(17)]],
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
@@ -1997,26 +2474,32 @@ kernel void kernel_mul_mv_q4_K_f32(
 
     const int ix = tiisg/8;  // 0...3
     const int it = tiisg%8;  // 0...7
-    const int im = it/4;     // 0 or 1
+    const int iq = it/4;     // 0 or 1
     const int ir = it%4;     // 0...3
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int im = tgpig.z;
     //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int first_row = r0 * N_DST;
     const int ib_row = first_row * nb;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
     float yl[16];
     float yh[16];
     float sumf[N_DST]={0.f}, all_sum;
 
     const int step = sizeof(block_q4_K) * nb / 2;
 
-    device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
+    device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
 
     uint16_t sc16[4];
     thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -2031,8 +2514,8 @@ kernel void kernel_mul_mv_q4_K_f32(
             yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
         }
 
-        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
-        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
@@ -2076,7 +2559,7 @@ kernel void kernel_mul_mv_q4_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -2090,9 +2573,10 @@ kernel void kernel_mul_mv_q4_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2103,12 +2587,18 @@ kernel void kernel_mul_mv_q4_K_f32(
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int im = tgpig.z;
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int ib_row = first_row * nb;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
     float yl[8];
     float yh[8];
     float sumf[N_DST]={0.f}, all_sum;
@@ -2164,7 +2654,7 @@ kernel void kernel_mul_mv_q4_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
+            dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -2179,9 +2669,10 @@ kernel void kernel_mul_mv_q5_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2190,12 +2681,17 @@ kernel void kernel_mul_mv_q5_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
     float sumf[2]={0.f};
 
@@ -2211,15 +2707,15 @@ kernel void kernel_mul_mv_q5_K_f32(
 
     const int tid = tiisg/4;
     const int ix  = tiisg%4;
-    const int im  = tid/4;
+    const int iq  = tid/4;
     const int ir  = tid%4;
     const int n   = 8;
 
     const int l0 = n*ir;
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
+    const int q_offset = 32*iq + l0;
+    const int y_offset = 64*iq + l0;
 
-    const uint8_t hm1 = 1u << (2*im);
+    const uint8_t hm1 = 1u << (2*iq);
     const uint8_t hm2 = hm1 << 1;
     const uint8_t hm3 = hm1 << 4;
     const uint8_t hm4 = hm2 << 4;
@@ -2234,7 +2730,7 @@ kernel void kernel_mul_mv_q5_K_f32(
         device const uint8_t * q1 = x[i].qs + q_offset;
         device const uint8_t * qh = x[i].qh + l0;
         device const half * dh = &x[i].d;
-        device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
+        device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
 
         device const float * y2 = y1 + 128;
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2290,7 +2786,7 @@ kernel void kernel_mul_mv_q5_K_f32(
 
     const int il = 4 * (tiisg/8);  // 0, 4, 8, 12
     const int ix = tiisg%8;
-    const int im = il/8;         // 0, 0, 1, 1
+    const int iq = il/8;         // 0, 0, 1, 1
     const int in = il%8;         // 0, 4, 0, 4
 
     device const float * y = yy + ix*QK_K + il;
@@ -2315,7 +2811,7 @@ kernel void kernel_mul_mv_q5_K_f32(
 
             float2 acc = {0.f, 0.f};
             for (int l = 0; l < 4; ++l) {
-                const uint8_t hl = h[l] >> im;
+                const uint8_t hl = h[l] >> iq;
                 acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
                         + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
                 acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2337,7 +2833,7 @@ kernel void kernel_mul_mv_q5_K_f32(
     for (int row = 0; row < 2; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
         }
     }
 
@@ -2352,9 +2848,10 @@ kernel void kernel_mul_mv_q6_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2368,12 +2865,17 @@ kernel void kernel_mul_mv_q6_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int     im = tgpig.z;
 
     const int row = 2 * r0 + sgitg;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
     float sumf = 0;
 
@@ -2439,7 +2941,7 @@ kernel void kernel_mul_mv_q6_K_f32(
 
     const float tot = simd_sum(sumf);
     if (tiisg == 0) {
-        dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
     }
 }
 
@@ -2749,24 +3251,25 @@ kernel void kernel_get_rows(
 
 // each block_q contains 16*nl weights
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-kernel void kernel_mul_mm(device const  uchar * src0,
-                          device const  uchar * src1,
-                          device        float * dst,
-                          constant    int64_t & ne00,
-                          constant    int64_t & ne02,
-                          constant    int64_t & nb01,
-                          constant    int64_t & nb02,
-                          constant    int64_t & ne12,
-                          constant    int64_t & nb10,
-                          constant    int64_t & nb11,
-                          constant    int64_t & nb12,
-                          constant    int64_t & ne0,
-                          constant    int64_t & ne1,
-                          constant       uint & gqa,
-                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
-                          uint3                 tgpig[[threadgroup_position_in_grid]],
-                          uint                  tiitg[[thread_index_in_threadgroup]],
-                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+void kernel_mul_mm_impl(device const  uchar * src0,
+                        device const  uchar * src1,
+                        device        float * dst,
+                        constant    int64_t & ne00,
+                        constant    int64_t & ne02,
+                        constant    int64_t & nb01,
+                        constant    int64_t & nb02,
+                        constant    int64_t & ne12,
+                        constant    int64_t & nb10,
+                        constant    int64_t & nb11,
+                        constant    int64_t & nb12,
+                        constant    int64_t & ne0,
+                        constant    int64_t & ne1,
+                        constant       uint & r2,
+                        constant       uint & r3,
+                        threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                        uint3                 tgpig[[threadgroup_position_in_grid]],
+                        uint                  tiitg[[thread_index_in_threadgroup]],
+                        uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     threadgroup half  * sa = (threadgroup half  *)(shared_memory);
     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2792,7 +3295,10 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
     short il = (tiitg % THREAD_PER_ROW);
 
-    uint   offset0 = im/gqa*nb02;
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
     ushort offset1 = il/nl;
 
     device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2876,14 +3382,116 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     }
 }
 
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm(device const  uchar * src0,
+                          device const  uchar * src1,
+                          device        float * dst,
+                          constant    int64_t & ne00,
+                          constant    int64_t & ne02,
+                          constant    int64_t & nb01,
+                          constant    int64_t & nb02,
+                          constant    int64_t & ne12,
+                          constant    int64_t & nb10,
+                          constant    int64_t & nb11,
+                          constant    int64_t & nb12,
+                          constant    int64_t & ne0,
+                          constant    int64_t & ne1,
+                          constant       uint & r2,
+                          constant       uint & r3,
+                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                          uint3                 tgpig[[threadgroup_position_in_grid]],
+                          uint                  tiitg[[thread_index_in_threadgroup]],
+                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mm_impl<block_q, nl, dequantize_func>(
+        src0,
+        src1,
+        dst,
+        ne00,
+        ne02,
+        nb01,
+        nb02,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        shared_memory,
+        tgpig,
+        tiitg,
+        sgitg);
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm_id(
+        device const int32_t * ids,
+        device const   uchar * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne02,
+        constant     int64_t & nb01,
+        constant     int64_t & nb02,
+        constant     int64_t & ne12,
+        constant     int64_t & nb10,
+        constant     int64_t & nb11,
+        constant     int64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const   uchar * src00,
+        device const   uchar * src01,
+        device const   uchar * src02,
+        device const   uchar * src03,
+        device const   uchar * src04,
+        device const   uchar * src05,
+        device const   uchar * src06,
+        device const   uchar * src07,
+        threadgroup    uchar * shared_memory [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    kernel_mul_mm_impl<block_q, nl, dequantize_func>(
+        src0[ids[idx]],
+        src1,
+        dst,
+        ne00,
+        ne02,
+        nb01,
+        nb02,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        shared_memory,
+        tgpig,
+        tiitg,
+        sgitg);
+}
+
 #if QK_K == 256
 #define QK_NL 16
 #else
 #define QK_NL 4
 #endif
 
-typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
-                          constant uint64_t &, constant uint64_t &, uint, uint, uint);
+typedef void (get_rows_t)(
+        device const void * src0,
+        device const  int * src1,
+        device      float * dst,
+        constant  int64_t & ne00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb1,
+        uint, uint, uint);
 
 template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
 template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
@@ -2912,8 +3520,10 @@ typedef void (mat_mm_t)(
         constant    int64_t & nb12,
         constant    int64_t & ne0,
         constant    int64_t & ne1,
-        constant       uint & gqa,
-        threadgroup uchar *, uint3, uint, uint);
+        constant       uint & r2,
+        constant       uint & r3,
+        threadgroup   uchar *,
+        uint3, uint, uint);
 
 template [[host_name("kernel_mul_mm_f32_f32")]]  kernel mat_mm_t kernel_mul_mm<float4x4,   1,     dequantize_f32>;
 template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>;
@@ -2927,3 +3537,44 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
 template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
+
+typedef void (mat_mm_id_t)(
+        device const int32_t * ids,
+        device const   uchar * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne02,
+        constant     int64_t & nb01,
+        constant     int64_t & nb02,
+        constant     int64_t & ne12,
+        constant     int64_t & nb10,
+        constant     int64_t & nb11,
+        constant     int64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const   uchar * src00,
+        device const   uchar * src01,
+        device const   uchar * src02,
+        device const   uchar * src03,
+        device const   uchar * src04,
+        device const   uchar * src05,
+        device const   uchar * src06,
+        device const   uchar * src07,
+        threadgroup    uchar *,
+        uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<float4x4,   1,     dequantize_f32>;
+template [[host_name("kernel_mul_mm_id_f16_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<half4x4,    1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2,     dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2,     dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
index 202bcb4853893c7332906418059c9111ddedda8d..496f9cdca542d348e8edc4c08e0c25e6fefe6fe2 100644 (file)
@@ -1,20 +1,18 @@
+#include "ggml.h"
 #include "ggml-opencl.h"
 
 #include <array>
 #include <atomic>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <limits>
 #include <sstream>
 #include <vector>
-#include <limits>
 
 #define CL_TARGET_OPENCL_VERSION 110
 #include <clblast.h>
 
-#include <stdlib.h>
-#include <stdio.h>
-#include <string.h>
-
-#include "ggml.h"
-
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
index cf2860b8cbd5924c3da459a5feb78541dd120323..7285d5f7fbcc00ce41b5d481b702ecc52c5671f6 100644 (file)
@@ -19,7 +19,7 @@
 #ifdef __wasm_simd128__
 #include <wasm_simd128.h>
 #else
-#ifdef __POWER9_VECTOR__
+#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
 #include <altivec.h>
 #undef bool
 #define bool _Bool
diff --git a/ggml.c b/ggml.c
index 9612aa554028810741732462390c16b1e7a1473d..ca56f063c3a87440353e3efce428c18e003517fa 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -233,24 +233,6 @@ inline static void * ggml_aligned_malloc(size_t size) {
 #define UNUSED GGML_UNUSED
 #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
-//
-// tensor access macros
-//
-
-#define GGML_TENSOR_UNARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
-
-#define GGML_TENSOR_BINARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
-
 #if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
 #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
@@ -1613,6 +1595,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GROUP_NORM",
 
     "MUL_MAT",
+    "MUL_MAT_ID",
     "OUT_PROD",
 
     "SCALE",
@@ -1640,6 +1623,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_1D",
     "POOL_2D",
     "UPSCALE",
+    "ARGSORT",
 
     "FLASH_ATTN",
     "FLASH_FF",
@@ -1666,7 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1695,6 +1679,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "group_norm(x)",
 
     "X*Y",
+    "X[i]*Y",
     "X*Y",
 
     "x*v",
@@ -1722,6 +1707,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_1d(x)",
     "pool_2d(x)",
     "upscale(x)",
+    "argsort(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
@@ -1748,10 +1734,28 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
+
+static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
+    "ABS",
+    "SGN",
+    "NEG",
+    "STEP",
+    "TANH",
+    "ELU",
+    "RELU",
+    "GELU",
+    "GELU_QUICK",
+    "SILU",
+    "LEAKY",
+};
+
+static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
+
+
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
@@ -1771,6 +1775,7 @@ static void ggml_setup_op_has_task_pass(void) {
 
         p[GGML_OP_ACC                    ] = true;
         p[GGML_OP_MUL_MAT                ] = true;
+        p[GGML_OP_MUL_MAT_ID             ] = true;
         p[GGML_OP_OUT_PROD               ] = true;
         p[GGML_OP_SET                    ] = true;
         p[GGML_OP_GET_ROWS_BACK          ] = true;
@@ -2023,6 +2028,20 @@ const char * ggml_op_symbol(enum ggml_op op) {
     return GGML_OP_SYMBOL[op];
 }
 
+const char * ggml_unary_op_name(enum ggml_unary_op op) {
+    return GGML_UNARY_OP_NAME[op];
+}
+
+const char * ggml_op_desc(const struct ggml_tensor * t) {
+    if (t->op == GGML_OP_UNARY) {
+        enum ggml_unary_op uop = ggml_get_unary_op(t);
+        return ggml_unary_op_name(uop);
+    }
+    else {
+        return ggml_op_name(t->op);
+    }
+}
+
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return ggml_type_size(tensor->type);
 }
@@ -3154,9 +3173,7 @@ static struct ggml_tensor * ggml_add_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    // TODO: support less-strict constraint
-    //       GGML_ASSERT(ggml_can_repeat(b, a));
-    GGML_ASSERT(ggml_can_repeat_rows(b, a));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -3371,9 +3388,7 @@ static struct ggml_tensor * ggml_mul_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    // TODO: support less-strict constraint
-    //       GGML_ASSERT(ggml_can_repeat(b, a));
-    GGML_ASSERT(ggml_can_repeat_rows(b, a));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -3418,7 +3433,7 @@ static struct ggml_tensor * ggml_div_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -4056,6 +4071,49 @@ struct ggml_tensor * ggml_mul_mat(
     return result;
 }
 
+// ggml_mul_mat_id
+
+struct ggml_tensor * ggml_mul_mat_id(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * as[],
+        struct ggml_tensor  * ids,
+        int                   id,
+        struct ggml_tensor  * b) {
+
+    int64_t n_as = ids->ne[0];
+
+    GGML_ASSERT(ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_vector(ids));
+    GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
+    GGML_ASSERT(id >= 0 && id < n_as);
+
+    bool is_node = false;
+
+    if (as[0]->grad || b->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
+
+    ggml_set_op_params_i32(result, 0, id);
+
+    result->op   = GGML_OP_MUL_MAT_ID;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = ids;
+    result->src[1] = b;
+
+    for (int64_t i = 0; i < n_as; i++) {
+        struct ggml_tensor * a = as[i];
+        GGML_ASSERT(ggml_are_same_shape(as[0], a));
+        GGML_ASSERT(ggml_can_mul_mat(a, b));
+        GGML_ASSERT(!ggml_is_transposed(a));
+        result->src[i + 2] = a;
+    }
+
+    return result;
+}
+
 // ggml_out_prod
 
 struct ggml_tensor * ggml_out_prod(
@@ -4209,7 +4267,7 @@ struct ggml_tensor * ggml_set_2d_inplace(
         struct ggml_tensor *  b,
         size_t                nb1,
         size_t                offset) {
-    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
+    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
 }
 
 // ggml_cpy
@@ -4826,7 +4884,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
 static struct ggml_tensor * ggml_soft_max_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale,
         bool                  inplace) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    if (mask) {
+        GGML_ASSERT(ggml_is_contiguous(mask));
+        GGML_ASSERT(mask->ne[2] == 1);
+        GGML_ASSERT(mask->ne[3] == 1);
+        GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+    }
+
     bool is_node = false;
 
     if (a->grad) {
@@ -4835,9 +4903,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    float params[] = { scale };
+    ggml_set_op_params(result, params, sizeof(params));
+
     result->op   = GGML_OP_SOFT_MAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
+    result->src[1] = mask;
 
     return result;
 }
@@ -4845,13 +4917,21 @@ static struct ggml_tensor * ggml_soft_max_impl(
 struct ggml_tensor * ggml_soft_max(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, false);
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
 }
 
 struct ggml_tensor * ggml_soft_max_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, true);
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
+}
+
+struct ggml_tensor * ggml_soft_max_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale) {
+    return ggml_soft_max_impl(ctx, a, mask, scale, false);
 }
 
 // ggml_soft_max_back
@@ -5446,6 +5526,43 @@ struct ggml_tensor * ggml_upscale(
     return ggml_upscale_impl(ctx, a, scale_factor);
 }
 
+// ggml_argsort
+
+struct ggml_tensor * ggml_argsort(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_sort_order  order) {
+    bool is_node = false;
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, a->ne);
+
+    ggml_set_op_params_i32(result, 0, (int32_t) order);
+
+    result->op   = GGML_OP_ARGSORT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_top_k
+
+struct ggml_tensor * ggml_top_k(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   k) {
+    GGML_ASSERT(a->ne[0] >= k);
+
+    struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_DESC);
+
+    result = ggml_view_4d(ctx, result,
+                k, result->ne[1], result->ne[2], result->ne[3],
+                   result->nb[1], result->nb[2], result->nb[3],
+                0);
+
+    return result;
+}
+
 // ggml_flash_attn
 
 struct ggml_tensor * ggml_flash_attn(
@@ -6805,7 +6922,7 @@ static void ggml_compute_forward_add_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -6838,16 +6955,19 @@ static void ggml_compute_forward_add_f32(
             const int64_t i13 = i03 % ne13;
             const int64_t i12 = i02 % ne12;
             const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
 
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
             float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
+            for (int64_t r = 0; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
+                vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
+                ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
+            }
         }
     } else {
         // src1 is not contiguous
@@ -6864,8 +6984,9 @@ static void ggml_compute_forward_add_f32(
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
             }
@@ -7585,7 +7706,7 @@ static void ggml_compute_forward_mul_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -7608,7 +7729,6 @@ static void ggml_compute_forward_mul_f32(
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(ne00 == ne10);
 
     if (nb10 == sizeof(float)) {
         for (int64_t ir = ith; ir < nr; ir += nth) {
@@ -7620,20 +7740,21 @@ static void ggml_compute_forward_mul_f32(
             const int64_t i13 = i03 % ne13;
             const int64_t i12 = i02 % ne12;
             const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
 
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
             float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
+            for (int64_t r = 0 ; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            UNUSED(ggml_vec_mul_f32);
+                UNUSED(ggml_vec_mul_f32);
 
-            vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr,  1, ne00);
+                vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
+                ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
-                // }
-            // }
+            }
         }
     } else {
         // src1 is not contiguous
@@ -7651,8 +7772,9 @@ static void ggml_compute_forward_mul_f32(
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 
-            for (int64_t i0 = 0; i0 < ne00; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
             }
@@ -7686,14 +7808,16 @@ static void ggml_compute_forward_div_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int nr  = ggml_nrows(src0);
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -7701,41 +7825,50 @@ static void ggml_compute_forward_div_f32(
     GGML_ASSERT(nb00 == sizeof(float));
 
     if (nb10 == sizeof(float)) {
-        for (int ir = 0; ir < nr; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
 
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            UNUSED(ggml_vec_div_f32);
+                UNUSED(ggml_vec_div_f32);
 
-            vDSP_vdiv(
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+                vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_div_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+                ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
-                // }
-            // }
+            }
         }
     } else {
         // src1 is not contiguous
-        for (int ir = 0; ir < nr; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
             }
@@ -8181,7 +8314,7 @@ static void ggml_compute_forward_repeat_f16(
         return;
     }
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
     const int nr0 = (int)(ne0/ne00);
@@ -8326,6 +8459,7 @@ static void ggml_compute_forward_concat_f32(
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
+    const int nth = params->nth;
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -8335,7 +8469,7 @@ static void ggml_compute_forward_concat_f32(
     GGML_ASSERT(nb10 == sizeof(float));
 
     for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = ith; i2 < ne2; i2++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
             if (i2 < ne02) { // src0
                 for (int i1 = 0; i1 < ne1; i1++) {
                     for (int i0 = 0; i0 < ne0; i0++) {
@@ -9495,6 +9629,8 @@ static void ggml_compute_forward_mul_mat(
             char * wdata = params->wdata;
             const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
+            assert(params->wsize >= ne11*ne12*ne13*row_size);
+
             for (int64_t i13 = 0; i13 < ne13; ++i13) {
                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
                     for (int64_t i11 = 0; i11 < ne11; ++i11) {
@@ -9596,6 +9732,26 @@ static void ggml_compute_forward_mul_mat(
     }
 }
 
+// ggml_compute_forward_mul_mat_id
+
+static void ggml_compute_forward_mul_mat_id(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    const int id = ggml_get_op_params_i32(dst, 0);
+
+    const int a_id = ((int32_t *)ids->data)[id];
+
+    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+
+    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+
+    ggml_compute_forward_mul_mat(params, src0, src1, dst);
+}
+
 // ggml_compute_forward_out_prod
 
 static void ggml_compute_forward_out_prod_f32(
@@ -9611,10 +9767,12 @@ static void ggml_compute_forward_out_prod_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
+    GGML_ASSERT(ne0  == ne00);
+    GGML_ASSERT(ne1  == ne10);
+    GGML_ASSERT(ne2  == ne02);
     GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
     GGML_ASSERT(ne3  == ne13);
+    GGML_ASSERT(ne03 == ne13);
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9625,18 +9783,25 @@ static void ggml_compute_forward_out_prod_f32(
     // GGML_ASSERT(nb1 <= nb2);
     // GGML_ASSERT(nb2 <= nb3);
 
-    GGML_ASSERT(ne0 == ne00);
-    GGML_ASSERT(ne1 == ne10);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
     // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
-    // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
+    // TODO: #if defined(GGML_USE_CLBLAST)
+
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    bool use_blas = ggml_is_matrix(src0) &&
+        ggml_is_matrix(src1) &&
+        ggml_is_contiguous(src0) &&
+        (ggml_is_contiguous(src1) || ggml_is_transposed(src1));
+#endif
 
     if (params->type == GGML_TASK_INIT) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
+        if (use_blas) {
+            return;
+        }
+#endif
         ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
         return;
     }
@@ -9645,6 +9810,50 @@ static void ggml_compute_forward_out_prod_f32(
         return;
     }
 
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    if (use_blas) {
+        if (params->ith != 0) { // All threads other than the first do no work.
+            return;
+        }
+        // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
+        // src0: (k,n)
+        // src1: (k,m)
+        // dst:  (m,n)
+        //
+        // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
+        // Also expressed as (major,minor)
+        // a: (m,k): so src1 transposed
+        // b: (k,n): so src0
+        // c: (m,n)
+        //
+        // However, if ggml_is_transposed(src1) is true, then
+        // src1->data already contains a transposed version, so sgemm mustn't
+        // transpose it further.
+
+        int n = src0->ne[0];
+        int k = src0->ne[1];
+        int m = src1->ne[0];
+
+        int transposeA, lda;
+
+        if (!ggml_is_transposed(src1)) {
+            transposeA = CblasTrans;
+            lda = m;
+        } else {
+            transposeA = CblasNoTrans;
+            lda = k;
+        }
+
+        float * a = (float *) ((char *) src1->data);
+        float * b = (float *) ((char *) src0->data);
+        float * c = (float *) ((char *) dst->data);
+
+        cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
+
+        return;
+    }
+#endif
+
     // dst[:,:,:,:] = 0
     // for i2,i3:
     //   for i1:
@@ -10498,20 +10707,25 @@ static void ggml_compute_forward_diag_mask_zero(
 static void ggml_compute_forward_soft_max_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(ggml_is_contiguous(dst));
+    assert(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
+    float scale = 1.0f;
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
     const int nth = params->nth;
 
+    const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
     const int nc = src0->ne[0];
     const int nr = ggml_nrows(src0);
 
@@ -10522,29 +10736,40 @@ static void ggml_compute_forward_soft_max_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
+    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
     for (int i1 = ir0; i1 < ir1; i1++) {
-        float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float *dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+
+        // broadcast the mask across rows
+        float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+
+        ggml_vec_cpy_f32  (nc, wp, sp);
+        ggml_vec_scale_f32(nc, wp, scale);
+        if (mp) {
+            ggml_vec_acc_f32(nc, wp, mp);
+        }
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
             //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(sp[i]));
+            assert(!isnan(wp[i]));
         }
 #endif
 
         float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, sp);
+        ggml_vec_max_f32(nc, &max, wp);
 
         ggml_float sum = 0.0;
 
         uint16_t scvt;
         for (int i = 0; i < nc; i++) {
-            if (sp[i] == -INFINITY) {
+            if (wp[i] == -INFINITY) {
                 dp[i] = 0.0f;
             } else {
-                // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
-                ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
+                // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
+                ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
                 memcpy(&scvt, &s, sizeof(scvt));
                 const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
                 sum += (ggml_float)val;
@@ -10569,11 +10794,12 @@ static void ggml_compute_forward_soft_max_f32(
 static void ggml_compute_forward_soft_max(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_soft_max_f32(params, src0, dst);
+                ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
             } break;
         default:
             {
@@ -11929,6 +12155,67 @@ static void ggml_compute_forward_upscale(
     }
 }
 
+// ggml_compute_forward_argsort
+
+static void ggml_compute_forward_argsort_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+    for (int64_t i = ith; i < nr; i += nth) {
+        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+        const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+        for (int64_t j = 0; j < ne0; j++) {
+            dst_data[j] = j;
+        }
+
+        // C doesn't have a functional sort, so we do a bubble sort instead
+        for (int64_t j = 0; j < ne0; j++) {
+            for (int64_t k = j + 1; k < ne0; k++) {
+                if ((order == GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+                    (order == GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+                    int32_t tmp = dst_data[j];
+                    dst_data[j] = dst_data[k];
+                    dst_data[k] = tmp;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_argsort(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argsort_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_flash_attn
 
 static void ggml_compute_forward_flash_attn_f32(
@@ -13752,6 +14039,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                ggml_compute_forward_mul_mat_id(params, tensor);
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
@@ -13810,7 +14101,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SOFT_MAX:
             {
-                ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
+                ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_SOFT_MAX_BACK:
             {
@@ -13856,6 +14147,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_upscale(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                ggml_compute_forward_argsort(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -14506,6 +14801,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 zero_table);
                 }
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -14844,6 +15143,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -15204,12 +15507,8 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
     return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
 }
 
-struct ggml_cgraph * ggml_graph_view(struct ggml_context * ctx, struct ggml_cgraph * cgraph0, int i0, int i1) {
-    const size_t obj_size = sizeof(struct ggml_cgraph);
-    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size);
-    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
-
-    *cgraph = (struct ggml_cgraph) {
+struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
+    struct ggml_cgraph cgraph = {
         /*.size         =*/ 0,
         /*.n_nodes      =*/ i1 - i0,
         /*.n_leafs      =*/ 0,
@@ -15444,7 +15743,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_SUB:
-        case GGML_OP_DIV:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_LOG:
@@ -15477,10 +15775,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                     {
                         n_tasks = n_threads;
                     } break;
+                default:
+                    GGML_ASSERT(false);
             }
             break;
         case GGML_OP_SILU_BACK:
         case GGML_OP_MUL:
+        case GGML_OP_DIV:
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_RMS_NORM_BACK:
@@ -15518,6 +15819,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 }
 #endif
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                // FIXME: blas
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 n_tasks = n_threads;
@@ -15537,7 +15843,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
         case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_SOFT_MAX:
         case GGML_OP_SOFT_MAX_BACK:
         case GGML_OP_ROPE:
         case GGML_OP_ROPE_BACK:
@@ -15553,6 +15858,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = 1; //TODO
             } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
+            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 n_tasks = n_threads;
@@ -15574,6 +15883,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = n_threads;
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 n_tasks = n_threads;
@@ -15642,7 +15955,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         default:
             {
-                printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op));
+                fprintf(stderr, "%s: op not implemented: ", __func__);
+                if (node->op < GGML_OP_COUNT) {
+                    fprintf(stderr, "%s\n", ggml_op_name(node->op));
+                } else {
+                    fprintf(stderr, "%d\n", node->op);
+                }
                 GGML_ASSERT(false);
             } break;
     }
@@ -15783,18 +16101,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
     // thread scheduling for the different operations + work buffer size estimation
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        int n_tasks = 1;
-
         struct ggml_tensor * node = cgraph->nodes[i];
 
+        const int n_tasks = ggml_get_n_tasks(node, n_threads);
+
         size_t cur = 0;
 
         switch (node->op) {
             case GGML_OP_CPY:
             case GGML_OP_DUP:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                     }
@@ -15802,16 +16118,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_ADD:
             case GGML_OP_ADD1:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
             case GGML_OP_ACC:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
                     }
@@ -15837,14 +16149,33 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
                     }
                 } break;
+            case GGML_OP_MUL_MAT_ID:
+                {
+                    const struct ggml_tensor * a = node->src[2];
+                    const struct ggml_tensor * b = node->src[1];
+                    const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                    if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
+                        if (a->type != GGML_TYPE_F32) {
+                            // here we need memory just for single 2D matrix from src0
+                            cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
+                        }
+                    } else
+#endif
+                    if (b->type != vec_dot_type) {
+                        cur = ggml_type_size(vec_dot_type)*ggml_nelements(b)/ggml_blck_size(vec_dot_type);
+                    }
+                } break;
             case GGML_OP_OUT_PROD:
                 {
-                    n_tasks = n_threads;
-
                     if (ggml_is_quantized(node->src[0]->type)) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
+            case GGML_OP_SOFT_MAX:
+                {
+                    cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+                } break;
             case GGML_OP_CONV_TRANSPOSE_1D:
                 {
                     GGML_ASSERT(node->src[0]->ne[3] == 1);
@@ -15870,10 +16201,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         GGML_ASSERT(false);
                     }
                 } break;
-            case GGML_OP_IM2COL:
-                {
-                    n_tasks = n_threads;
-                } break;
             case GGML_OP_CONV_TRANSPOSE_2D:
                 {
                     const int64_t ne00 = node->src[0]->ne[0]; // W
@@ -15890,8 +16217,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_ATTN:
                 {
-                    n_tasks = n_threads;
-
                     const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
 
                     if (node->src[1]->type == GGML_TYPE_F32) {
@@ -15904,8 +16229,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_FF:
                 {
-                    n_tasks = n_threads;
-
                     if (node->src[1]->type == GGML_TYPE_F32) {
                         cur  = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
                         cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -15916,8 +16239,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
             case GGML_OP_FLASH_ATTN_BACK:
                 {
-                    n_tasks = n_threads;
-
                     const int64_t    D = node->src[0]->ne[0];
                     const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
                     const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -15932,8 +16253,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
             case GGML_OP_CROSS_ENTROPY_LOSS:
                 {
-                    n_tasks = n_threads;
-
                     cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
                 } break;
             case GGML_OP_COUNT:
@@ -17720,8 +18039,8 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t *
             memcpy(&qh, &y[i].qh, sizeof(qh));
 
             for (int j = 0; j < QK5_0; j += 2) {
-                const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-                const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
+                const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
+                const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
 
                 // cast to 16 bins
                 const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -17750,8 +18069,8 @@ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t *
             memcpy(&qh, &y[i].qh, sizeof(qh));
 
             for (int j = 0; j < QK5_1; j += 2) {
-                const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-                const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
+                const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
+                const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
 
                 // cast to 16 bins
                 const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -17941,6 +18260,7 @@ struct gguf_kv {
 
 struct gguf_header {
     char magic[4];
+
     uint32_t version;
     uint64_t n_tensors; // GGUFv2
     uint64_t n_kv;      // GGUFv2
@@ -18030,7 +18350,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
 
         for (uint32_t i = 0; i < sizeof(magic); i++) {
             if (magic[i] != GGUF_MAGIC[i]) {
-                fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
+                fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
                 fclose(file);
                 return NULL;
             }
@@ -18045,7 +18365,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
     {
         strncpy(ctx->header.magic, magic, 4);
 
-
         ctx->kv    = NULL;
         ctx->infos = NULL;
         ctx->data  = NULL;
@@ -18399,24 +18718,29 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
 }
 
 const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     return ctx->kv[key_id].key.data;
 }
 
 enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     return ctx->kv[key_id].type;
 }
 
 enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     return ctx->kv[key_id].value.arr.type;
 }
 
 const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     return ctx->kv[key_id].value.arr.data;
 }
 
 const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     struct gguf_kv * kv = &ctx->kv[key_id];
     struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
@@ -18424,70 +18748,90 @@ const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i
 }
 
 int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
     return ctx->kv[key_id].value.arr.n;
 }
 
 uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
     return ctx->kv[key_id].value.uint8;
 }
 
 int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
     return ctx->kv[key_id].value.int8;
 }
 
 uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
     return ctx->kv[key_id].value.uint16;
 }
 
 int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
     return ctx->kv[key_id].value.int16;
 }
 
 uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
     return ctx->kv[key_id].value.uint32;
 }
 
 int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
     return ctx->kv[key_id].value.int32;
 }
 
 float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
     return ctx->kv[key_id].value.float32;
 }
 
 uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
     return ctx->kv[key_id].value.uint64;
 }
 
 int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
     return ctx->kv[key_id].value.int64;
 }
 
 double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
     return ctx->kv[key_id].value.float64;
 }
 
 bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
     return ctx->kv[key_id].value.bool_;
 }
 
 const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
     GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
     return ctx->kv[key_id].value.str.data;
 }
 
+const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
+    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
+    return &ctx->kv[key_id].value;
+}
+
 int gguf_get_n_tensors(const struct gguf_context * ctx) {
     return ctx->header.n_tensors;
 }
diff --git a/ggml.h b/ggml.h
index 8e6b646066b7a488197becae814d17e504916194..a8f10cbd5c1d8ead7963644753cd3d8bdd776f05 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_ASSERT(x) \
     do { \
         if (!(x)) { \
-            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
-            fflush(stderr); \
             fflush(stdout); \
+            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
             ggml_print_backtrace(); \
-            exit(1); \
+            abort(); \
         } \
     } while (0)
 
     const type prefix##3 = (pointer)->array[3]; \
     GGML_UNUSED(prefix##3);
 
+#define GGML_TENSOR_UNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
@@ -382,6 +395,7 @@ extern "C" {
         GGML_OP_GROUP_NORM,
 
         GGML_OP_MUL_MAT,
+        GGML_OP_MUL_MAT_ID,
         GGML_OP_OUT_PROD,
 
         GGML_OP_SCALE,
@@ -408,8 +422,8 @@ extern "C" {
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
-
         GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_ARGSORT,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
@@ -449,7 +463,9 @@ extern "C" {
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
-        GGML_UNARY_OP_LEAKY
+        GGML_UNARY_OP_LEAKY,
+
+        GGML_UNARY_OP_COUNT,
     };
 
     enum ggml_object_type {
@@ -632,6 +648,9 @@ extern "C" {
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
     GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 
+    GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
+    GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
+
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
     GGML_API bool    ggml_is_quantized(enum ggml_type type);
@@ -1028,6 +1047,15 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // indirect matrix multiplication
+    //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
+    GGML_API struct ggml_tensor * ggml_mul_mat_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * as[],
+            struct ggml_tensor  * ids,
+            int                   id,
+            struct ggml_tensor  * b);
+
     // A: m columns, n rows,
     // B: p columns, n rows,
     // result is m columns, p rows
@@ -1283,6 +1311,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // fused soft_max(a*scale + mask)
+    // mask is optional
+    GGML_API struct ggml_tensor * ggml_soft_max_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * mask,
+            float                 scale);
+
     GGML_API struct ggml_tensor * ggml_soft_max_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1513,6 +1549,23 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   scale_factor);
 
+    // sort rows
+    enum ggml_sort_order {
+        GGML_SORT_ASC,
+        GGML_SORT_DESC,
+    };
+
+    GGML_API struct ggml_tensor * ggml_argsort(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_sort_order  order);
+
+    // top k elements per row
+    GGML_API struct ggml_tensor * ggml_top_k(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   k);
+
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
@@ -1574,7 +1627,6 @@ extern "C" {
             int                   kh);
 
     // used in sam
-
     GGML_API struct ggml_tensor * ggml_add_rel_pos(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1749,7 +1801,7 @@ extern "C" {
     GGML_API struct ggml_cgraph * ggml_new_graph         (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
     GGML_API struct ggml_cgraph * ggml_new_graph_custom  (struct ggml_context * ctx, size_t size, bool grads);
     GGML_API struct ggml_cgraph * ggml_graph_dup         (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
-    GGML_API struct ggml_cgraph * ggml_graph_view        (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
+    GGML_API struct ggml_cgraph   ggml_graph_view        (struct ggml_cgraph * cgraph, int i0, int i1);
     GGML_API void                 ggml_graph_cpy         (struct ggml_cgraph * src, struct ggml_cgraph * dst);
     GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph);  // zero grads
     GGML_API void                 ggml_graph_clear       (struct ggml_cgraph * cgraph);
@@ -2045,6 +2097,7 @@ extern "C" {
     GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
     GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
     GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
     GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
     GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
     GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
index 971b8e6f4d712789dd2ad642ddfed4fa10ff2797..e709e29fd5006ad9e1149fef1287c91e7619a526 100644 (file)
@@ -1063,7 +1063,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
 #ifdef GGML_USE_CUBLAS
     if (params.use_gpu && ggml_cublas_loaded()) {
         WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
-        backend_gpu = ggml_backend_cuda_init();
+        backend_gpu = ggml_backend_cuda_init(0);
         if (!backend_gpu) {
             WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
         }
@@ -1077,8 +1077,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
         backend_gpu = ggml_backend_metal_init();
         if (!backend_gpu) {
             WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
-        }
-        if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
+        } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
             WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
             ggml_backend_free(backend_gpu);
             backend_gpu = NULL;
@@ -1346,10 +1345,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
 
             model.e_conv_1_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_mels,     n_audio_state);
-            model.e_conv_1_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
+            model.e_conv_1_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,         1,     n_audio_state);
 
             model.e_conv_2_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_audio_state, n_audio_state);
-            model.e_conv_2_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,    n_audio_ctx,   n_audio_state);
+            model.e_conv_2_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,                1, n_audio_state);
 
             model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
             model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
@@ -1579,29 +1578,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
             auto tensor = model.tensors[name.data()];
 
-            const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
-
-            if (!is_conv_bias) {
-                if (ggml_nelements(tensor) != nelements) {
-                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
-                    WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
-                            __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
-                    return false;
-                }
+            if (ggml_nelements(tensor) != nelements) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
+                        __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
+                return false;
+            }
 
-                if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
-                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
-                            __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
-                    return false;
-                }
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+                        __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
+                return false;
+            }
 
-                const size_t bpe = ggml_type_size(ggml_type(ttype));
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
 
-                if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
-                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
-                            __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
-                    return false;
-                }
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
             }
 
             ggml_backend_t backend = wctx.backend;
@@ -1612,7 +1607,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 #ifdef GGML_USE_METAL
                 || ggml_backend_is_metal(backend)
 #endif
-                ) && !is_conv_bias) {
+                )) {
                 // for the CPU and Metal backend, we can read directly into the tensor
                 loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
                 BYTESWAP_TENSOR(tensor);
@@ -1620,24 +1615,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 // read into a temporary buffer first, then copy to device memory
                 read_buf.resize(ggml_nbytes(tensor));
 
-                // we repeat the 2 bias tensors along dim 0:
-                // [1, 512] -> [3000, 512] (conv1.bias)
-                // [1, 512] -> [1500, 512] (conv2.bias)
-                if (is_conv_bias) {
-                    loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
-
-                    float * data_f32 = (float *) read_buf.data();
-                    for (int64_t y = 0; y < tensor->ne[1]; ++y) {
-                        const int64_t yy = tensor->ne[1] - y - 1;
-                        const float val = data_f32[yy];
-
-                        for (int64_t x = 0; x < tensor->ne[0]; ++x) {
-                            data_f32[yy*tensor->ne[0] + x] = val;
-                        }
-                    }
-                } else {
-                    loader->read(loader->context, read_buf.data(), read_buf.size());
-                }
+                loader->read(loader->context, read_buf.data(), read_buf.size());
 
                 ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
             }
@@ -1737,20 +1715,12 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         // convolution + gelu
         {
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
-            if (n_ctx == hparams.n_audio_ctx) {
-                cur = ggml_add(ctx0, cur, model.e_conv_1_b);
-            } else {
-                cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_1_b, cur->ne[0], cur->ne[1], model.e_conv_1_b->nb[1], 0)));
-            }
+            cur = ggml_add(ctx0, cur, model.e_conv_1_b);
 
             cur = ggml_gelu(ctx0, cur);
 
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
-            if (n_ctx == hparams.n_audio_ctx) {
-                cur = ggml_add(ctx0, cur, model.e_conv_2_b);
-            } else {
-                cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_2_b, cur->ne[0], cur->ne[1], model.e_conv_2_b->nb[1], 0)));
-            }
+            cur = ggml_add(ctx0, cur, model.e_conv_2_b);
 
             cur = ggml_gelu(ctx0, cur);
         }