]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : whisper.cpp (Metal + ggml sched_yield fix + reduce ggml-alloc size) (#522)
authorGeorgi Gerganov <redacted>
Fri, 15 Sep 2023 16:07:30 +0000 (19:07 +0300)
committerGitHub <redacted>
Fri, 15 Sep 2023 16:07:30 +0000 (19:07 +0300)
ggml-ci

examples/whisper/whisper.cpp
include/ggml/ggml.h
scripts/sync-whisper.sh
src/ggml-alloc.c
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c

index 3192fbc640425b7d2714cc587971cb45e1d7e34c..1224be9bf51bcb4ae10dc5adfef9015aca807da0 100644 (file)
@@ -3,11 +3,16 @@
 #include "coreml/whisper-encoder.h"
 #endif
 
+#ifdef GGML_USE_METAL
+#  include "ggml-metal.h"
+#endif
+
 #ifdef WHISPER_USE_OPENVINO
 #include "openvino/whisper-openvino-encoder.h"
 #endif
 
 #include "ggml.h"
+#include "ggml-alloc.h"
 
 #include <algorithm>
 #include <cassert>
 #include <cstring>
 #include <fstream>
 #include <map>
+#include <set>
 #include <string>
 #include <thread>
 #include <vector>
 #include <regex>
 #include <random>
+#include <functional>
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -114,8 +121,58 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 //#define WHISPER_USE_FLASH_FF
 #define WHISPER_MAX_DECODERS 16
 
-#define WHISPER_USE_SCRATCH
-#define WHISPER_MAX_SCRATCH_BUFFERS 16
+//
+// ggml helpers
+//
+
+static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+
+    if (plan.work_size > 0) {
+        buf.resize(plan.work_size);
+        plan.work_data = buf.data();
+    }
+
+    ggml_graph_compute(graph, &plan);
+}
+
+// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
+// the idea is to represent the original matrix multiplication:
+//
+//   Z = X @ Y
+//
+// with the sum of two matrix multiplications:
+//
+//   Z = (X_0 @ Y_0) + (X_1 @ Y_1)
+//
+// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
+// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
+// general-purpose kernels
+//
+static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
+    // use padding only if dimension 0 is at least 8 times larger than the padding
+    // else we won't get much benefit from the optimization
+    const int n_pad_req = 8;
+
+    if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
+        return ggml_mul_mat(ctx, x, y);
+    }
+
+    struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
+    struct ggml_tensor * x_1 = ggml_view_3d(ctx, x,  x->ne[0]%pad,      x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
+
+    struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
+    struct ggml_tensor * y_1 = ggml_view_3d(ctx, y,  y->ne[0]%pad,      y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
+
+    return ggml_add(ctx,
+            ggml_mul_mat(ctx, x_0, y_0),
+            ggml_mul_mat(ctx, x_1, y_1));
+}
+
+// TODO: check if other platforms can benefit from this optimization
+#if defined(GGML_USE_METAL)
+#define ggml_mul_mat ggml_mul_mat_pad
+#endif
 
 // available whisper models
 enum e_model {
@@ -231,38 +288,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
 
 static const size_t MB = 1ull*1024*1024;
 
-static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
-    { MODEL_TINY,     62ull*MB },
-    { MODEL_BASE,     80ull*MB },
-    { MODEL_SMALL,   120ull*MB },
-    { MODEL_MEDIUM,  158ull*MB },
-    { MODEL_LARGE,   198ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
-    { MODEL_TINY,     18ull*MB },
-    { MODEL_BASE,     24ull*MB },
-    { MODEL_SMALL,    36ull*MB },
-    { MODEL_MEDIUM,   48ull*MB },
-    { MODEL_LARGE,    60ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
-    { MODEL_TINY,      4ull*MB },
-    { MODEL_BASE,      4ull*MB },
-    { MODEL_SMALL,     6ull*MB },
-    { MODEL_MEDIUM,    7ull*MB },
-    { MODEL_LARGE,     9ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
-    { MODEL_TINY,      4ull*MB },
-    { MODEL_BASE,      4ull*MB },
-    { MODEL_SMALL,     6ull*MB },
-    { MODEL_MEDIUM,    7ull*MB },
-    { MODEL_LARGE,     9ull*MB },
-};
-
+// TODO: avoid using GGUF
 static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
     { GGML_TYPE_F32,
         {
@@ -329,38 +355,6 @@ static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
     },
 };
 
-static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
-    { MODEL_TINY,      3ull*MB },
-    { MODEL_BASE,      6ull*MB },
-    { MODEL_SMALL,    16ull*MB },
-    { MODEL_MEDIUM,   43ull*MB },
-    { MODEL_LARGE,    71ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
-    { MODEL_TINY,      9ull*MB },
-    { MODEL_BASE,     18ull*MB },
-    { MODEL_SMALL,    53ull*MB },
-    { MODEL_MEDIUM,  141ull*MB },
-    { MODEL_LARGE,   235ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
-    { MODEL_TINY,     30ull*MB },
-    { MODEL_BASE,     38ull*MB },
-    { MODEL_SMALL,    56ull*MB },
-    { MODEL_MEDIUM,   74ull*MB },
-    { MODEL_LARGE,    94ull*MB },
-};
-
-static const std::map<e_model, size_t> MEM_REQ_DECODE = {
-    { MODEL_TINY,      3ull*MB },
-    { MODEL_BASE,      5ull*MB },
-    { MODEL_SMALL,    10ull*MB },
-    { MODEL_MEDIUM,   18ull*MB },
-    { MODEL_LARGE,    27ull*MB },
-};
-
 struct whisper_mel {
     int n_len;
     int n_len_org;
@@ -537,6 +531,7 @@ struct whisper_kv_cache {
 
     struct ggml_context * ctx;
 
+    // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init)
     std::vector<uint8_t> buf;
 
     int n; // number of tokens currently in the cache
@@ -602,7 +597,7 @@ struct whisper_sequence {
 
 // TAGS: WHISPER_DECODER_INIT
 struct whisper_decoder {
-    // each decoders keeps its own KV-cache
+    // each decoder keeps its own KV-cache
     whisper_kv_cache kv_self;
 
     // the currently generated sequence of tokens
@@ -622,15 +617,75 @@ struct whisper_decoder {
     std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
 };
 
+// replace std::pair by using customized pair struct (reason: std::pair is very slow)
+template<typename A, typename B>
+struct whisper_pair {
+    A first;
+    B second;
+
+    // Define a constructor that takes two arguments.
+    whisper_pair(const A& a, const B& b) : first(a), second(b) {}
+    // Define a constructor that takes no argument.
+    whisper_pair() : first(A()), second(B()) {}
+};
+
+// beam-search helpers
+struct kv_buf {
+    std::vector<uint8_t> k;
+    std::vector<uint8_t> v;
+};
+
+// ggml_allocr wrapper for whisper usage
+struct whisper_allocr {
+    ggml_allocr * alloc = nullptr;
+
+    std::vector<uint8_t> meta;
+    std::vector<uint8_t> data;
+};
+
+static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
+    return allocr.meta.size() + allocr.data.size();
+}
+
+// measure the memory usage of a graph and prepare the allocr's internal data buffer
+static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
+    const int tensor_alignment = 32;
+
+    auto & alloc = allocr.alloc;
+    auto & meta  = allocr.meta;
+    auto & data  = allocr.data;
+
+    meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+
+    alloc = ggml_allocr_new_measure(tensor_alignment);
+
+    const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
+
+    ggml_allocr_free(alloc);
+
+    data.resize(alloc_size);
+
+    alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
+}
+
+static void whisper_allocr_free(struct whisper_allocr & allocr) {
+    if (allocr.alloc) {
+        ggml_allocr_free(allocr.alloc);
+        allocr.alloc = nullptr;
+    }
+}
+
 struct whisper_state {
     int64_t t_sample_us = 0;
     int64_t t_encode_us = 0;
     int64_t t_decode_us = 0;
+    int64_t t_prompt_us = 0;
     int64_t t_mel_us = 0;
 
     int32_t n_sample = 0; // number of tokens sampled
     int32_t n_encode = 0; // number of encoder calls
-    int32_t n_decode = 0; // number of decoder calls
+    int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
+    int32_t n_prompt = 0; // number of decoder calls with n_tokens >  1 (prompt encoding)
     int32_t n_fail_p = 0; // number of logprob threshold failures
     int32_t n_fail_h = 0; // number of entropy threshold failures
 
@@ -641,12 +696,23 @@ struct whisper_state {
 
     whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
 
-    // memory buffers used by encode / decode contexts
-    std::vector<uint8_t> buf_compute;
-    std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
+    // buffer for swapping KV caches between decoders during beam-search
+    std::vector<kv_buf> kv_swap_bufs;
 
-    int    buf_last = 0;
-    size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
+    // reusable buffer for `struct ggml_graph_plan.work_data`
+    std::vector<uint8_t> work_buffer;
+
+    // ggml-alloc:
+    // - stores meta info about the intermediate tensors into the `meta` buffers
+    // - stores the actual tensor data into the `data` buffers
+    whisper_allocr alloc_conv;
+    whisper_allocr alloc_encode;
+    whisper_allocr alloc_cross;
+    whisper_allocr alloc_decode;
+
+    // result of the encoder
+    struct ggml_tensor * embd_conv = nullptr;
+    struct ggml_tensor * embd_enc  = nullptr;
 
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
@@ -655,7 +721,7 @@ struct whisper_state {
     std::vector<whisper_token>   prompt_past;
 
     // work container used to avoid memory allocations
-    std::vector<std::pair<double, whisper_vocab::id>> logits_id;
+    std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
 
     mutable std::mt19937 rng; // used for sampling at t > 0.0
 
@@ -666,6 +732,10 @@ struct whisper_state {
     whisper_coreml_context * ctx_coreml = nullptr;
 #endif
 
+#ifdef GGML_USE_METAL
+    ggml_metal_context * ctx_metal = nullptr;
+#endif
+
 #ifdef WHISPER_USE_OPENVINO
     whisper_openvino_context * ctx_openvino = nullptr;
 #endif
@@ -678,37 +748,6 @@ struct whisper_state {
 
     // [EXPERIMENTAL] speed-up techniques
     int32_t exp_n_audio_ctx = 0; // 0 - use default
-
-    void use_buf(struct ggml_context * ctx, int i) {
-#if defined(WHISPER_USE_SCRATCH)
-        size_t last_size = 0;
-
-        if (i == -1) {
-            last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
-        } else {
-            auto & buf = buf_scratch[i];
-            last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
-        }
-
-        if (buf_last >= 0) {
-            buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
-        }
-
-        buf_last = i;
-#else
-        (void) i;
-        (void) ctx;
-#endif
-    }
-
-    size_t get_buf_max_mem(int i) const {
-#if defined(WHISPER_USE_SCRATCH)
-        return buf_max_size[i];
-#else
-        (void) i;
-        return 0;
-#endif
-    }
 };
 
 struct whisper_context {
@@ -755,10 +794,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
 
 static bool kv_cache_init(
         const struct whisper_hparams & hparams,
-                        const size_t   mem_bytes,
              struct whisper_kv_cache & cache,
                            ggml_type   wtype,
                                  int   n_ctx) {
+    const int64_t n_text_state = hparams.n_text_state;
+    const int64_t n_text_layer = hparams.n_text_layer;
+
+    const int64_t n_mem      = n_text_layer*n_ctx;
+    const int64_t n_elements = n_text_state*n_mem;
+
+    const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
+
     cache.buf.resize(mem_bytes);
 
     struct ggml_init_params params = {
@@ -774,12 +820,6 @@ static bool kv_cache_init(
         return false;
     }
 
-    const int n_text_state = hparams.n_text_state;
-    const int n_text_layer = hparams.n_text_layer;
-
-    const int n_mem      = n_text_layer*n_ctx;
-    const int n_elements = n_text_state*n_mem;
-
     cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
     cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
 
@@ -922,22 +962,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         // print memory requirements
         {
-            // this is the total memory required to run the inference
-            const size_t mem_required =
-                     MEM_REQ_SCRATCH0.at(model.type) +
-                     MEM_REQ_SCRATCH1.at(model.type) +
-                     MEM_REQ_SCRATCH2.at(model.type) +
-                     MEM_REQ_SCRATCH3.at(model.type) +
-                scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
-                scale*MEM_REQ_KV_CROSS.at(model.type) +
-                scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
-
-            // this is the memory required by one decoder
-            const size_t mem_required_decoder =
-                scale*MEM_REQ_KV_SELF.at(model.type);
-
-            log("%s: mem required  = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
-                    mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
+            // TODO
+            //log("%s: mem required  = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
+            //        mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
         }
 
         // initialize all memory buffers
@@ -1446,49 +1473,56 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
     return true;
 }
 
-// evaluate the encoder with the given state
-//
-// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
-// part of the transformer model and returns the encoded features
-//
-//   - wctx:      the model
-//   - wstate:     the state of the encoder
-//   - n_threads:  number of threads to use
-//   - mel_offset: offset in the mel spectrogram (i.e. audio offset)
-//
-static bool whisper_encode_internal(
-        whisper_context & wctx,
-          whisper_state & wstate,
-              const int   mel_offset,
-              const int   n_threads){
+static bool whisper_encode_external(const whisper_state & wstate) {
+    GGML_UNUSED(wstate);
 
-    const int64_t t_start_us = ggml_time_us();
+#ifndef WHISPER_USE_COREML
+    const bool use_coreml = false;
+#else
+    const bool use_coreml = wstate.ctx_coreml != nullptr;
+#endif
 
+#ifndef WHISPER_USE_OPENVINO
+    const bool use_openvino = false;
+#else
+    const bool use_openvino = wstate.ctx_openvino != nullptr;
+#endif
+
+    return use_coreml || use_openvino;
+}
+
+static struct ggml_cgraph * whisper_build_graph_conv(
+        whisper_context & wctx,
+          whisper_state & wstate,
+              const int   mel_offset) {
     const auto & model   = wctx.model;
     const auto & mel_inp = wstate.mel;
     const auto & hparams = model.hparams;
 
     const int n_ctx   = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
-    const int n_state = hparams.n_audio_state;
-    const int n_head  = hparams.n_audio_head;
-    const int n_layer = hparams.n_audio_layer;
+    const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state);
 
     const int n_mels = hparams.n_mels;
-    assert(mel_inp.n_mel == n_mels);
 
     struct ggml_init_params params = {
-        /*.mem_size   =*/ wstate.buf_compute.size(),
-        /*.mem_buffer =*/ wstate.buf_compute.data(),
-        /*.no_alloc   =*/ false,
+        /*.mem_size   =*/ wstate.alloc_conv.meta.size(),
+        /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
+        /*.no_alloc   =*/ true,
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    wstate.use_buf(ctx0, 0);
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    ggml_allocr * alloc = wstate.alloc_conv.alloc;
 
     struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
+    ggml_allocr_alloc(alloc, mel);
+
     assert(mel->type == GGML_TYPE_F32);
-    {
+    if (!ggml_allocr_is_measure(alloc)) {
+        assert(mel_inp.n_mel == n_mels);
+
         float * dst = (float *) mel->data;
         memset(dst, 0, ggml_nbytes(mel));
 
@@ -1502,25 +1536,11 @@ static bool whisper_encode_internal(
         }
     }
 
-    struct ggml_tensor * cur;
+    struct ggml_tensor * cur = nullptr;
 
-#ifndef WHISPER_USE_COREML
-    const bool use_coreml = false;
-#else
-    const bool use_coreml = wstate.ctx_coreml != nullptr;
-#endif
-
-#ifndef WHISPER_USE_OPENVINO
-    const bool use_openvino = false;
-#else
-    const bool use_openvino = wstate.ctx_openvino != nullptr;
-#endif
-
-    if (!use_coreml && !use_openvino) {
+    if (!whisper_encode_external(wstate)) {
         // convolution + gelu
         {
-            wstate.use_buf(ctx0, 1);
-
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0,
@@ -1530,8 +1550,6 @@ static bool whisper_encode_internal(
 
             cur = ggml_gelu(ctx0, cur);
 
-            wstate.use_buf(ctx0, 0);
-
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0,
@@ -1542,371 +1560,431 @@ static bool whisper_encode_internal(
             cur = ggml_gelu(ctx0, cur);
         }
 
-        wstate.use_buf(ctx0, 3);
+        wstate.embd_conv = cur;
+    } else {
+#ifdef WHISPER_USE_COREML
+        cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+        ggml_allocr_alloc(alloc, cur);
 
-        // ===================================================================
-        // NOTE: experimenting with partial evaluation of the encoder (ignore)
-        //static int iter = -1;
-        //const int n_iter = 1500/n_ctx;
+        if (!ggml_allocr_is_measure(alloc)) {
+            whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
+        }
+#endif
+#ifdef WHISPER_USE_OPENVINO
+        cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+        ggml_allocr_alloc(alloc, cur);
 
-        //iter = (iter + 1) % n_iter;
+        if (!ggml_allocr_is_measure(alloc)) {
+            whisper_openvino_encode(wstate.ctx_openvino, mel, cur);
+        }
+#endif
 
-        //if (iter == 0) {
-        //    memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
-        //    memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
-        //}
+        wstate.embd_enc = cur;
+    }
 
-        static int iter = 0;
+    ggml_build_forward_expand(gf, cur);
 
-        const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
-        const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
+    ggml_free(ctx0);
 
-        struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
+    return gf;
+}
 
-        cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
+static struct ggml_cgraph * whisper_build_graph_encoder(
+        whisper_context & wctx,
+          whisper_state & wstate) {
+    const auto & model   = wctx.model;
+    const auto & hparams = model.hparams;
 
-        // ===================================================================
+    const int n_ctx   = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+    const int n_state = hparams.n_audio_state;
+    const int n_head  = hparams.n_audio_head;
+    const int n_layer = hparams.n_audio_layer;
 
-        // original:
-        //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ wstate.alloc_encode.meta.size(),
+        /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
+        /*.no_alloc   =*/ true,
+    };
 
-        struct ggml_tensor * inpL = cur;
+    struct ggml_context * ctx0 = ggml_init(params);
 
-        for (int il = 0; il < n_layer; ++il) {
-            const auto & layer = model.layers_encoder[il];
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-            // norm
-            {
-                wstate.use_buf(ctx0, 0);
+    ggml_allocr * alloc = wstate.alloc_encode.alloc;
 
-                cur = ggml_norm(ctx0, inpL, hparams.eps);
+    struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_allocr_alloc(alloc, KQscale);
 
-                // cur = ln_0_w*cur + ln_0_b
-                cur = ggml_add(ctx0,
-                        ggml_mul(ctx0,
-                            ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
-                            cur),
-                        ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
-            }
+    if (!ggml_allocr_is_measure(alloc)) {
+        ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
+    }
 
-            // self-attention
-            {
-                wstate.use_buf(ctx0, 1);
+    struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
 
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
-                        layer.attn_q_w,
-                        cur);
+    // ===================================================================
+    // NOTE: experimenting with partial evaluation of the encoder (ignore)
+    //static int iter = -1;
+    //const int n_iter = 1500/n_ctx;
 
-                Qcur = ggml_add(ctx0,
-                        ggml_repeat(ctx0,
-                            layer.attn_q_b,
-                            Qcur),
-                        Qcur);
+    //iter = (iter + 1) % n_iter;
 
-                //Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+    //if (iter == 0) {
+    //    memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
+    //    memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
+    //}
 
-                // note: no bias for Key
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
-                        layer.attn_k_w,
-                        cur);
+    static int iter = 0;
 
-                //Kcur = ggml_scale_inplace(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+    const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
+    const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
-                        layer.attn_v_w,
-                        cur);
+    struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
 
-                Vcur = ggml_add(ctx0,
-                        ggml_repeat(ctx0,
-                            layer.attn_v_b,
-                            Vcur),
-                        Vcur);
+    cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
 
-                // ------
+    // ===================================================================
+
+    // original:
+    //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+
+    struct ggml_tensor * inpL = cur;
 
-                wstate.use_buf(ctx0, 0);
+    for (int il = 0; il < n_layer; ++il) {
+        const auto & layer = model.layers_encoder[il];
+
+        // norm
+        {
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
+
+            // cur = ln_0_w*cur + ln_0_b
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0, cur, layer.attn_ln_0_w),
+                    layer.attn_ln_0_b);
+        }
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
+                    layer.attn_q_w,
+                    cur);
+
+            Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
+
+            //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+
+            // note: no bias for Key
+            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
+                    layer.attn_k_w,
+                    cur);
+
+            //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+
+            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
+                    layer.attn_v_w,
+                    cur);
+
+            Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b);
+
+            // ------
 
 #ifdef WHISPER_USE_FLASH_ATTN
-                struct ggml_tensor * Q =
-                    ggml_permute(ctx0,
-                            ggml_cpy(ctx0,
-                                Qcur,
-                                ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
-                            0, 2, 1, 3);
-
-                struct ggml_tensor * K =
-                    ggml_permute(ctx0,
-                            ggml_cpy(ctx0,
-                                Kcur,
-                                ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
-                            0, 2, 1, 3);
-
-                struct ggml_tensor * V =
-                    ggml_cpy(ctx0,
-                            ggml_permute(ctx0,
-                                ggml_reshape_3d(ctx0,
-                                    Vcur,
-                                    n_state/n_head, n_head, n_ctx),
-                                1, 2, 0, 3),
-                            ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
-
-                struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Qcur,
+                            ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
+                        0, 2, 1, 3);
+
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Kcur,
+                            ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
+                        0, 2, 1, 3);
+
+            struct ggml_tensor * V =
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                Vcur,
+                                n_state/n_head, n_head, n_ctx),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
+
+            struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
 #else
-                struct ggml_tensor * Q =
-                    ggml_permute(ctx0,
-                            ggml_cpy(ctx0,
-                                Qcur,
-                                ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
-                            0, 2, 1, 3);
-
-                struct ggml_tensor * K =
-                    ggml_permute(ctx0,
-                            ggml_cpy(ctx0,
-                                Kcur,
-                                ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
-                            0, 2, 1, 3);
-
-                // K * Q
-                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-
-                struct ggml_tensor * KQ_scaled =
-                    ggml_scale_inplace(ctx0,
-                            KQ,
-                            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
-                            );
-
-                struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_scaled);
-
-                struct ggml_tensor * V =
-                    ggml_cpy(ctx0,
-                            ggml_permute(ctx0,
-                                ggml_reshape_3d(ctx0,
-                                    Vcur,
-                                    n_state/n_head, n_head, n_ctx),
-                                1, 2, 0, 3),
-                            ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
-                            );
-
-                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Qcur,
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
+                        0, 2, 1, 3);
+
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Kcur,
+                            ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
+                        0, 2, 1, 3);
+
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
+
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
+
+            struct ggml_tensor * V =
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                Vcur,
+                                n_state/n_head, n_head, n_ctx),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
+                        );
+
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 #endif
-                struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-                wstate.use_buf(ctx0, 1);
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+        }
 
-                cur = ggml_cpy(ctx0,
-                        KQV_merged,
-                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
-            }
+        // projection
+        {
+            cur = ggml_mul_mat(ctx0,
+                    layer.attn_ln_1_w,
+                    cur);
 
-            // projection
-            {
-                wstate.use_buf(ctx0, 0);
+            cur = ggml_add(ctx0, cur, layer.attn_ln_1_b);
+        }
 
-                cur = ggml_mul_mat(ctx0,
-                        layer.attn_ln_1_w,
-                        cur);
+        // add the input
+        cur = ggml_add(ctx0, cur, inpL);
 
-                wstate.use_buf(ctx0, 1);
+        struct ggml_tensor * inpFF = cur;
 
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
+
+                // cur = mlp_ln_w*cur + mlp_ln_b
                 cur = ggml_add(ctx0,
-                        ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
-                        cur);
+                        ggml_mul(ctx0, cur, layer.mlp_ln_w),
+                        layer.mlp_ln_b);
             }
 
-            wstate.use_buf(ctx0, 2);
+#ifdef WHISPER_USE_FLASH_FF
+            cur = ggml_flash_ff(ctx0,
+                    ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
+                    layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
+#else
+            // fully connected
+            cur = ggml_mul_mat(ctx0,
+                    layer.mlp_0_w,
+                    cur);
 
-            // add the input
-            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, layer.mlp_0_b);
 
-            struct ggml_tensor * inpFF = cur;
+            // GELU activation
+            cur = ggml_gelu(ctx0, cur);
 
-            // feed-forward network
-            {
-                // norm
-                {
-                    wstate.use_buf(ctx0, 0);
+            // projection
+            cur = ggml_mul_mat(ctx0,
+                    layer.mlp_1_w,
+                    cur);
 
-                    cur = ggml_norm(ctx0, inpFF, hparams.eps);
+            cur = ggml_add(ctx0, cur, layer.mlp_1_b);
+#endif
+        }
 
-                    wstate.use_buf(ctx0, 1);
+        inpL = ggml_add(ctx0, cur, inpFF);
+    }
 
-                    // cur = mlp_ln_w*cur + mlp_ln_b
-                    cur = ggml_add(ctx0,
-                            ggml_mul(ctx0,
-                                ggml_repeat(ctx0, layer.mlp_ln_w, cur),
-                                cur),
-                            ggml_repeat(ctx0, layer.mlp_ln_b, cur));
-                }
+    cur = inpL;
 
-#ifdef WHISPER_USE_FLASH_FF
-                wstate.use_buf(ctx0, 0);
+    // norm
+    {
+        cur = ggml_norm(ctx0, cur, hparams.eps);
 
-                cur = ggml_flash_ff(ctx0,
-                        ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
-                        layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
-#else
-                wstate.use_buf(ctx0, 0);
+        // cur = ln_f_g*cur + ln_f_b
+        cur = ggml_add(ctx0,
+                ggml_mul(ctx0, cur, model.e_ln_w),
+                model.e_ln_b);
+    }
 
-                // fully connected
-                cur = ggml_mul_mat(ctx0,
-                        layer.mlp_0_w,
-                        cur);
+    ggml_build_forward_expand(gf, cur);
 
-                wstate.use_buf(ctx0, 1);
+    wstate.embd_enc = cur;
 
-                cur = ggml_add(ctx0,
-                        ggml_repeat(ctx0, layer.mlp_0_b, cur),
-                        cur);
+    //ggml_graph_print(gf);
 
-                wstate.use_buf(ctx0, 0);
+    ////////////////////////////////////////////////////////////////////////////
 
-                // GELU activation
-                cur = ggml_gelu(ctx0, cur);
+    //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
+    //        ggml_used_mem(ctx0)/1024.0/1024.0,
+    //        wstate.get_buf_max_mem(0)/1024.0/1024.0,
+    //        wstate.get_buf_max_mem(1)/1024.0/1024.0,
+    //        wstate.get_buf_max_mem(2)/1024.0/1024.0,
+    //        wstate.get_buf_max_mem(3)/1024.0/1024.0);
 
-                wstate.use_buf(ctx0, 1);
+    ggml_free(ctx0);
 
-                // projection
-                cur = ggml_mul_mat(ctx0,
-                        layer.mlp_1_w,
-                        cur);
+    return gf;
+}
 
-                wstate.use_buf(ctx0, 0);
+// pre-compute cross-attention memory
+static struct ggml_cgraph * whisper_build_graph_cross(
+        whisper_context & wctx,
+          whisper_state & wstate) {
+    const auto & model   = wctx.model;
+    const auto & hparams = model.hparams;
 
-                cur = ggml_add(ctx0,
-                        ggml_repeat(ctx0, layer.mlp_1_b, cur),
-                        cur);
-#endif
-            }
+    const int n_ctx   = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+    const int n_state = hparams.n_audio_state;
+    const int n_head  = hparams.n_audio_head;
 
-            wstate.use_buf(ctx0, 3);
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ wstate.alloc_cross.meta.size(),
+        /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
+        /*.no_alloc   =*/ true,
+    };
 
-            inpL = ggml_add(ctx0, cur, inpFF);
-        }
+    struct ggml_context * ctx0 = ggml_init(params);
 
-        cur = inpL;
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-        // norm
-        {
-            wstate.use_buf(ctx0, 0);
+    ggml_allocr * alloc = wstate.alloc_cross.alloc;
 
-            cur = ggml_norm(ctx0, cur, hparams.eps);
+    struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
 
-            wstate.use_buf(ctx0, 1);
+    struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_allocr_alloc(alloc, Kscale);
 
-            // cur = ln_f_g*cur + ln_f_b
-            cur = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, model.e_ln_w, cur),
-                        cur),
-                    ggml_repeat(ctx0, model.e_ln_b, cur));
-        }
+    if (!ggml_allocr_is_measure(alloc)) {
+        ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
+    }
 
-        wstate.use_buf(ctx0, -1);
+    for (int il = 0; il < model.hparams.n_text_layer; ++il) {
+        auto & layer = model.layers_decoder[il];
 
-        // run the computation
-        {
-            struct ggml_cgraph gf = {};
+        struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
+                layer.cross_attn_k_w,
+                cur);
 
-            ggml_build_forward_expand  (&gf, cur);
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+        Kcross = ggml_scale(ctx0, Kcross, Kscale);
 
-            //ggml_graph_print(&gf);
-        }
-    }
-#ifdef WHISPER_USE_COREML
-    else if (use_coreml) {
-        wstate.use_buf(ctx0, -1);
+        struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
+                layer.cross_attn_v_w,
+                cur);
 
-        cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+        Vcross = ggml_add(ctx0,
+                    Vcross,
+                    layer.cross_attn_v_b);
 
-        whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
-    }
-#endif
-#ifdef WHISPER_USE_OPENVINO
-    else if (use_openvino) {
-        wstate.use_buf(ctx0, -1);
+        Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
 
-        cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+        struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
+                n_state*n_ctx,
+                (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
 
-        if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
-            return false;
-        }
+        struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
+                (   n_ctx)*ggml_element_size(wstate.kv_cross.v),
+                (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
+
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
     }
-#endif
 
-    // cur
-    //{
-    //    printf("ne0 = %d\n", cur->ne[0]);
-    //    printf("ne1 = %d\n", cur->ne[1]);
-    //    for (int i = 0; i < 10; ++i) {
-    //        printf("%8.4f ", ((float *)(cur->data))[i]);
-    //    }
-    //    printf("... ");
-    //    for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
-    //        printf("%8.4f ", ((float *)(cur->data))[i]);
-    //    }
-    //    printf("\n");
-    //}
+    //ggml_graph_print(gf);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+// evaluate the encoder with the given state
+//
+// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
+// part of the transformer model and returns the encoded features
+//
+//   - wctx:      the model
+//   - wstate:     the state of the encoder
+//   - n_threads:  number of threads to use
+//   - mel_offset: offset in the mel spectrogram (i.e. audio offset)
+//
+static bool whisper_encode_internal(
+        whisper_context & wctx,
+          whisper_state & wstate,
+              const int   mel_offset,
+              const int   n_threads) {
+    const int64_t t_start_us = ggml_time_us();
 
-    // pre-compute cross-attention memory
+    // conv
     {
-        struct ggml_cgraph gf = {};
+        auto & alloc = wstate.alloc_conv.alloc;
 
-        // TODO: hack to disconnect the encoded features from the previous graph
-        cur->op = GGML_OP_NONE;
-        cur->src[0] = nullptr;
-        cur->src[1] = nullptr;
+        ggml_allocr_reset(alloc);
 
-        for (int il = 0; il < model.hparams.n_text_layer; ++il) {
-            auto& layer = model.layers_decoder[il];
+        ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
 
-            wstate.use_buf(ctx0, 0);
+        ggml_allocr_alloc_graph(alloc, gf);
 
-            struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
-                layer.cross_attn_k_w,
-                cur);
-
-            Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
+        if (!whisper_encode_external(wstate)) {
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        }
+    }
 
-            wstate.use_buf(ctx0, 1);
+    // encoder
+    if (!whisper_encode_external(wstate)) {
+        auto & alloc = wstate.alloc_encode.alloc;
 
-            struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
-                layer.cross_attn_v_w,
-                cur);
+        ggml_allocr_reset(alloc);
 
-            Vcross = ggml_add(ctx0,
-                ggml_repeat(ctx0,
-                    layer.cross_attn_v_b,
-                    Vcross),
-                Vcross);
+        ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
 
-            wstate.use_buf(ctx0, -1);
+        ggml_allocr_alloc_graph(alloc, gf);
 
-            Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+#ifdef GGML_USE_METAL
+        if (wstate.ctx_metal) {
+            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
+            ggml_metal_graph_compute(wstate.ctx_metal, gf);
+        } else {
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        }
+#else
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+#endif
+    }
 
-            struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
-            struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
-                    (   n_ctx)*ggml_element_size(wstate.kv_cross.v),
-                    (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
+    // cross
+    {
+        auto & alloc = wstate.alloc_cross.alloc;
 
-            ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
-            ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
-        }
+        ggml_allocr_reset(alloc);
 
-        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-        //ggml_graph_print(&gf);
-    }
+        ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
 
-    ////////////////////////////////////////////////////////////////////////////
+        ggml_allocr_alloc_graph(alloc, gf);
 
-    //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
-    //        ggml_used_mem(ctx0)/1024.0/1024.0,
-    //        wstate.get_buf_max_mem(0)/1024.0/1024.0,
-    //        wstate.get_buf_max_mem(1)/1024.0/1024.0,
-    //        wstate.get_buf_max_mem(2)/1024.0/1024.0,
-    //        wstate.get_buf_max_mem(3)/1024.0/1024.0);
+#ifdef GGML_USE_METAL
+        if (wstate.ctx_metal) {
+            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
+            ggml_metal_graph_compute(wstate.ctx_metal, gf);
+        } else {
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        }
+#else
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+#endif
+    }
 
-    ggml_free(ctx0);
+    // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
     wstate.t_encode_us += ggml_time_us() - t_start_us;
     wstate.n_encode++;
@@ -1914,26 +1992,13 @@ static bool whisper_encode_internal(
     return true;
 }
 
-// evaluate the decoder
-//
-// given text prompt + audio features -> computes the logits for the next token
-//
-//   - model:      the model
-//   - n_threads:  number of threads to use
-//   - tokens:     text prompt
-//   - n_tokens:   number of tokens in the prompt
-//   - n_past:     number of past tokens to prefix the prompt with
-//
-static bool whisper_decode_internal(
-        whisper_context & wctx,
-          whisper_state & wstate,
-        whisper_decoder & decoder,
-    const whisper_token * tokens,
-              const int   n_tokens,
-              const int   n_past,
-              const int   n_threads) {
-    const int64_t t_start_us = ggml_time_us();
-
+static struct ggml_cgraph * whisper_build_graph_decoder(
+         whisper_context & wctx,
+         whisper_state   & wstate,
+         whisper_decoder & decoder,
+     const whisper_token * tokens,
+                   int   n_tokens,
+                   int   n_past) {
     const auto & model   = wctx.model;
     const auto & hparams = model.hparams;
 
@@ -1941,10 +2006,6 @@ static bool whisper_decode_internal(
 
     WHISPER_ASSERT(!!kv_self.ctx);
 
-    auto & logits_out = wstate.logits;
-
-    const int n_vocab = hparams.n_vocab;
-
     const int n_ctx   = hparams.n_text_ctx;
     const int n_state = hparams.n_text_state;
     const int n_head  = hparams.n_text_head;
@@ -1956,24 +2017,39 @@ static bool whisper_decode_internal(
     //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
 
     struct ggml_init_params params = {
-        /*.mem_size   =*/ wstate.buf_compute.size(),
-        /*.mem_buffer =*/ wstate.buf_compute.data(),
-        /*.no_alloc   =*/ false,
+        /*.mem_size   =*/ wstate.alloc_decode.meta.size(),
+        /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
+        /*.no_alloc   =*/ true,
     };
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    struct ggml_cgraph gf = {};
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    ggml_allocr * alloc = wstate.alloc_decode.alloc;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
-    memcpy(embd->data, tokens, N*ggml_element_size(embd));
+    ggml_allocr_alloc(alloc, embd);
+
+    if (!ggml_allocr_is_measure(alloc)) {
+        memcpy(embd->data, tokens, N*ggml_element_size(embd));
+    }
 
     struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
-    for (int i = 0; i < N; ++i) {
-        ((int32_t *) position->data)[i] = n_past + i;
+    ggml_allocr_alloc(alloc, position);
+
+    if (!ggml_allocr_is_measure(alloc)) {
+        for (int i = 0; i < N; ++i) {
+            ((int32_t *) position->data)[i] = n_past + i;
+        }
     }
 
-    wstate.use_buf(ctx0, 3);
+    struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_allocr_alloc(alloc, KQscale);
+
+    if (!ggml_allocr_is_measure(alloc)) {
+        ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
+    }
 
     // token encoding + position encoding
     struct ggml_tensor * cur =
@@ -1988,16 +2064,14 @@ static bool whisper_decode_internal(
 
         // norm
         {
-            wstate.use_buf(ctx0, 0);
-
             cur = ggml_norm(ctx0, inpL, hparams.eps);
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
                     ggml_mul(ctx0,
-                        ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
-                        cur),
-                    ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
+                        cur,
+                        layer.attn_ln_0_w),
+                    layer.attn_ln_0_b);
         }
 
         // self-attention
@@ -2007,19 +2081,17 @@ static bool whisper_decode_internal(
                     cur);
 
             Qcur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.attn_q_b,
-                        Qcur),
-                    Qcur);
+                        Qcur,
+                        layer.attn_q_b);
 
-            Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+            Qcur = ggml_scale(ctx0, Qcur, KQscale);
 
             // note: no bias for Key
             struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
                     layer.attn_k_w,
                     cur);
 
-            Kcur = ggml_scale_inplace(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+            Kcur = ggml_scale(ctx0, Kcur, KQscale);
 
             // store key and value to memory
             {
@@ -2028,10 +2100,8 @@ static bool whisper_decode_internal(
                         cur);
 
                 Vcur = ggml_add(ctx0,
-                        ggml_repeat(ctx0,
-                            layer.attn_v_b,
-                            Vcur),
-                        Vcur);
+                            Vcur,
+                            layer.attn_v_b);
 
                 Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
 
@@ -2040,42 +2110,32 @@ static bool whisper_decode_internal(
                         (   n_ctx)*ggml_element_size(kv_self.v),
                         (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
 
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             // ------
 
-            wstate.use_buf(ctx0, 0);
-
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Qcur,
-                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
-                ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0,
-                            ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
-                            n_state/n_head, n_head, n_past + N),
-                        0, 2, 1, 3);
-
-            wstate.use_buf(ctx0, 1);
+                ggml_view_3d(ctx0, kv_self.k,
+                        n_state/n_head, n_past + N, n_head,
+                        ggml_element_size(kv_self.k)*n_state,
+                        ggml_element_size(kv_self.k)*n_state/n_head,
+                        ggml_element_size(kv_self.k)*n_state*n_ctx*il);
 
             // K * Q
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
-            //struct ggml_tensor * KQ_scaled =
-            //    ggml_scale_inplace(ctx0,
-            //            KQ,
-            //            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
-            //            );
+            //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
 
-            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ, n_past);
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
@@ -2095,36 +2155,28 @@ static bool whisper_decode_internal(
 
         // projection
         {
-            wstate.use_buf(ctx0, 0);
-
             cur = ggml_mul_mat(ctx0,
                     layer.attn_ln_1_w,
                     cur);
 
-            wstate.use_buf(ctx0, 1);
-
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
-                    cur);
+                    cur,
+                    layer.attn_ln_1_b);
         }
 
-        wstate.use_buf(ctx0, 2);
-
         // add the input
         struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
 
         // norm
         {
-            wstate.use_buf(ctx0, 0);
-
             cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
                     ggml_mul(ctx0,
-                        ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur),
-                        cur),
-                    ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur));
+                        cur,
+                        layer.cross_attn_ln_0_w),
+                    layer.cross_attn_ln_0_b);
         }
 
         // cross-attention
@@ -2134,18 +2186,18 @@ static bool whisper_decode_internal(
                     cur);
 
             Qcur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.cross_attn_q_b,
-                        Qcur),
-                    Qcur);
+                        Qcur,
+                        layer.cross_attn_q_b);
 
-            Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
+            Qcur = ggml_scale(ctx0, Qcur, KQscale);
 
             // Kcross is already scaled
             struct ggml_tensor * Kcross =
-                ggml_reshape_3d(ctx0,
-                        ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
-                        n_state/n_head, n_head, M);
+                ggml_view_3d(ctx0, wstate.kv_cross.k,
+                        n_state/n_head, M, n_head,
+                        ggml_element_size(wstate.kv_cross.k)*n_state,
+                        ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
+                        ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
 
             //struct ggml_tensor * Vcross =
             //    ggml_reshape_3d(ctx0,
@@ -2168,26 +2220,22 @@ static bool whisper_decode_internal(
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Qcur,
-                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
                         0, 2, 1, 3);
 
-            struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
-
             // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
 
             //struct ggml_tensor * KQ_scaled =
-            //    ggml_scale_inplace(ctx0,
+            //    ggml_scale(ctx0,
             //            KQ,
             //            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
             //            );
 
             // no masking for cross-attention
-            //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+            //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ);
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
@@ -2201,21 +2249,15 @@ static bool whisper_decode_internal(
 
         // projection
         {
-            wstate.use_buf(ctx0, 0);
-
             cur = ggml_mul_mat(ctx0,
                     layer.cross_attn_ln_1_w,
                     cur);
 
-            wstate.use_buf(ctx0, 1);
-
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
-                    cur);
+                    cur,
+                    layer.cross_attn_ln_1_b);
         }
 
-        wstate.use_buf(ctx0, 2);
-
         // add the input
         cur = ggml_add(ctx0, cur, inpCA);
 
@@ -2225,54 +2267,38 @@ static bool whisper_decode_internal(
         {
             // norm
             {
-                wstate.use_buf(ctx0, 0);
-
                 cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
-                wstate.use_buf(ctx0, 1);
-
                 // cur = mlp_ln_w*cur + mlp_ln_b
                 cur = ggml_add(ctx0,
                         ggml_mul(ctx0,
-                            ggml_repeat(ctx0, layer.mlp_ln_w, cur),
-                            cur),
-                        ggml_repeat(ctx0, layer.mlp_ln_b, cur));
+                            cur,
+                            layer.mlp_ln_w),
+                        layer.mlp_ln_b);
             }
 
-            wstate.use_buf(ctx0, 0);
-
             // fully connected
             cur = ggml_mul_mat(ctx0,
                     layer.mlp_0_w,
                     cur);
 
-            wstate.use_buf(ctx0, 1);
-
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.mlp_0_b, cur),
-                    cur);
-
-            wstate.use_buf(ctx0, 0);
+                    cur,
+                    layer.mlp_0_b);
 
             // GELU activation
             cur = ggml_gelu(ctx0, cur);
 
-            wstate.use_buf(ctx0, 1);
-
             // projection
             cur = ggml_mul_mat(ctx0,
                     layer.mlp_1_w,
                     cur);
 
-            wstate.use_buf(ctx0, 0);
-
             cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.mlp_1_b, cur),
-                    cur);
+                    cur,
+                    layer.mlp_1_b);
         }
 
-        wstate.use_buf(ctx0, 3);
-
         inpL = ggml_add(ctx0, cur, inpFF);
     }
 
@@ -2280,21 +2306,15 @@ static bool whisper_decode_internal(
 
     // norm
     {
-        wstate.use_buf(ctx0, 0);
-
         cur = ggml_norm(ctx0, cur, hparams.eps);
 
-        wstate.use_buf(ctx0, 1);
-
         cur = ggml_add(ctx0,
                 ggml_mul(ctx0,
-                    ggml_repeat(ctx0, model.d_ln_w, cur),
-                    cur),
-                ggml_repeat(ctx0, model.d_ln_b, cur));
+                    cur,
+                    model.d_ln_w),
+                model.d_ln_b);
     }
 
-    wstate.use_buf(ctx0, 0);
-
     // compute logits only for the last token
     // comment this line to compute logits for all N tokens
     // might be useful in the future
@@ -2302,23 +2322,75 @@ static bool whisper_decode_internal(
 
     struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
 
-    wstate.use_buf(ctx0, -1);
+    ggml_build_forward_expand(gf, logits);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+// evaluate the decoder
+//
+// given text prompt + audio features -> computes the logits for the next token
+//
+//   - model:      the model
+//   - n_threads:  number of threads to use
+//   - tokens:     text prompt
+//   - n_tokens:   number of tokens in the prompt
+//   - n_past:     number of past tokens to prefix the prompt with
+//
+static bool whisper_decode_internal(
+        whisper_context & wctx,
+          whisper_state & wstate,
+        whisper_decoder & decoder,
+    const whisper_token * tokens,
+              const int   n_tokens,
+              const int   n_past,
+              const int   n_threads) {
+    const int64_t t_start_us = ggml_time_us();
+
+    const auto & model   = wctx.model;
+    const auto & hparams = model.hparams;
+
+    const int n_vocab = hparams.n_vocab;
+
+    auto & logits_out = wstate.logits;
+
+    struct ggml_tensor * logits;
 
-    // run the computation
+    // decoder
     {
-        ggml_build_forward_expand  (&gf, logits);
-        ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+        auto & alloc = wstate.alloc_decode.alloc;
+
+        ggml_allocr_reset(alloc);
+
+        ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
+
+        ggml_allocr_alloc_graph(alloc, gf);
+
+        logits = gf->nodes[gf->n_nodes - 1];
+
+#ifdef GGML_USE_METAL
+        if (wstate.ctx_metal) {
+            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
+            ggml_metal_graph_compute(wstate.ctx_metal, gf);
+        } else {
+            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+        }
+#else
+        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
+#endif
     }
 
     // extract logits for all N tokens
-    //logits_out.resize(N*n_vocab);
-    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
+    //logits_out.resize(n_tokens*n_vocab);
+    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
 
     // extract logits only for the last token
     logits_out.resize(n_vocab);
     memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
 
-    if (N > 1) {
+    if (n_tokens > 1) {
         //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
         //        ggml_used_mem(ctx0)/1024.0/1024.0,
         //        wstate.get_buf_max_mem(0)/1024.0/1024.0,
@@ -2327,14 +2399,18 @@ static bool whisper_decode_internal(
         //        wstate.get_buf_max_mem(3)/1024.0/1024.0);
     }
 
-    ggml_free(ctx0);
-
-    wstate.t_decode_us += ggml_time_us() - t_start_us;
-    wstate.n_decode++;
+    if (n_tokens == 1) {
+        wstate.t_decode_us += ggml_time_us() - t_start_us;
+        wstate.n_decode++;
+    } else {
+        wstate.t_prompt_us += ggml_time_us() - t_start_us;
+        wstate.n_prompt++;
+    }
 
     return true;
 }
 
+
 //  500 -> 00:05.000
 // 6000 -> 01:00.000
 static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2743,9 +2819,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     fill_sin_cos_table();
     whisper_state * state = new whisper_state;
 
-    const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
-
-    if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
+    if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
         log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         delete state;
         return nullptr;
@@ -2756,7 +2830,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         log("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
-    if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
+    if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
         log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
         delete state;
         return nullptr;
@@ -2777,6 +2851,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     if (!state->ctx_coreml) {
         log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
 #ifndef WHISPER_COREML_ALLOW_FALLBACK
+        delete state;
         return nullptr;
 #endif
     } else {
@@ -2791,15 +2866,111 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     // TAGS: WHISPER_DECODER_INIT
     state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
 
-    state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
-    state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
+    state->decoders[0].probs.reserve   (ctx->vocab.n_vocab);
+    state->decoders[0].logits.reserve  (ctx->vocab.n_vocab);
     state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
-    state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
 
-    state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
-    state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
-    state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
-    state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
+    // conv allocator
+    {
+        whisper_allocr_graph_init(state->alloc_conv,
+                [&]() {
+                    return whisper_build_graph_conv(*ctx, *state, 0);
+                });
+
+        log("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
+    }
+
+    // encoder allocator
+    if (!whisper_encode_external(*state)) {
+        whisper_allocr_graph_init(state->alloc_encode,
+                [&]() {
+                    return whisper_build_graph_encoder(*ctx, *state);
+                });
+
+        log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
+    }
+
+    // cross allocator
+    {
+        whisper_allocr_graph_init(state->alloc_cross,
+                [&]() {
+                    return whisper_build_graph_cross(*ctx, *state);
+                });
+
+        log("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
+    }
+
+    // decoder allocator
+    {
+        whisper_allocr_graph_init(state->alloc_decode,
+                [&]() {
+                    const auto & hparams = ctx->model.hparams;
+
+                    // TODO: make sure this is the worst-case scenario
+                    const int n_tokens = hparams.n_text_ctx;
+                    const int n_past   = 0;
+
+                    return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
+                });
+
+        log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
+    }
+
+#ifdef GGML_USE_METAL
+    state->ctx_metal = ggml_metal_init(1);
+    if (!state->ctx_metal) {
+        log("%s: ggml_metal_init() failed\n", __func__);
+        delete state;
+        return nullptr;
+    }
+
+    log("%s: Metal context initialized\n", __func__);
+
+    // this allocates all Metal resources and memory buffers
+
+    void * data_ptr  = NULL;
+    size_t data_size = 0;
+
+    // TODO: add mmap support
+    //if (params.use_mmap) {
+    //    data_ptr  = ctx->model.mapping->addr;
+    //    data_size = ctx->model.mapping->size;
+    //} else {
+    //    data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
+    //    data_size = ggml_get_mem_size  (ctx->model.ctx);
+    //}
+
+    data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
+    data_size = ggml_get_mem_size  (ctx->model.ctx);
+
+    const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
+
+    log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
+
+#define WHISPER_METAL_CHECK_BUF(result)              \
+    if (!(result)) {                                 \
+        log("%s: failed to add metal buffer\n", __func__); \
+        delete state;                                \
+        return nullptr;                              \
+    }
+
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
+
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv",   state->alloc_conv.meta.data(),   state->alloc_conv.meta.size(),   0));
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross",  state->alloc_cross.meta.data(),  state->alloc_cross.meta.size(),  0));
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
+
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv",   state->alloc_conv.data.data(),   state->alloc_conv.data.size(),   0));
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross",  state->alloc_cross.data.data(),  state->alloc_cross.data.size(),  0));
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
+
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross",  state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
+
+    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
+#undef WHISPER_METAL_CHECK_BUF
+#endif
 
     state->rng = std::mt19937(0);
 
@@ -2856,7 +3027,6 @@ int whisper_ctx_init_openvino_encoder(
 }
 
 struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
-
     log("%s: loading model from '%s'\n", __func__, path_model);
 
     auto fin = std::ifstream(path_model, std::ios::binary);
@@ -3009,6 +3179,13 @@ void whisper_free_state(struct whisper_state * state)
         }
 #endif
 
+#ifdef GGML_USE_METAL
+        if (state->ctx_metal) {
+            ggml_metal_free(state->ctx_metal);
+            state->ctx_metal = nullptr;
+        }
+#endif
+
 #ifdef WHISPER_USE_OPENVINO
         if (state->ctx_openvino != nullptr) {
             whisper_openvino_free(state->ctx_openvino);
@@ -3016,6 +3193,11 @@ void whisper_free_state(struct whisper_state * state)
         }
 #endif
 
+        whisper_allocr_free(state->alloc_conv);
+        whisper_allocr_free(state->alloc_decode);
+        whisper_allocr_free(state->alloc_cross);
+        whisper_allocr_free(state->alloc_encode);
+
         delete state;
     }
 }
@@ -3436,12 +3618,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
         const int32_t n_sample = std::max(1, ctx->state->n_sample);
         const int32_t n_encode = std::max(1, ctx->state->n_encode);
         const int32_t n_decode = std::max(1, ctx->state->n_decode);
+        const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
 
         log("%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
         log("%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
         log("%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
         log("%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
         log("%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+        log("%s:   prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
     }
     log("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
@@ -3451,6 +3635,11 @@ void whisper_reset_timings(struct whisper_context * ctx) {
         ctx->state->t_sample_us = 0;
         ctx->state->t_encode_us = 0;
         ctx->state->t_decode_us = 0;
+        ctx->state->t_prompt_us = 0;
+        ctx->state->n_sample = 0;
+        ctx->state->n_encode = 0;
+        ctx->state->n_decode = 0;
+        ctx->state->n_prompt = 0;
     }
 }
 
@@ -3480,6 +3669,7 @@ const char * whisper_print_system_info(void) {
     s += "FMA = "       + std::to_string(ggml_cpu_has_fma())       + " | ";
     s += "NEON = "      + std::to_string(ggml_cpu_has_neon())      + " | ";
     s += "ARM_FMA = "   + std::to_string(ggml_cpu_has_arm_fma())   + " | ";
+    s += "METAL = "     + std::to_string(ggml_cpu_has_metal())     + " | ";
     s += "F16C = "      + std::to_string(ggml_cpu_has_f16c())      + " | ";
     s += "FP16_VA = "   + std::to_string(ggml_cpu_has_fp16_va())   + " | ";
     s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
@@ -3975,17 +4165,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
 
     auto & logits_id = state.logits_id;
 
-    logits_id.clear();
+    logits_id.resize(n_logits);
     for (int i = 0; i < n_logits; ++i) {
-        logits_id.push_back({ logits[i], i });
+        logits_id[i].first = logits[i];
+        logits_id[i].second = i;
     }
 
-    std::partial_sort(
-            logits_id.begin(),
-            logits_id.begin() + k, logits_id.end(),
-            [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
-                return a.first > b.first;
-            });
+    {
+        using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
+        std::partial_sort(
+                logits_id.begin(),
+                logits_id.begin() + k, logits_id.end(),
+                [](const pair_type & a, const pair_type & b) {
+            return a.first > b.first;
+        });
+    }
 
     std::vector<whisper_token_data> result;
     result.reserve(k);
@@ -4080,6 +4274,115 @@ static void whisper_sequence_score(
     }
 }
 
+static bool whisper_kv_swap_fast(
+                   std::vector<int> & view,
+                    whisper_decoder   src[],
+                std::vector<kv_buf> & kv_swap_bufs,
+                          const int & n_decoders) {
+    WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
+
+    // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
+    std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
+
+    // (buffer->decoder or decoder->decoder)
+    std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
+
+    // (decoder<->decoder)
+    std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
+    std::vector<whisper_pair<int, int>> p_swap_vec;
+    p_swap_vec.reserve(n_decoders);
+
+    // see https://github.com/ggerganov/whisper.cpp/wiki
+    for (int i = 0; i < n_decoders; i++) {
+        // zero-copy (no modification)
+        if (i == view[i] || view[i] < 0) {
+            continue;
+        }
+
+        bool is_one_copy = true;
+        // since we modify data sequentially, we only consider decoder indices after current index
+        for (int j = i + 1; j < n_decoders; j++) {
+            if (i == view[j]) {
+                // detect symmetric diagram
+                if (j == view[i]) {
+                    p_swap_set.insert(i);
+                    p_swap_set.insert(j);
+                    p_swap_vec.emplace_back(i, j);
+                } else {
+                    two_copy.insert(i);
+                    is_one_copy = false;
+                }
+                break;
+            }
+        }
+        if (is_one_copy) {
+            one_copy.insert(i);
+        }
+    }
+
+    kv_swap_bufs.resize(n_decoders);
+
+    for (int i = 0; i < n_decoders; i++) {
+        kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
+        kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
+    }
+
+    for (auto & i : two_copy) {
+        // make a copy of KV caches
+        WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
+        memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
+        memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
+    }
+
+    // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
+    for (auto & i : two_copy) {
+        // skip the decoder indices that require pointer swapping
+        if (p_swap_set.find(i) != p_swap_set.end()) {
+            continue;
+        }
+
+        if (two_copy.find(view[i]) != two_copy.end()) {
+            // modify KV caches of decoder using data from kv_swap_bufs
+            WHISPER_PRINT_DEBUG("%s: two-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
+            memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
+            memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+        } else {
+            // modify KV caches of decoder using data from correspond decoder KV caches directly
+            WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
+            memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
+            memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+        }
+    }
+
+    // then modify one-copy decoder KV caches
+    for (auto & i : one_copy) {
+        // skip the decoder indices that require pointer swapping
+        if (p_swap_set.find(i) != p_swap_set.end()) {
+            continue;
+        }
+
+        if (two_copy.find(view[i]) != two_copy.end()) {
+            // modify KV caches of decoder using data from kv_swap_bufs
+            WHISPER_PRINT_DEBUG("%s: one-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
+            memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
+            memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+        } else {
+            // modify KV caches of decoder using data from correspond decoder KV caches directly
+            WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
+            memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
+            memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+        }
+    }
+
+    // swap the pointers
+    for (auto & i : p_swap_vec) {
+        WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
+        std::swap(src[i.first].kv_self, src[i.second].kv_self);
+    }
+
+    return true;
+}
+
 int whisper_full_with_state(
         struct whisper_context * ctx,
           struct whisper_state * state,
@@ -4187,6 +4490,21 @@ int whisper_full_with_state(
             decoder.probs.resize   (ctx->vocab.n_vocab);
             decoder.logits.resize  (ctx->vocab.n_vocab);
             decoder.logprobs.resize(ctx->vocab.n_vocab);
+
+            // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
+#ifdef GGML_USE_METAL
+#define WHISPER_METAL_CHECK_BUF(result)              \
+            if (!(result)) {                                 \
+                log("%s: failed to add metal buffer\n", __func__); \
+                return 0;                              \
+            }
+
+            const std::string kv_name = "kv_self_" + std::to_string(j);
+            auto & kv_self = decoder.kv_self;
+
+            WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
+#undef WHISPER_METAL_CHECK_BUF
+#endif
         }
     }
 
@@ -4243,14 +4561,6 @@ int whisper_full_with_state(
     std::vector<whisper_token> prompt;
     prompt.reserve(whisper_n_text_ctx(ctx));
 
-    // beam-search helpers
-    struct kv_buf {
-        std::vector<uint8_t> k;
-        std::vector<uint8_t> v;
-    };
-
-    std::vector<kv_buf> kv_bufs;
-
     struct beam_candidate {
         int decoder_idx;
         int seek_delta;
@@ -4387,8 +4697,8 @@ int whisper_full_with_state(
 
                         decoder.kv_self.n += prompt.size();
 
-                        memcpy(decoder.probs.data(), state->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
-                        memcpy(decoder.logits.data(), state->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
+                        memcpy(decoder.probs.data(),    state->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
+                        memcpy(decoder.logits.data(),   state->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
                         memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
                     }
 
@@ -4399,23 +4709,7 @@ int whisper_full_with_state(
             for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
                 const int64_t t_start_sample_us = ggml_time_us();
 
-                // store the KV caches of all decoders when doing beam-search
                 if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
-                    kv_bufs.resize(n_decoders_cur);
-                    for (int j = 0; j < n_decoders_cur; ++j) {
-                        auto & decoder = state->decoders[j];
-
-                        if (decoder.completed || decoder.failed) {
-                            continue;
-                        }
-
-                        kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
-                        kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
-
-                        memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
-                        memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
-                    }
-
                     beam_candidates.clear();
                 }
 
@@ -4463,6 +4757,7 @@ int whisper_full_with_state(
                     });
 
                     uint32_t cur_c = 0;
+                    std::vector<int> decoder_idx(n_decoders_cur, -1);
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
                         auto & decoder = state->decoders[j];
@@ -4481,12 +4776,13 @@ int whisper_full_with_state(
                         decoder.seek_delta = cur.seek_delta;
                         decoder.has_ts     = cur.has_ts;
 
-                        memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
-                        memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
-
+                        decoder_idx[j] = cur.decoder_idx;
                         WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
                                 __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
                     }
+
+                    // update KV caches
+                    whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
                 }
 
                 // update the decoder state
@@ -4915,6 +5211,12 @@ int whisper_full_parallel(
         ctx->state->t_sample_us += states[i]->t_sample_us;
         ctx->state->t_encode_us += states[i]->t_encode_us;
         ctx->state->t_decode_us += states[i]->t_decode_us;
+        ctx->state->t_prompt_us += states[i]->t_prompt_us;
+
+        ctx->state->n_sample += states[i]->n_sample;
+        ctx->state->n_encode += states[i]->n_encode;
+        ctx->state->n_decode += states[i]->n_decode;
+        ctx->state->n_prompt += states[i]->n_prompt;
 
         whisper_free_state(states[i]);
     }
@@ -5111,7 +5413,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
     // b: N*N*sizeof(float)
     // c: N*N*sizeof(float)
     // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
-    std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*512);
+    std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
+    std::vector<uint8_t> work;
 
     // put a bunch of random data in the buffer
     for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
@@ -5166,12 +5469,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
             double tsum = 0.0;
 
             // heat-up
-            ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+            ggml_graph_compute_helper(work, &gf, n_threads);
 
             for (int i = 0; i < n_max; ++i) {
                 const int64_t t0 = ggml_time_us();
 
-                ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+                ggml_graph_compute_helper(work, &gf, n_threads);
 
                 const int64_t t1 = ggml_time_us();
 
index eff6090eb1f41797cd8fc7835838d953f742b3c8..f45456876da621d5beafdc2419d9d575a9fb02ff 100644 (file)
@@ -1970,6 +1970,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
+    GGML_API int ggml_cpu_has_metal      (void);
     GGML_API int ggml_cpu_has_f16c       (void);
     GGML_API int ggml_cpu_has_fp16_va    (void);
     GGML_API int ggml_cpu_has_wasm_simd  (void);
index 1c74859b65e6ae9e5483048fe02f5a0c919f9876..c17a091da9b49b4ba1d46297d84046dfc9b606e9 100755 (executable)
@@ -1,6 +1,7 @@
 #!/bin/bash
 
 cp -rpv ../whisper.cpp/ggml.c                         src/ggml.c
+cp -rpv ../whisper.cpp/ggml-alloc.c                   src/ggml-alloc.c
 cp -rpv ../whisper.cpp/ggml-cuda.h                    src/ggml-cuda.h
 cp -rpv ../whisper.cpp/ggml-cuda.cu                   src/ggml-cuda.cu
 cp -rpv ../whisper.cpp/ggml-opencl.h                  src/ggml-opencl.h
@@ -9,6 +10,7 @@ cp -rpv ../whisper.cpp/ggml-metal.h                   src/ggml-metal.h
 cp -rpv ../whisper.cpp/ggml-metal.m                   src/ggml-metal.m
 cp -rpv ../whisper.cpp/ggml-metal.metal               src/ggml-metal.metal
 cp -rpv ../whisper.cpp/ggml.h                         include/ggml/ggml.h
+cp -rpv ../whisper.cpp/ggml-alloc.h                   include/ggml/ggml-alloc.h
 cp -rpv ../whisper.cpp/examples/common.h              examples/common.h
 cp -rpv ../whisper.cpp/examples/common.cpp            examples/common.cpp
 cp -rpv ../whisper.cpp/examples/common-ggml.h         examples/common-ggml.h
index a1f6e7bf4f66eec4d9437c86624cb85d7d72bc39..304964be4f3f07e1c10db0f46c4e013e813c989b 100644 (file)
@@ -131,6 +131,10 @@ static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_ten
     return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
 }
 
+static bool ggml_is_view(struct ggml_tensor * t) {
+    return t->view_src != NULL;
+}
+
 void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
 #ifdef GGML_ALLOCATOR_DEBUG
     GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
@@ -338,8 +342,8 @@ static void free_vmem(void * base_addr, size_t size) {
 
 // allocate uncommitted virtual memory to measure the size of the graph
 static void alloc_measure_vmem(void ** base_addr, size_t * size) {
-    // 1TB for 64-bit, 1GB for 32-bit
-    *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<40;
+    // 128GB for 64-bit, 1GB for 32-bit
+    *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
     do {
         *base_addr = alloc_vmem(*size);
         if (*base_addr != NULL) {
@@ -399,10 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
 
 //////////// compute graph allocator
 
-static bool ggml_is_view(struct ggml_tensor * t) {
-    return t->view_src != NULL;
-}
-
 static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
     if (a->type != b->type) {
         return false;
index 7e2355ce6bcc7eeb8abe7f5290cad19e9629cb35..1139ee3114610e05726b7e4319843eade838271c 100644 (file)
@@ -63,7 +63,10 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(relu);
     GGML_METAL_DECL_KERNEL(gelu);
     GGML_METAL_DECL_KERNEL(soft_max);
+    GGML_METAL_DECL_KERNEL(soft_max_4);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
+    GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
+    GGML_METAL_DECL_KERNEL(get_rows_f32);
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -75,8 +78,10 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
+    GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
+    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -85,6 +90,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@@ -117,14 +123,17 @@ static NSString * const msl_library_source = @"see metal.metal";
 struct ggml_metal_context * ggml_metal_init(int n_cb) {
     metal_printf("%s: allocating\n", __func__);
 
-    // Show all the Metal device instances in the system
-    NSArray * devices = MTLCopyAllDevices();
     id <MTLDevice> device;
     NSString * s;
+
+#if TARGET_OS_OSX
+    // Show all the Metal device instances in the system
+    NSArray * devices = MTLCopyAllDevices();
     for (device in devices) {
         s = [device name];
         metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
     }
+#endif
 
     // Pick and show default Metal device
     device = MTLCreateSystemDefaultDevice();
@@ -139,14 +148,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     ctx->n_buffers = 0;
     ctx->concur_list_len = 0;
 
-    ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
+    ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
-#if 0
-    // compile from source string and show compile log
+#ifdef GGML_SWIFT
+    // load the default.metallib file
     {
         NSError * error = nil;
 
-        ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
+        NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
+        NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
+        NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
+        NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
+        NSURL * libURL = [NSURL fileURLWithPath:libPath];
+
+        // Load the metallib file into a Metal library
+        ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
+
         if (error) {
             metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
             return NULL;
@@ -161,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
         //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
         NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
-        NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+        NSString * path   = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
         metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
 
         NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@@ -207,7 +224,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(relu);
         GGML_METAL_ADD_KERNEL(gelu);
         GGML_METAL_ADD_KERNEL(soft_max);
+        GGML_METAL_ADD_KERNEL(soft_max_4);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
+        GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
+        GGML_METAL_ADD_KERNEL(get_rows_f32);
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -219,8 +239,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
+        GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
+        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -229,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@@ -247,13 +270,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #undef GGML_METAL_ADD_KERNEL
     }
 
-    metal_printf("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
     metal_printf("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+#if TARGET_OS_OSX
+    metal_printf("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
     if (ctx->device.maxTransferRate != 0) {
         metal_printf("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
     } else {
         metal_printf("%s: maxTransferRate               = built-in GPU\n", __func__);
     }
+#endif
 
     return ctx;
 }
@@ -273,7 +298,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(relu);
     GGML_METAL_DEL_KERNEL(gelu);
     GGML_METAL_DEL_KERNEL(soft_max);
+    GGML_METAL_DEL_KERNEL(soft_max_4);
     GGML_METAL_DEL_KERNEL(diag_mask_inf);
+    GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
+    GGML_METAL_DEL_KERNEL(get_rows_f32);
     GGML_METAL_DEL_KERNEL(get_rows_f16);
     GGML_METAL_DEL_KERNEL(get_rows_q4_0);
     GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -285,8 +313,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
+    GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
+    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
     GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -295,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@@ -365,6 +396,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
     for (int i = 0; i < ctx->n_buffers; ++i) {
         const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
 
+        //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
         if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
             *offs = (size_t) ioffs;
 
@@ -454,6 +486,7 @@ bool ggml_metal_add_buffer(
             }
         }
 
+#if TARGET_OS_OSX
         metal_printf(", (%8.2f / %8.2f)",
                 ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
                 ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
@@ -463,6 +496,9 @@ bool ggml_metal_add_buffer(
         } else {
             metal_printf("\n");
         }
+#else
+        metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
+#endif
     }
 
     return true;
@@ -698,6 +734,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ADD:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
+                            GGML_ASSERT(ggml_is_contiguous(src1));
 
                             // utilize float4
                             GGML_ASSERT(ne00 % 4 == 0);
@@ -705,6 +742,7 @@ void ggml_metal_graph_compute(
 
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
+                                GGML_ASSERT(ne11 == 1);
                                 [encoder setComputePipelineState:ctx->pipeline_add_row];
                             } else {
                                 [encoder setComputePipelineState:ctx->pipeline_add];
@@ -721,6 +759,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
+                            GGML_ASSERT(ggml_is_contiguous(src1));
 
                             // utilize float4
                             GGML_ASSERT(ne00 % 4 == 0);
@@ -728,6 +767,7 @@ void ggml_metal_graph_compute(
 
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
+                                GGML_ASSERT(ne11 == 1);
                                 [encoder setComputePipelineState:ctx->pipeline_mul_row];
                             } else {
                                 [encoder setComputePipelineState:ctx->pipeline_mul];
@@ -743,6 +783,8 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_SCALE:
                         {
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+
                             const float scale = *(const float *) src1->data;
 
                             [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -750,7 +792,7 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
 
-                            const int64_t n = ggml_nelements(dst);
+                            const int64_t n = ggml_nelements(dst)/4;
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
@@ -762,7 +804,7 @@ void ggml_metal_graph_compute(
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                    const int64_t n = ggml_nelements(dst);
+                                    const int64_t n = ggml_nelements(dst)/4;
 
                                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
@@ -782,7 +824,7 @@ void ggml_metal_graph_compute(
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                    const int64_t n = ggml_nelements(dst);
+                                    const int64_t n = ggml_nelements(dst)/4;
 
                                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
@@ -796,13 +838,16 @@ void ggml_metal_graph_compute(
                         {
                             const int nth = 32;
 
-                            [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                            if (ne00%4 == 0) {
+                                [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                            }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -810,14 +855,23 @@ void ggml_metal_graph_compute(
                         {
                             const int n_past = ((int32_t *)(dst->op_params))[0];
 
-                            [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                            if (ne00%8 == 0) {
+                                [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                            }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            if (ne00%8 == 0) {
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            }
+                            else {
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            }
                         } break;
                     case GGML_OP_MUL_MAT:
                         {
@@ -830,13 +884,14 @@ void ggml_metal_graph_compute(
 
                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                            if (ggml_is_contiguous(src0) &&
-                                ggml_is_contiguous(src1) &&
+                            if (!ggml_is_transposed(src0) &&
+                                !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
                                 [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                                 ne00%32 == 0 &&
                                 ne11 > 1) {
                                 switch (src0->type) {
+                                    case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
                                     case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
                                     case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
@@ -856,25 +911,38 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
                                 [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
                                 [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
-                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:8];
-                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:9];
-                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:10];
+                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
+                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
+                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
+                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
+                                int nrows = 1;
 
                                 // use custom matrix x vector kernel
                                 switch (src0t) {
+                                    case GGML_TYPE_F32:
+                                        {
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
+                                            nrows = 4;
+                                        } break;
                                     case GGML_TYPE_F16:
                                         {
                                             nth0 = 32;
                                             nth1 = 1;
                                             if (ne11 * ne12 < 4) {
                                                 [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
+                                                nrows = ne11;
                                             } else {
                                                 [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                                nrows = 4;
                                             }
                                         } break;
                                     case GGML_TYPE_Q4_0:
@@ -995,7 +1063,7 @@ void ggml_metal_graph_compute(
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
-                                    int64_t ny = (ne11 + 3)/4;
+                                    int64_t ny = (ne11 + nrows - 1)/nrows;
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
@@ -1003,6 +1071,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_GET_ROWS:
                         {
                             switch (src0->type) {
+                                case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f32];  break;
                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@@ -1018,9 +1087,9 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
-                            [encoder setBytes:&(dst->nb[1])  length:sizeof(uint64_t) atIndex:5];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:5];
 
                             const int64_t n = ggml_nelements(src1);
 
index 5070561fba1ace47d5b224d13ec8e9a0a7771773..3087ecda812d90d510ffc6859ff3d4dc52f12e7e 100644 (file)
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant   int64_t & nb,
+        constant    int64_t & nb,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
 }
 
 kernel void kernel_scale(
-        device const float * src0,
-        device       float * dst,
+        device const float4 * src0,
+        device       float4 * dst,
         constant     float & scale,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * scale;
 }
 
 kernel void kernel_silu(
-        device const float * src0,
-        device       float * dst,
+        device const float4 * src0,
+        device       float4 * dst,
         uint tpig[[thread_position_in_grid]]) {
-    float x = src0[tpig];
+    device const float4 & x = src0[tpig];
     dst[tpig] = x / (1.0f + exp(-x));
 }
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A    = 0.044715f;
 constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
 kernel void kernel_gelu(
-    device const float * src0,
-    device       float * dst,
+    device const float4 * src0,
+    device       float4 * dst,
     uint tpig[[thread_position_in_grid]]) {
-    float x = src0[tpig];
+    device const float4 & x = src0[tpig];
 
     // BEWARE !!!
     // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
-        threadgroup float  * buf [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
     device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    buf[tpitg[0]] = -INFINITY;
-    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
-        buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
+    float lmax = psrc0[tpitg[0]];
+    for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+        lmax = MAX(lmax, psrc0[i00]);
     }
-
-    // reduce
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg[0]/2; i > 0; i /= 2) {
-        if (tpitg[0] < i) {
-            buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
-
-    //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
-    //               the loop, and when that is done, buf[0] has the correct (synchronized) value
-    //if (tpitg[0] == 0) {
-    //    buf[0] = buf[0];
-    //}
-
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    const float max = buf[0];
+    const float max = simd_max(lmax);
 
     // parallel sum
-    buf[tpitg[0]] = 0.0f;
+    float lsum = 0.0f;
     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
         const float exp_psrc0 = exp(psrc0[i00] - max);
-        buf[tpitg[0]] += exp_psrc0;
+        lsum += exp_psrc0;
         // Remember the result of exp here. exp is expensive, so we really do not
         // whish to compute it twice.
         pdst[i00] = exp_psrc0;
     }
 
-    // reduce
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg[0]/2; i > 0; i /= 2) {
-        if (tpitg[0] < i) {
-            buf[tpitg[0]] += buf[tpitg[0] + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    const float sum = simd_sum(lsum);
+
+    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+        pdst[i00] /= sum;
+    }
+}
+
+kernel void kernel_soft_max_4(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+    // parallel max
+    float4 lmax4 = psrc4[tpitg[0]];
+    for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+        lmax4 = fmax(lmax4, psrc4[i00]);
     }
+    float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
-    // broadcast - not needed, see above
-    //// broadcast
-    //if (tpitg[0] == 0) {
-    //    buf[0] = buf[0];
-    //}
+    const float max = simd_max(lmax);
 
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    // parallel sum
+    float4 lsum4 = 0.0f;
+    for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+        const float4 exp_psrc4 = exp(psrc4[i00] - max);
+        lsum4 += exp_psrc4;
+        pdst4[i00] = exp_psrc4;
+    }
+    float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
-    const float sum = buf[0];
+    const float sum = simd_sum(lsum);
 
-    for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
-        pdst[i00] /= sum;
+    for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+        pdst4[i00] /= sum;
     }
 }
 
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
         dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
     } else {
         dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+     }
+}
+
+kernel void kernel_diag_mask_inf_8(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne01,
+        constant        int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+
+    const int64_t i = 2*tpig[0];
+
+    dst[i+0] = src0[i+0];
+    dst[i+1] = src0[i+1];
+    int64_t i4 = 4*i;
+    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
+    const int64_t i00 = i4;
+    for (int k = 3; k >= 0; --k) {
+        if (i00 + 4 + k <= n_past + i01) {
+            break;
+        }
+        dst[i+1][k] = -INFINITY;
+        if (i00 + k > n_past + i01) {
+            dst[i][k] = -INFINITY;
+        }
     }
 }
 
@@ -491,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
     }
 }
 
+#define N_F32_F32 4
+
+kernel void kernel_mul_mat_f32_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F32_F32;
+    const int64_t im = tgpig.z;
+
+    device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const float4 * x4 = (device const float4 *)x;
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
+            device const float4 * y4 = (device const float4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
 kernel void kernel_mul_mat_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
@@ -616,6 +721,49 @@ kernel void kernel_mul_mat_f16_f32(
     }
 }
 
+// Assumes row size (ne00) is a multiple of 4
+kernel void kernel_mul_mat_f16_f32_l4(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+
+    const int nrows = ne11;
+    const int64_t r0 = tgpig.x;
+    const int64_t im = tgpig.z;
+
+    device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    for (int r1 = 0; r1 < nrows; ++r1) {
+        device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+
+        float sumf = 0;
+        for (int i = tiisg; i < ne00/4; i += 32) {
+            for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+        }
+
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
+}
+
 kernel void kernel_alibi_f32(
         device const float * src0,
         device       float * dst,
@@ -1123,31 +1271,40 @@ kernel void kernel_mul_mat_q3_K_f32(
     device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
     device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
 
-    float yl[16];
+    float yl[32];
 
-    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask1 = 0x3030;
     const uint16_t kmask2 = 0x0f0f;
 
-    const int tid = tiisg/2;
-    const int ix  = tiisg%2;
-    const int ip  = tid/8;          // 0 or 1
-    const int il  = tid/2 - 4*ip;   // 0...3
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int ip  = tid/4;          // 0 or 1
+    const int il  = 2*((tid%4)/2);  // 0 or 2
     const int ir  = tid%2;
     const int n   = 8;
     const int l0  = n*ir;
 
-    const uint16_t m1 = 1 << (4*ip + il);
-    const uint16_t m2 = m1 << 8;
+    // One would think that the Metal compiler would figure out that ip and il can only have
+    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+    // with these two tales.
+    //
+    // Possible masks for the high bit
+    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
+                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
+                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
+                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+    // Possible masks for the low 2 bits
+    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+    const ushort4 hm = mm[2*ip + il/2];
 
     const int shift = 2*il;
-    const uint16_t qm1 = 0x0003 << shift;
-    const uint16_t qm2 = 0x0300 << shift;
-    const int32_t v1 = 4 << shift;
-    const int32_t v2 = 1024 << shift;
+    const float    v1 = il == 0 ? 4.f : 64.f;
+    const float    v2 = 4.f * v1;
 
     const uint16_t s_shift1 = 4*ip;
-    const uint16_t s_shift2 = s_shift1 + 2*(il/2);
-    const int ik = 4 + (il%2);
+    const uint16_t s_shift2 = s_shift1 + il;
 
     const int q_offset = 32*ip + l0;
     const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1313,19 @@ kernel void kernel_mul_mat_q3_K_f32(
 
     device const float * y1 = yy + ix*QK_K + y_offset;
 
-    float sumf1[2] = {0.f}, sumf2[2] = {0.f};
-    for (int i = ix; i < nb; i += 2) {
+    uint32_t scales32, aux32;
+    thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+    thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+    float sumf1[2] = {0.f};
+    float sumf2[2] = {0.f};
+    for (int i = ix; i < nb; i += 4) {
 
         for (int l = 0; l < 8; ++l) {
-            yl[l+0] = y1[l+ 0];
-            yl[l+8] = y1[l+16];
+            yl[l+ 0] = y1[l+ 0];
+            yl[l+ 8] = y1[l+16];
+            yl[l+16] = y1[l+32];
+            yl[l+24] = y1[l+48];
         }
 
         device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1336,43 @@ kernel void kernel_mul_mat_q3_K_f32(
         for (int row = 0; row < 2; ++row) {
 
             const float d_all = (float)dh[0];
-            const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
 
-            float s1 = 0, s2 = 0;
+            scales16[0] = a[4];
+            scales16[1] = a[5];
+            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+            scales16[0] = a[il+0];
+            scales16[1] = a[il+1];
+            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
             for (int l = 0; l < n; l += 2) {
-                const uint16_t qs = q[l/2];
-                s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
-                s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
+                const int32_t qs = q[l/2];
+                s1 += yl[l+0] * (qs & qm[il/2][0]);
+                s2 += yl[l+1] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+                s4 += yl[l+16] * (qs & qm[il/2][2]);
+                s5 += yl[l+17] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
             }
-            float d = d_all * (s1 + 1.f/256.f * s2);
-            sumf1[row] += d * scales[0];
-            sumf2[row] += d;
+            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[0] - 32);
+            sumf2[row] += d2 * (scales[2] - 32);
 
-            s1 = s2 = 0;
+            s1 = s2 = s3 = s4 = s5 = s6 = 0;
             for (int l = 0; l < n; l += 2) {
-                const uint16_t qs = q[l/2+8];
-                s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
-                s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
+                const int32_t qs = q[l/2+8];
+                s1 += yl[l+8] * (qs & qm[il/2][0]);
+                s2 += yl[l+9] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+                s4 += yl[l+24] * (qs & qm[il/2][2]);
+                s5 += yl[l+25] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
             }
-            d = d_all * (s1 + 1.f/256.f * s2);
-            sumf1[row] += d * scales[1];
-            sumf2[row] += d;
+            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[1] - 32);
+            sumf2[row] += d2 * (scales[3] - 32);
 
             q  += step;
             h  += step;
@@ -1201,15 +1381,17 @@ kernel void kernel_mul_mat_q3_K_f32(
 
         }
 
-        y1 += 2 * QK_K;
+        y1 += 4 * QK_K;
 
     }
 
     for (int row = 0; row < 2; ++row) {
-        const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
-        const float tot = simd_sum(sumf);
-        if (tiisg == 0) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
+        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+        sumf1[row] = simd_sum(sumf);
+    }
+    if (tiisg == 0) {
+        for (int row = 0; row < 2; ++row) {
+            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
         }
     }
 }
@@ -1290,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne01 [[buffer(4)]],
+        constant   int64_t & ne02 [[buffer(5)]],
+        constant   int64_t & ne10 [[buffer(9)]],
+        constant   int64_t & ne12 [[buffer(11)]],
+        constant   int64_t & ne0  [[buffer(15)]],
+        constant   int64_t & ne1  [[buffer(16)]],
+        constant   uint    & gqa  [[buffer(17)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1564,17 +1746,25 @@ kernel void kernel_mul_mat_q5_K_f32(
             sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
             sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
 
-            float4 acc = {0.f, 0.f, 0.f, 0.f};
+            float4 acc1 = {0.f};
+            float4 acc2 = {0.f};
             for (int l = 0; l < n; ++l) {
                 uint8_t h = qh[l];
-                acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
-                acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
-                acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
-                acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
+                acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+                acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+                acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+                acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+                acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+                acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+                acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+                acc2[3] += h & hm4 ? yh[l+8] : 0.f;
             }
             const float dall = dh[0];
             const float dmin = dh[1];
-            sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
+            sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
+                                 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+                                 sc8[4] * (acc1[2] +  16.f*acc2[2]) +
+                                 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
             q1 += step;
@@ -1747,6 +1937,15 @@ kernel void kernel_mul_mat_q6_K_f32(
 
 //============================= templates and their specializations =============================
 
+// NOTE: this is not dequantizing - we are simply fitting the template
+template <typename type4x4>
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+    float4x4 temp = *(((device float4x4 *)src));
+    for (int i = 0; i < 16; i++){
+        reg[i/4][i%4] = temp[i/4][i%4];
+    }
+}
+
 template <typename type4x4>
 void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
     half4x4 temp = *(((device half4x4 *)src));
@@ -1758,28 +1957,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
 template <typename type4x4>
 void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 1);
-    const half d = il ? (xb->d / 16.h) : xb->d;
-    const half m = il ? ( -8.h * 16.h) : -8.h;
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
     const ushort mask0 = il ? 0x00F0 : 0x000F;
-    const ushort mask1 = il ? 0xF000 : 0x0F00;
+    const ushort mask1 = mask0 << 8;
 
     for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) + m) * d;
-        reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
+        reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+        reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
     }
 }
 
 template <typename type4x4>
 void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 2);
-    const half d = il ? (xb->d / 16.h) : xb->d;
-    const half m = xb->m;
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
     const ushort mask0 = il ? 0x00F0 : 0x000F;
-    const ushort mask1 = il ? 0xF000 : 0x0F00;
+    const ushort mask1 = mask0 << 8;
 
     for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) * d) + m;
-        reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
+        reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+        reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
     }
 }
 
@@ -1815,7 +2016,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
 
 template <typename type4x4>
 void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
-    const float d_all = (float)(xb->d);
+    const half d_all = xb->d;
     device const uint8_t * q = (device const uint8_t *)xb->qs;
     device const uint8_t * h = (device const uint8_t *)xb->hmask;
     device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1828,16 +2029,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
                                  ((il/4)>0 ? 12  : 3);
     uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
     uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
-    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
-                                 (scale_2&kmask2) | ((scale_1&kmask1) << 4);
-    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+    half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
+    const half ml = 4.h * dl;
 
-    il = (il/2)%4;
-    float   coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
-    uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    il = (il/2) & 3;
+    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl *= coef;
 
     for (int i = 0; i < 16; ++i) {
-        reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
+        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
     }
 #else
     float    kcoef = il&1 ? 1.f/16.f : 1.f;
@@ -1852,26 +2055,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
 #endif
 }
 
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
 template <typename type4x4>
 void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
-    device const uint8_t * q = xb->qs;
+    device const uchar * q = xb->qs;
 
 #if QK_K == 256
-    const float d = (float)(xb->d);
-    const float min = (float)(xb->dmin);
     short is = (il/4) * 2;
     q = q + (il/4) * 32 + 16 * (il&1);
-    il = il%4;
-    const uchar4 sc = get_scale_min_k4(is, xb->scales);
-    const float dl = il<2 ? d * sc[0]   : d * sc[2]/16.h;
-    const float ml = il<2 ? min * sc[1] : min * sc[3];
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const half d   = il < 2 ? xb->d : xb->d / 16.h;
+    const half min = xb->dmin;
+    const half dl = d * sc[0];
+    const half ml = min * sc[1];
 #else
     q = q + 16 * (il&1);
     device const uint8_t * s = xb->scales;
     device const half2 * dh = (device const half2 *)xb->d;
     const float2 d = (float2)dh[0];
     const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
-    const float ml = il<2 ? d[1] * (s[0]>>4)  : d[1 ]* (s[1]>>4);
+    const float ml = il<2 ? d[1] * (s[0]>>4)  : d[1* (s[1]>>4);
 #endif
     const ushort mask = il<2 ? 0x0F : 0xF0;
     for (int i = 0; i < 16; ++i) {
@@ -1885,19 +2093,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
     device const uint8_t * qh = xb->qh;
 
 #if QK_K == 256
-    const float d = (float)(xb->d);
-    const float min = (float)(xb->dmin);
     short is = (il/4) * 2;
     q  = q + 32 * (il/4) + 16 * (il&1);
     qh = qh + 16 * (il&1);
     uint8_t ul = 1 << (il/2);
-    il = il%4;
-    const uchar4 sc = get_scale_min_k4(is, xb->scales);
-    const float dl = il<2 ? d * sc[0]   : d * sc[2]/16.h;
-    const float ml = il<2 ? min * sc[1] : min * sc[3];
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const half d = il < 2 ? xb->d : xb->d / 16.h;
+    const half min = xb->dmin;
+    const half dl = d * sc[0];
+    const half ml = min * sc[1];
 
-    const ushort mask   = il<2 ? 0x0F : 0xF0;
-    const float  qh_val = il<2 ? 16.f : 256.f;
+    const ushort mask = il<2 ? 0x0F : 0xF0;
+    const half qh_val = il<2 ? 16.h : 256.h;
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
     }
@@ -1916,7 +2124,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
 
 template <typename type4x4>
 void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
-    const float d_all = (float)(xb->d);
+    const half d_all = xb->d;
     device const uint8_t * ql = (device const uint8_t *)xb->ql;
     device const uint8_t * qh = (device const uint8_t *)xb->qh;
     device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1924,19 +2132,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
 #if QK_K == 256
     ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
     qh = qh + 32*(il/8) + 16*(il&1);
-    float sc = scales[(il%2) + 2 * ((il/2))];
-    il = (il/2)%4;
+    half sc = scales[(il%2) + 2 * ((il/2))];
+    il = (il/2) & 3;
 #else
     ql = ql + 16 * (il&1);
-    float sc = scales[il];
+    half sc = scales[il];
 #endif
+    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
+    const half        coef = il>1 ? 1.f/16.h          : 1.h;
+    const half ml = d_all * sc * 32.h;
+    const half dl = d_all * sc * coef;
     for (int i = 0; i < 16; ++i) {
-        uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-        uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-        const float coef = il>1 ? 1.f/16.f          : 1.f;
-        float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
-                         ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
-        reg[i/4][i%4] = d_all * sc * q * coef;
+        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+        reg[i/4][i%4] = dl * q - ml;
     }
 }
 
@@ -1976,22 +2186,25 @@ kernel void kernel_get_rows(
 // each block_q contains 16*nl weights
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm(device const  uchar * src0,
-                           device const  float * src1,
-                           device        float * dst,
-                           constant    int64_t & ne00,
-                           constant    int64_t & ne02,
-                           constant    int64_t & nb01,
-                           constant    int64_t & nb02,
-                           constant    int64_t & ne12,
-                           constant    int64_t & ne0,
-                           constant    int64_t & ne1,
-                           constant    uint & gqa,
-                           threadgroup   uchar * shared_memory [[threadgroup(0)]],
-                           uint3                 tgpig[[threadgroup_position_in_grid]],
-                           uint                  tiitg[[thread_index_in_threadgroup]],
-                           uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    threadgroup half * sa = ((threadgroup half *)shared_memory);
+                          device const  uchar * src1,
+                          device        float * dst,
+                          constant    int64_t & ne00,
+                          constant    int64_t & ne02,
+                          constant    int64_t & nb01,
+                          constant    int64_t & nb02,
+                          constant    int64_t & ne12,
+                          constant    int64_t & nb10,
+                          constant    int64_t & nb11,
+                          constant    int64_t & nb12,
+                          constant    int64_t & ne0,
+                          constant    int64_t & ne1,
+                          constant       uint & gqa,
+                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                          uint3                 tgpig[[threadgroup_position_in_grid]],
+                          uint                  tiitg[[thread_index_in_threadgroup]],
+                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half  * sa = (threadgroup half  *)(shared_memory);
     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
 
     const uint r0 = tgpig.y;
@@ -2004,7 +2217,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
 
-    simdgroup_half8x8 ma[4];
+    simdgroup_half8x8  ma[4];
     simdgroup_float8x8 mb[2];
     simdgroup_float8x8 c_res[8];
     for (int i = 0; i < 8; i++){
@@ -2012,10 +2225,15 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     }
 
     short il = (tiitg % THREAD_PER_ROW);
-    uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
-    device const block_q  * x = (device const block_q  *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
-    device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
-                             + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
+
+    uint   offset0 = im/gqa*nb02;
+    ushort offset1 = il/nl;
+
+    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+    device const float   * y = (device const float   *)(src1
+        + nb12 * im
+        + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
         //load data and store to threadgroup memory
@@ -2095,6 +2313,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
                           constant uint64_t &, constant uint64_t &, uint, uint, uint);
 
+template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
 template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@@ -2105,14 +2324,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
 
-typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
-                             constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
-                             constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
-
-template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1, dequantize_f16>;
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
+typedef void (mat_mm_t)(
+        device const  uchar * src0,
+        device const  uchar * src1,
+        device        float * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne02,
+        constant    int64_t & nb01,
+        constant    int64_t & nb02,
+        constant    int64_t & ne12,
+        constant    int64_t & nb10,
+        constant    int64_t & nb11,
+        constant    int64_t & nb12,
+        constant    int64_t & ne0,
+        constant    int64_t & ne1,
+        constant       uint & gqa,
+        threadgroup uchar *, uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_f32_f32")]]  kernel mat_mm_t kernel_mul_mm<float4x4,   1,     dequantize_f32>;
+template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2,     dequantize_q8_0>;
 template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
index f4a9c3eab88ae15acbcdca2e3f1f507202c9fda7..2a806c7ad830733c08f47f259174212adac6834e 100644 (file)
@@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
 }
 
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
-    size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
-    for (int i = 1; i < GGML_MAX_DIMS; ++i) {
-        nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+    size_t nbytes;
+    size_t blck_size = ggml_blck_size(tensor->type);
+    if (blck_size == 1) {
+        nbytes = ggml_type_size(tensor->type);
+        for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+        }
+    }
+    else {
+        nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
+        for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+        }
     }
+
     return nbytes;
 }
 
@@ -17283,10 +17294,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         } else {
             // wait for other threads to finish
             const int last = node_n;
-            do {
-                //sched_yield();
+            while (true) {
+                // TODO: this sched_yield can have significant impact on the performance - either positive or negative
+                //       depending on the workload and the operating system.
+                //       since it is not clear what is the best approach, it should potentially become user-configurable
+                //       ref: https://github.com/ggerganov/ggml/issues/291
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                sched_yield();
+#endif
+
                 node_n = atomic_load(&state->shared->node_n);
-            } while (node_n == last);
+                if (node_n != last) break;
+            };
         }
 
         // check if we should stop
@@ -18337,10 +18356,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
     for (int i = 0; i < cgraph->n_leafs; i++) {
         struct ggml_tensor * node = cgraph->leafs[i];
 
-        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
+        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
                 i,
                 node->ne[0], node->ne[1],
-                ggml_op_name(node->op));
+                ggml_op_name(node->op),
+                ggml_get_name(node));
     }
 
     for (int i = 0; i < GGML_OP_COUNT; i++) {
@@ -20733,6 +20753,14 @@ int ggml_cpu_has_arm_fma(void) {
 #endif
 }
 
+int ggml_cpu_has_metal(void) {
+#if defined(GGML_USE_METAL)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_f16c(void) {
 #if defined(__F16C__)
     return 1;