]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : whisper.cpp (update whisper example + minor) (#613)
authorGeorgi Gerganov <redacted>
Fri, 17 Nov 2023 08:00:11 +0000 (10:00 +0200)
committerGitHub <redacted>
Fri, 17 Nov 2023 08:00:11 +0000 (10:00 +0200)
ggml-ci

examples/common-ggml.cpp
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
src/CMakeLists.txt
src/ggml-metal.m

index 33ae03ae10ff419dd18b336e5e7d9406afc94d36..e69bd51000c9d3f031ed47f310a114590eb6c4b6 100644 (file)
@@ -9,6 +9,11 @@ static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
     {"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
     {"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
     {"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
+    {"q2_k", GGML_FTYPE_MOSTLY_Q2_K},
+    {"q3_k", GGML_FTYPE_MOSTLY_Q3_K},
+    {"q4_k", GGML_FTYPE_MOSTLY_Q4_K},
+    {"q5_k", GGML_FTYPE_MOSTLY_Q5_K},
+    {"q6_k", GGML_FTYPE_MOSTLY_Q6_K},
 };
 
 void ggml_print_ftypes(FILE * fp) {
@@ -48,15 +53,15 @@ bool ggml_common_quantize_0(
         case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
         case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
         case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
+        case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break;
+        case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break;
+        case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break;
+        case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break;
+        case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break;
         case GGML_FTYPE_UNKNOWN:
         case GGML_FTYPE_ALL_F32:
         case GGML_FTYPE_MOSTLY_F16:
         case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
-        case GGML_FTYPE_MOSTLY_Q2_K:
-        case GGML_FTYPE_MOSTLY_Q3_K:
-        case GGML_FTYPE_MOSTLY_Q4_K:
-        case GGML_FTYPE_MOSTLY_Q5_K:
-        case GGML_FTYPE_MOSTLY_Q6_K:
                 {
                     fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
                     return false;
@@ -167,24 +172,17 @@ bool ggml_common_quantize_0(
 
             switch ((ggml_type) ttype) {
                 case GGML_TYPE_Q4_0:
-                    {
-                        cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
-                    } break;
                 case GGML_TYPE_Q4_1:
-                    {
-                        cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
-                    } break;
                 case GGML_TYPE_Q5_0:
-                    {
-                        cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
-                    } break;
                 case GGML_TYPE_Q5_1:
-                    {
-                        cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
-                    } break;
                 case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q2_K:
+                case GGML_TYPE_Q3_K:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q5_K:
+                case GGML_TYPE_Q6_K:
                     {
-                        cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
+                        cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements, hist_cur.data());
                     } break;
                 case GGML_TYPE_F32:
                 case GGML_TYPE_F16:
@@ -192,11 +190,6 @@ bool ggml_common_quantize_0(
                 case GGML_TYPE_I16:
                 case GGML_TYPE_I32:
                 case GGML_TYPE_Q8_1:
-                case GGML_TYPE_Q2_K:
-                case GGML_TYPE_Q3_K:
-                case GGML_TYPE_Q4_K:
-                case GGML_TYPE_Q5_K:
-                case GGML_TYPE_Q6_K:
                 case GGML_TYPE_Q8_K:
                 case GGML_TYPE_COUNT:
                     {
index e43dfe3f948f916836fb6022f4379355bd5d2bcf..98af5839ca551fe25f15b65d8c37b6c19d0672f1 100644 (file)
@@ -62,8 +62,8 @@ struct whisper_params {
     int32_t progress_step =  5;
     int32_t max_context  = -1;
     int32_t max_len      =  0;
-    int32_t best_of      =  2;
-    int32_t beam_size    = -1;
+    int32_t best_of      = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
+    int32_t beam_size    = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
 
     float word_thold    =  0.01f;
     float entropy_thold =  2.40f;
@@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
             if (params.detect_language) {
                 params.language = "auto";
             }
-            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
+            fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
                     __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
-                    params.n_threads, params.n_processors,
+                    params.n_threads, params.n_processors, params.beam_size, params.best_of,
                     params.language.c_str(),
                     params.translate ? "translate" : "transcribe",
                     params.tinydiarize ? "tdrz = 1, " : "",
index 1c7d7e9440dcb82cdedf3172aec31ee01c74e46e..e2197ff262c993ae7ac14ed393e1d116242fe45f 100644 (file)
@@ -20,6 +20,7 @@
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
 
+#include <atomic>
 #include <algorithm>
 #include <cassert>
 #define _USE_MATH_DEFINES
@@ -147,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
 
 //#define WHISPER_USE_FLASH_ATTN
 //#define WHISPER_USE_FLASH_FF
-#define WHISPER_MAX_DECODERS 16
+#define WHISPER_MAX_DECODERS 8
 #define WHISPER_MAX_NODES 4096
 
 //
@@ -406,6 +407,121 @@ struct whisper_segment {
     bool speaker_turn_next;
 };
 
+struct whisper_batch {
+    int32_t n_tokens;
+
+    whisper_token  *  token;
+    whisper_pos    *  pos;
+    int32_t        *  n_seq_id;
+    whisper_seq_id ** seq_id;   // null terminated
+    int8_t         *  logits;
+};
+
+static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
+    whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
+
+    batch.token    = (whisper_token *  ) malloc(sizeof(whisper_token)    * (n_tokens));
+    batch.pos      = (whisper_pos *)     malloc(sizeof(whisper_pos)      * (n_tokens));
+    batch.n_seq_id = (int32_t *)         malloc(sizeof(int32_t)          * (n_tokens));
+    batch.seq_id   = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
+    for (int i = 0; i < n_tokens; ++i) {
+        batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id)   * n_seq_max);
+    }
+    batch.seq_id[n_tokens] = nullptr;
+    batch.logits   = (int8_t *)          malloc(sizeof(int8_t)           * n_tokens);
+
+    return batch;
+}
+
+static void whisper_batch_free(struct whisper_batch batch) {
+    if (batch.token)    free(batch.token);
+    if (batch.pos)      free(batch.pos);
+    if (batch.n_seq_id) free(batch.n_seq_id);
+    if (batch.seq_id) {
+        for (int i = 0; batch.seq_id[i]; ++i) {
+            free(batch.seq_id[i]);
+        }
+        free(batch.seq_id);
+    }
+    if (batch.logits)   free(batch.logits);
+}
+
+static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
+    batch.n_tokens = n_tokens;
+    for (int i = 0; i < n_tokens; ++i) {
+        if (tokens) {
+            batch.token[i] = tokens[i];
+        }
+        batch.pos     [i]    = n_past + i;
+        batch.n_seq_id[i]    = 1;
+        batch.seq_id  [i][0] = seq_id;
+        batch.logits  [i]    = 0;
+    }
+    batch.logits[n_tokens - 1] = 1;
+}
+
+// replace std::pair by using customized pair struct (reason: std::pair is very slow)
+template<typename A, typename B>
+struct whisper_pair {
+    A first;
+    B second;
+
+    // Define a constructor that takes two arguments.
+    whisper_pair(const A& a, const B& b) : first(a), second(b) {}
+    // Define a constructor that takes no argument.
+    whisper_pair() : first(A()), second(B()) {}
+};
+
+// ggml_allocr wrapper for whisper usage
+struct whisper_allocr {
+    ggml_allocr * alloc = nullptr;
+
+    std::vector<uint8_t> meta;
+
+    ggml_backend_buffer_t buffer;
+};
+
+static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
+    return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
+}
+
+// measure the memory usage of a graph and prepare the allocr's internal data buffer
+static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
+    auto & alloc  = allocr.alloc;
+    auto & meta   = allocr.meta;
+
+    alloc = ggml_allocr_new_measure_from_backend(backend);
+
+    meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
+
+    ggml_allocr_alloc_graph(alloc, get_graph());
+}
+
+static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
+    if (allocr.alloc == nullptr) {
+        // this can be null if we use external encoder like CoreML or OpenVINO
+        return;
+    }
+
+    auto & alloc  = allocr.alloc;
+    auto & buffer = allocr.buffer;
+
+    size_t size = ggml_allocr_max_size(alloc);
+
+    ggml_allocr_free(alloc);
+
+    buffer = ggml_backend_alloc_buffer(backend, size);
+    alloc = ggml_allocr_new_from_buffer(buffer);
+}
+
+static void whisper_allocr_free(struct whisper_allocr & allocr) {
+    if (allocr.alloc) {
+        ggml_allocr_free(allocr.alloc);
+        ggml_backend_buffer_free(allocr.buffer);
+        allocr.alloc = nullptr;
+    }
+}
+
 // medium
 // hparams: {
 // 'n_mels': 80,
@@ -523,15 +639,31 @@ struct whisper_layer_decoder {
     struct ggml_tensor * mlp_1_b;
 };
 
+struct whisper_kv_cell {
+    whisper_pos pos = -1;
+
+    std::set<whisper_seq_id> seq_id;
+
+    bool has_seq_id(const whisper_seq_id & id) const {
+        return seq_id.find(id) != seq_id.end();
+    }
+};
+
 struct whisper_kv_cache {
+    uint32_t head = 0;
+    uint32_t size = 0;
+
+    // computed before each graph build
+    uint32_t n = 0;
+
+    std::vector<whisper_kv_cell> cells;
+
     struct ggml_tensor * k;
     struct ggml_tensor * v;
 
     struct ggml_context * ctx;
 
     ggml_backend_buffer_t buffer;
-
-    int n; // number of tokens currently in the cache
 };
 
 struct whisper_model {
@@ -579,6 +711,25 @@ struct whisper_model {
     std::map<std::string, struct ggml_tensor *> tensors;
 };
 
+struct whisper_partial_utf8 {
+    uint32_t value;    // bit value so far (unshifted)
+    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
+};
+
+struct whisper_grammar {
+    /*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
+    std::vector<std::vector<const whisper_grammar_element *>>   stacks;
+
+    // buffer for partially generated UTF-8 sequence from accepted tokens
+    whisper_partial_utf8 partial_utf8;
+};
+
+struct whisper_grammar_candidate {
+    whisper_token          id;
+    const uint32_t       * code_points;
+    whisper_partial_utf8   partial_utf8;
+};
+
 struct whisper_sequence {
     std::vector<whisper_token_data> tokens;
 
@@ -594,12 +745,13 @@ struct whisper_sequence {
 
 // TAGS: WHISPER_DECODER_INIT
 struct whisper_decoder {
-    // each decoder keeps its own KV-cache
-    whisper_kv_cache kv_self;
-
     // the currently generated sequence of tokens
     whisper_sequence sequence;
 
+    // grammar parse state of generated sequence of tokens
+    whisper_grammar  grammar;
+
+    int i_batch;    // the index of the token in the current batch
     int seek_delta; // the window shift found so far based on the decoded timestamp tokens
 
     bool failed;    // has the current segment failed to decode?
@@ -611,100 +763,40 @@ struct whisper_decoder {
     std::vector<float> logits;
     std::vector<float> logprobs;
 
-    std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
-};
-
-// replace std::pair by using customized pair struct (reason: std::pair is very slow)
-template<typename A, typename B>
-struct whisper_pair {
-    A first;
-    B second;
-
-    // Define a constructor that takes two arguments.
-    whisper_pair(const A& a, const B& b) : first(a), second(b) {}
-    // Define a constructor that takes no argument.
-    whisper_pair() : first(A()), second(B()) {}
-};
-
-// beam-search helpers
-struct kv_buf {
-    std::vector<uint8_t> k;
-    std::vector<uint8_t> v;
-};
-
-// ggml_allocr wrapper for whisper usage
-struct whisper_allocr {
-    ggml_allocr * alloc = nullptr;
-
-    std::vector<uint8_t> meta;
+    // work container used to avoid memory allocations
+    std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
 
-    ggml_backend_buffer_t buffer;
+    mutable std::mt19937 rng; // used for sampling at t > 0.0
 };
 
-static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
-    return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
-}
-
-// measure the memory usage of a graph and prepare the allocr's internal data buffer
-static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
-    auto & alloc  = allocr.alloc;
-    auto & meta   = allocr.meta;
-
-    alloc = ggml_allocr_new_measure_from_backend(backend);
-
-    meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
-
-    ggml_allocr_alloc_graph(alloc, get_graph());
-}
-
-static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
-    if (allocr.alloc == nullptr) {
-        // this can be null if we use external encoder like CoreML or OpenVINO
-        return;
-    }
-
-    auto & alloc  = allocr.alloc;
-    auto & buffer = allocr.buffer;
-
-    size_t size = ggml_allocr_max_size(alloc);
-
-    ggml_allocr_free(alloc);
-
-    buffer = ggml_backend_alloc_buffer(backend, size);
-    alloc = ggml_allocr_new_from_buffer(buffer);
-}
-
-static void whisper_allocr_free(struct whisper_allocr & allocr) {
-    if (allocr.alloc) {
-        ggml_allocr_free(allocr.alloc);
-        ggml_backend_buffer_free(allocr.buffer);
-        allocr.alloc = nullptr;
-    }
-}
-
 struct whisper_state {
     int64_t t_sample_us = 0;
     int64_t t_encode_us = 0;
     int64_t t_decode_us = 0;
+    int64_t t_batchd_us = 0;
     int64_t t_prompt_us = 0;
     int64_t t_mel_us = 0;
 
     int32_t n_sample = 0; // number of tokens sampled
     int32_t n_encode = 0; // number of encoder calls
-    int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
-    int32_t n_prompt = 0; // number of decoder calls with n_tokens >  1 (prompt encoding)
+    int32_t n_decode = 0; // number of decoder calls with n_tokens == 1  (text-generation)
+    int32_t n_batchd = 0; // number of decoder calls with n_tokens <  16 (batch decoding)
+    int32_t n_prompt = 0; // number of decoder calls with n_tokens >  1  (prompt encoding)
     int32_t n_fail_p = 0; // number of logprob threshold failures
     int32_t n_fail_h = 0; // number of entropy threshold failures
 
+    // unified self-attention KV cache for all decoders
+    whisper_kv_cache kv_self;
+
     // cross-attention KV cache for the decoders
     // shared between all decoders
     whisper_kv_cache kv_cross;
+
     whisper_mel mel;
 
-    whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
+    whisper_batch batch;
 
-    // buffer for swapping KV caches between decoders during beam-search
-    std::vector<kv_buf> kv_swap_bufs;
+    whisper_decoder decoders[WHISPER_MAX_DECODERS];
 
     ggml_backend_t backend = nullptr;
 
@@ -720,8 +812,9 @@ struct whisper_state {
     struct ggml_tensor * embd_conv = nullptr;
     struct ggml_tensor * embd_enc  = nullptr;
 
-    // helper for GPU offloading
+    // helpers for GPU offloading
     std::vector<float> inp_mel;
+    std::vector<float> inp_mask;
 
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
@@ -729,11 +822,6 @@ struct whisper_state {
     std::vector<whisper_segment> result_all;
     std::vector<whisper_token>   prompt_past;
 
-    // work container used to avoid memory allocations
-    std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
-
-    mutable std::mt19937 rng; // used for sampling at t > 0.0
-
     int lang_id = 0; // english by default
 
     std::string path_model; // populated by whisper_init_from_file_with_params()
@@ -809,6 +897,12 @@ static bool kv_cache_init(
         /*.no_alloc   =*/ true,
     };
 
+    cache.head = 0;
+    cache.size = n_ctx;
+
+    cache.cells.clear();
+    cache.cells.resize(n_ctx);
+
     cache.ctx = ggml_init(params);
 
     if (!cache.ctx) {
@@ -836,54 +930,129 @@ static bool kv_cache_init(
     return true;
 }
 
-// TODO: remove after batched decoding
-static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
-    WHISPER_ASSERT(cache.ctx);
+static void kv_cache_free(struct whisper_kv_cache & cache) {
+    if (cache.ctx) {
+        ggml_free(cache.ctx);
+        ggml_backend_buffer_free(cache.buffer);
+        cache.ctx = nullptr;
+    }
+}
 
-    const int n_elements = ggml_nelements(cache.k);
-    WHISPER_ASSERT(n_elements == ggml_nelements(cache.v));
+static bool whisper_kv_cache_find_slot(
+           struct whisper_kv_cache & cache,
+        const struct whisper_batch & batch) {
+    const uint32_t n_ctx    = cache.size;
+    const uint32_t n_tokens = batch.n_tokens;
 
-    const ggml_type wtype = cache.k->type;
-    WHISPER_ASSERT(wtype == cache.v->type);
+    if (n_tokens > n_ctx) {
+        WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
+        return false;
+    }
 
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
-        /*.mem_buffer =*/ nullptr,
-        /*.no_alloc   =*/ true,
-    };
+    uint32_t n_tested = 0;
 
-    cache.ctx = ggml_init(params);
+    while (true) {
+        if (cache.head + n_tokens > n_ctx) {
+            n_tested += n_ctx - cache.head;
+            cache.head = 0;
+            continue;
+        }
 
-    if (!cache.ctx) {
-        WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
-        return false;
+        bool found = true;
+        for (uint32_t i = 0; i < n_tokens; i++) {
+            if (cache.cells[cache.head + i].pos >= 0) {
+                found = false;
+                cache.head += i + 1;
+                n_tested   += i + 1;
+                break;
+            }
+        }
+
+        if (found) {
+            break;
+        }
+
+        if (n_tested >= n_ctx) {
+            //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+            return false;
+        }
     }
 
-    cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
-    cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+    for (uint32_t i = 0; i < n_tokens; i++) {
+        cache.cells[cache.head + i].pos = batch.pos[i];
 
-    const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
+        for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
+            cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
+        }
+    }
 
-    cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
+    return true;
+}
 
-    // allocate the tensors into the backend buffer
-    {
-        ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
+// find how many cells are currently in use
+static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
+    for (uint32_t i = cache.size - 1; i > 0; --i) {
+        if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
+            return i + 1;
+        }
+    }
 
-        ggml_allocr_alloc(alloc, cache.k);
-        ggml_allocr_alloc(alloc, cache.v);
+    return 1;
+}
 
-        ggml_allocr_free(alloc);
+static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
+    for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
+        cache.cells[i].pos = -1;
+        cache.cells[i].seq_id.clear();
     }
+    cache.head = 0;
+}
 
-    return true;
+static void whisper_kv_cache_seq_rm(
+        struct whisper_kv_cache & cache,
+                 whisper_seq_id   seq_id,
+                    whisper_pos   p0,
+                    whisper_pos   p1) {
+    uint32_t new_head = cache.size;
+
+    if (p0 < 0) p0 = 0;
+    if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
+
+    for (uint32_t i = 0; i < cache.size; ++i) {
+        if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
+            if (seq_id < 0) {
+                cache.cells[i].seq_id.clear();
+            } else if (cache.cells[i].has_seq_id(seq_id)) {
+                cache.cells[i].seq_id.erase(seq_id);
+            } else {
+                continue;
+            }
+            if (cache.cells[i].seq_id.empty()) {
+                cache.cells[i].pos = -1;
+                if (new_head == cache.size) new_head = i;
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != cache.size) cache.head = new_head;
 }
 
-static void kv_cache_free(struct whisper_kv_cache & cache) {
-    if (cache.ctx) {
-        ggml_free(cache.ctx);
-        ggml_backend_buffer_free(cache.buffer);
-        cache.ctx = nullptr;
+static void whisper_kv_cache_seq_cp(
+        struct whisper_kv_cache & cache,
+                 whisper_seq_id   seq_id_src,
+                 whisper_seq_id   seq_id_dst,
+                    whisper_pos   p0,
+                    whisper_pos   p1) {
+    if (p0 < 0) p0 = 0;
+    if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
+
+    cache.head = 0;
+
+    for (uint32_t i = 0; i < cache.size; ++i) {
+        if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
+            cache.cells[i].seq_id.insert(seq_id_dst);
+        }
     }
 }
 
@@ -892,7 +1061,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
 
     // initialize the backends
 #ifdef GGML_USE_CUBLAS
-    if (params.use_gpu) {
+    if (params.use_gpu && ggml_cublas_loaded()) {
         WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
         backend_gpu = ggml_backend_cuda_init();
         if (!backend_gpu) {
@@ -1094,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                     word = "[_EOT_]";
                 } else if (i == vocab.token_sot) {
                     word = "[_SOT_]";
+                } else if (i == vocab.token_translate) {
+                    word = "[_TRANSLATE_]";
+                } else if (i == vocab.token_transcribe) {
+                    word = "[_TRANSCRIBE_]";
                 } else if (i == vocab.token_solm) {
                     word = "[_SOLM_]";
                 } else if (i == vocab.token_prev) {
@@ -1104,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                     word = "[_NOT_]";
                 } else if (i == vocab.token_beg) {
                     word = "[_BEG_]";
+                } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
+                    word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
                 } else {
                     word = "[_extra_token_" + std::to_string(i) + "]";
                 }
@@ -1347,7 +1522,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main);
 
-        WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6);
     }
 
     ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
@@ -1462,12 +1637,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
             }
 
-            //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
+            //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
             total_size += ggml_nbytes(tensor);
             model.n_loaded++;
         }
 
-        WHISPER_LOG_INFO("%s: model size    = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
+        WHISPER_LOG_INFO("%s: model size    = %7.2f MB\n", __func__, total_size/1e6);
 
         if (model.n_loaded == 0) {
             WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
@@ -1852,11 +2027,11 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
     ////////////////////////////////////////////////////////////////////////////
 
     //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
-    //        ggml_used_mem(ctx0)/1024.0/1024.0,
-    //        wstate.get_buf_max_mem(0)/1024.0/1024.0,
-    //        wstate.get_buf_max_mem(1)/1024.0/1024.0,
-    //        wstate.get_buf_max_mem(2)/1024.0/1024.0,
-    //        wstate.get_buf_max_mem(3)/1024.0/1024.0);
+    //        ggml_used_mem(ctx0)/1e6,
+    //        wstate.get_buf_max_mem(0)/1e6,
+    //        wstate.get_buf_max_mem(1)/1e6,
+    //        wstate.get_buf_max_mem(2)/1e6,
+    //        wstate.get_buf_max_mem(3)/1e6);
 
     ggml_free(ctx0);
 
@@ -2009,26 +2184,28 @@ static bool whisper_encode_internal(
 static struct ggml_cgraph * whisper_build_graph_decoder(
          whisper_context & wctx,
          whisper_state   & wstate,
-         whisper_decoder & decoder,
-     const whisper_token * tokens,
-                   int   n_tokens,
-                   int   n_past) {
+     const whisper_batch & batch) {
     const auto & model   = wctx.model;
     const auto & hparams = model.hparams;
 
-    auto & kv_self = decoder.kv_self;
+    auto & kv_self = wstate.kv_self;
 
     WHISPER_ASSERT(!!kv_self.ctx);
 
-    const int n_ctx   = hparams.n_text_ctx;
+    ggml_allocr * alloc = wstate.alloc_decode.alloc;
+
+    const int n_ctx   = kv_self.size;
     const int n_state = hparams.n_text_state;
     const int n_head  = hparams.n_text_head;
     const int n_layer = hparams.n_text_layer;
 
-    const int N = n_tokens;
-    const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+    const int n_tokens    = batch.n_tokens;
+    const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
+
+    const int32_t n_kv     = ggml_allocr_is_measure(alloc) ? n_ctx            : kv_self.n;
+    const int32_t kv_head  = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
 
-    //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
+    //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
 
     struct ggml_init_params params = {
         /*.mem_size   =*/ wstate.alloc_decode.meta.size(),
@@ -2040,21 +2217,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
     ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
 
-    ggml_allocr * alloc = wstate.alloc_decode.alloc;
-
-    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
     ggml_allocr_alloc(alloc, embd);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
+        ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd));
     }
 
-    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
     ggml_allocr_alloc(alloc, position);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        for (int i = 0; i < N; ++i) {
-            const int32_t val = n_past + i;
+        for (int i = 0; i < n_tokens; ++i) {
+            const int32_t val = batch.pos[i];
             ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
         }
     }
@@ -2067,18 +2242,43 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
         ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
     }
 
-    // token encoding + position encoding
-    struct ggml_tensor * cur =
-        ggml_add(ctx0,
-                ggml_get_rows(ctx0, model.d_te, embd),
-                ggml_get_rows(ctx0, model.d_pe, position));
+    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+    ggml_allocr_alloc(alloc, KQ_mask);
 
-    struct ggml_tensor * inpL = cur;
+    if (!ggml_allocr_is_measure(alloc)) {
+        wstate.inp_mask.resize(n_kv*n_tokens);
 
-    for (int il = 0; il < n_layer; ++il) {
-        const auto & layer = model.layers_decoder[il];
+        float * data = wstate.inp_mask.data();
+        memset(data, 0, ggml_nbytes(KQ_mask));
 
-        // norm
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                const whisper_pos    pos    = batch.pos[j];
+                const whisper_seq_id seq_id = batch.seq_id[j][0];
+
+                for (int i = 0; i < n_kv; ++i) {
+                    if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+                    }
+                }
+            }
+        }
+
+        ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
+    }
+
+    // token encoding + position encoding
+    struct ggml_tensor * cur =
+        ggml_add(ctx0,
+                ggml_get_rows(ctx0, model.d_te, embd),
+                ggml_get_rows(ctx0, model.d_pe, position));
+
+    struct ggml_tensor * inpL = cur;
+
+    for (int il = 0; il < n_layer; ++il) {
+        const auto & layer = model.layers_decoder[il];
+
+        // norm
         {
             cur = ggml_norm(ctx0, inpL, hparams.eps);
 
@@ -2119,12 +2319,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
                             Vcur,
                             layer.attn_v_b);
 
-                Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
+                Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
 
-                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
                         (   n_ctx)*ggml_element_size(kv_self.v),
-                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
 
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2134,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
+                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_view_3d(ctx0, kv_self.k,
-                        n_state/n_head, n_past + N, n_head,
+                        n_state/n_head, n_kv, n_head,
                         ggml_element_size(kv_self.k)*n_state,
                         ggml_element_size(kv_self.k)*n_state/n_head,
                         ggml_element_size(kv_self.k)*n_state*n_ctx*il);
@@ -2149,16 +2349,17 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
 
-            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+            //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
+            struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, kv_self.v,
-                        n_past + N, n_state/n_head, n_head,
+                        n_kv, n_state/n_head, n_head,
                         n_ctx*ggml_element_size(kv_self.v),
                         n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
-                        il*n_ctx*ggml_element_size(kv_self.v)*n_state);
+                        n_ctx*ggml_element_size(kv_self.v)*n_state*il);
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
@@ -2166,7 +2367,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             cur = ggml_cpy(ctx0,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
         }
 
         // projection
@@ -2210,33 +2411,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
             // Kcross is already scaled
             struct ggml_tensor * Kcross =
                 ggml_view_3d(ctx0, wstate.kv_cross.k,
-                        n_state/n_head, M, n_head,
+                        n_state/n_head, n_audio_ctx, n_head,
                         ggml_element_size(wstate.kv_cross.k)*n_state,
                         ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
-                        ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
+                        ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
 
             //struct ggml_tensor * Vcross =
             //    ggml_reshape_3d(ctx0,
-            //            ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
-            //            n_state/n_head, n_head, M);
+            //            ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
+            //            n_state/n_head, n_head, n_audio_ctx);
 
             //struct ggml_tensor * V_trans =
             //    ggml_cpy(ctx0,
             //            ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
-            //            ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+            //            ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
 
             struct ggml_tensor * V =
                 ggml_view_3d(ctx0, wstate.kv_cross.v,
-                        M, n_state/n_head, n_head,
-                        M*ggml_element_size(wstate.kv_cross.v),
-                        M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
-                        il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
+                        n_audio_ctx, n_state/n_head, n_head,
+                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
+                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
+                        n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
 
             // ------
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
+                        ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
                         0, 2, 1, 3);
 
             // K * Q
@@ -2257,10 +2458,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-            // cur = KQV_merged.contiguous().view(n_state, N)
+            // cur = KQV_merged.contiguous().view(n_state, n_tokens)
             cur = ggml_cpy(ctx0,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N));
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
         }
 
         // projection
@@ -2332,9 +2533,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     }
 
     // compute logits only for the last token
-    // comment this line to compute logits for all tokens
+    // comment this line to compute logits for all n_tokens
     // might be useful in the future
-    cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
+    //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
 
     struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
 
@@ -2358,10 +2559,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 static bool whisper_decode_internal(
         whisper_context & wctx,
           whisper_state & wstate,
-        whisper_decoder & decoder,
-    const whisper_token * tokens,
-              const int   n_tokens,
-              const int   n_past,
+    const whisper_batch & batch,
               const int   n_threads,
  whisper_abort_callback   abort_callback,
                    void * abort_callback_data) {
@@ -2370,19 +2568,33 @@ static bool whisper_decode_internal(
     const auto & model   = wctx.model;
     const auto & hparams = model.hparams;
 
-    const int n_vocab = hparams.n_vocab;
+    const int n_vocab  = hparams.n_vocab;
+    const int n_tokens = batch.n_tokens;
 
     auto & logits_out = wstate.logits;
 
     struct ggml_tensor * logits;
 
+    // find KV slot for the batch
+    {
+        auto & kv_self = wstate.kv_self;
+
+        if (!whisper_kv_cache_find_slot(kv_self, batch)) {
+            return false;
+        }
+
+        kv_self.n = whisper_kv_cache_cell_max(kv_self);
+        //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
+        //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
+    }
+
     // decoder
     {
         auto & alloc = wstate.alloc_decode.alloc;
 
         ggml_allocr_reset(alloc);
 
-        ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past);
+        ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch);
 
         ggml_allocr_alloc_graph(alloc, gf);
 
@@ -2391,37 +2603,37 @@ static bool whisper_decode_internal(
         ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
-    // extract logits for all N tokens
-    //logits_out.resize(n_tokens*n_vocab);
-    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
-    //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
-
-    // extract logits only for the last token
-    logits_out.resize(n_vocab);
-    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
-    ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
+    logits_out.resize(n_tokens*n_vocab);
+    for (int i = 0; i < n_tokens; i++) {
+        if (batch.logits[i] == 0) {
+            continue;
+        }
+        ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
+    }
 
-    if (n_tokens > 1) {
+    if (batch.n_tokens > 1) {
         //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
-        //        ggml_used_mem(ctx0)/1024.0/1024.0,
-        //        wstate.get_buf_max_mem(0)/1024.0/1024.0,
-        //        wstate.get_buf_max_mem(1)/1024.0/1024.0,
-        //        wstate.get_buf_max_mem(2)/1024.0/1024.0,
-        //        wstate.get_buf_max_mem(3)/1024.0/1024.0);
+        //        ggml_used_mem(ctx0)/1e6,
+        //        wstate.get_buf_max_mem(0)/1e6,
+        //        wstate.get_buf_max_mem(1)/1e6,
+        //        wstate.get_buf_max_mem(2)/1e6,
+        //        wstate.get_buf_max_mem(3)/1e6);
     }
 
-    if (n_tokens == 1) {
+    if (batch.n_tokens == 1) {
         wstate.t_decode_us += ggml_time_us() - t_start_us;
         wstate.n_decode++;
+    } else if (batch.n_tokens < 16) {
+        wstate.t_batchd_us += ggml_time_us() - t_start_us;
+        wstate.n_batchd += n_tokens;
     } else {
         wstate.t_prompt_us += ggml_time_us() - t_start_us;
-        wstate.n_prompt++;
+        wstate.n_prompt += n_tokens;
     }
 
     return !(abort_callback && abort_callback(abort_callback_data));
 }
 
-
 //  500 -> 00:05.000
 // 6000 -> 01:00.000
 static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2833,15 +3045,19 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     state->backend = whisper_backend_init(ctx->params);
 
-    if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
+    // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
+    // in theory, there can be a case where this is not enough, but in practice it should always be enough
+    const int factor = 3;
+
+    if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
         WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         delete state;
         return nullptr;
     }
 
     {
-        const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
-        WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v);
+        WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
     if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
@@ -2852,7 +3068,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     {
         const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
-        WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
 #ifdef WHISPER_USE_COREML
@@ -2875,14 +3091,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
 
-    state->logits_id.reserve(ctx->model.hparams.n_vocab);
+    state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
 
     // TAGS: WHISPER_DECODER_INIT
     state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
 
-    state->decoders[0].probs.reserve   (ctx->vocab.n_vocab);
-    state->decoders[0].logits.reserve  (ctx->vocab.n_vocab);
-    state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
+    state->decoders[0].probs.reserve    (ctx->vocab.n_vocab);
+    state->decoders[0].logits.reserve   (ctx->vocab.n_vocab);
+    state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
+    state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
+
+    state->decoders[0].rng = std::mt19937(0);
 
     // conv allocator
     {
@@ -2891,7 +3110,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
                     return whisper_build_graph_conv(*ctx, *state, 0);
                 });
 
-        WHISPER_LOG_INFO("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
     }
 
     // encoder allocator
@@ -2901,7 +3120,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
                     return whisper_build_graph_encoder(*ctx, *state);
                 });
 
-        WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
     }
 
     // cross allocator
@@ -2911,7 +3130,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
                     return whisper_build_graph_cross(*ctx, *state);
                 });
 
-        WHISPER_LOG_INFO("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
     }
 
     // decoder allocator
@@ -2924,10 +3143,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
                     const int n_tokens = hparams.n_text_ctx;
                     const int n_past   = 0;
 
-                    return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
+                    whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
+
+                    return whisper_build_graph_decoder(*ctx, *state, state->batch);
                 });
 
-        WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
     }
 
     whisper_allocr_graph_realloc(state->alloc_conv,   ctx->backend);
@@ -2935,8 +3156,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
     whisper_allocr_graph_realloc(state->alloc_cross,  ctx->backend);
     whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
 
-    state->rng = std::mt19937(0);
-
     return state;
 }
 
@@ -3161,12 +3380,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
 void whisper_free_state(struct whisper_state * state)
 {
     if (state) {
+        kv_cache_free(state->kv_self);
         kv_cache_free(state->kv_cross);
 
-        for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
-            kv_cache_free(state->decoders[i].kv_self);
-        }
-
 #ifdef WHISPER_USE_COREML
         if (state->ctx_coreml != nullptr) {
             whisper_coreml_free(state->ctx_coreml);
@@ -3181,6 +3397,8 @@ void whisper_free_state(struct whisper_state * state)
         }
 #endif
 
+        whisper_batch_free(state->batch);
+
         whisper_allocr_free(state->alloc_conv);
         whisper_allocr_free(state->alloc_encode);
         whisper_allocr_free(state->alloc_cross);
@@ -3307,9 +3525,11 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
 }
 
 int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
-    const int selected_decoder_id = 0;
+    whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
 
-    if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
+    whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
+
+    if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
         WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -3318,15 +3538,16 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
 }
 
 int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
-    // TODO: add selected_decoder_id to state
-    const int selected_decoder_id = 0;
-
     if (ctx->state == nullptr) {
         WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
         return false;
     }
 
-    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
+    whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1);
+
+    whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0);
+
+    if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) {
         WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -3414,7 +3635,7 @@ int whisper_lang_auto_detect_with_state(
         return -7;
     }
 
-    auto & logits_id = state->logits_id;
+    auto & logits_id = state->decoders[0].logits_id;
     logits_id.clear();
 
     for (const auto & kv : g_lang) {
@@ -3617,6 +3838,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
         const int32_t n_sample = std::max(1, ctx->state->n_sample);
         const int32_t n_encode = std::max(1, ctx->state->n_encode);
         const int32_t n_decode = std::max(1, ctx->state->n_decode);
+        const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
         const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
 
         WHISPER_LOG_INFO("%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
@@ -3624,6 +3846,7 @@ void whisper_print_timings(struct whisper_context * ctx) {
         WHISPER_LOG_INFO("%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
         WHISPER_LOG_INFO("%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
         WHISPER_LOG_INFO("%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+        WHISPER_LOG_INFO("%s:   batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
         WHISPER_LOG_INFO("%s:   prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
     }
     WHISPER_LOG_INFO("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
@@ -3640,6 +3863,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
         ctx->state->n_sample = 0;
         ctx->state->n_encode = 0;
         ctx->state->n_decode = 0;
+        ctx->state->n_batchd = 0;
         ctx->state->n_prompt = 0;
     }
 }
@@ -3685,6 +3909,424 @@ const char * whisper_print_system_info(void) {
     return s.c_str();
 }
 
+//////////////////////////////////
+// Grammar - ported from llama.cpp
+//////////////////////////////////
+
+// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
+// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
+std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
+        const char         * src,
+        whisper_partial_utf8   partial_start) {
+    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
+    const char          * pos      = src;
+    std::vector<uint32_t> code_points;
+    uint32_t              value    = partial_start.value;
+    int                   n_remain = partial_start.n_remain;
+
+    // continue previous decode, if applicable
+    while (*pos != 0 && n_remain > 0) {
+        uint8_t next_byte = static_cast<uint8_t>(*pos);
+        if ((next_byte >> 6) != 2) {
+            // invalid sequence, abort
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
+        }
+        value = (value << 6) + (next_byte & 0x3F);
+        ++pos;
+        --n_remain;
+    }
+
+    if (partial_start.n_remain > 0 && n_remain == 0) {
+        code_points.push_back(value);
+    }
+
+    // decode any subsequent utf-8 sequences, which may end in an incomplete one
+    while (*pos != 0) {
+        uint8_t  first_byte = static_cast<uint8_t>(*pos);
+        uint8_t  highbits   = first_byte >> 4;
+                 n_remain   = lookup[highbits] - 1;
+
+        if (n_remain < 0) {
+            // invalid sequence, abort
+            code_points.clear();
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
+        }
+
+        uint8_t  mask       = (1 << (7 - n_remain)) - 1;
+                 value      = first_byte & mask;
+        ++pos;
+        while (*pos != 0 && n_remain > 0) {
+            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+            ++pos;
+            --n_remain;
+        }
+        if (n_remain == 0) {
+            code_points.push_back(value);
+        }
+    }
+    code_points.push_back(0);
+
+    return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
+}
+
+// returns true iff pos points to the end of one of the definitions of a rule
+static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
+    switch (pos->type) {
+        case WHISPER_GRETYPE_END: return true;  // NOLINT
+        case WHISPER_GRETYPE_ALT: return true;  // NOLINT
+        default:                return false;
+    }
+}
+
+// returns true iff chr satisfies the char range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
+        const whisper_grammar_element * pos,
+        const uint32_t                chr) {
+
+    bool found            = false;
+    bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+
+    WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
+
+    do {
+        if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            found = found || (pos->value <= chr && chr <= pos[1].value);
+            pos += 2;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            found = found || pos->value == chr;
+            pos += 1;
+        }
+    } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+    return std::make_pair(found == is_positive_char, pos);
+}
+
+// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
+// range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static bool whisper_grammar_match_partial_char(
+        const whisper_grammar_element * pos,
+        const whisper_partial_utf8      partial_utf8) {
+
+    bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
+    WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
+
+    uint32_t partial_value = partial_utf8.value;
+    int      n_remain      = partial_utf8.n_remain;
+
+    // invalid sequence or 7-bit char split across 2 bytes (overlong)
+    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
+        return false;
+    }
+
+    // range of possible code points this partial UTF-8 sequence could complete to
+    uint32_t low  = partial_value << (n_remain * 6);
+    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+
+    if (low == 0) {
+        if (n_remain == 2) {
+            low = 1 << 11;
+        } else if (n_remain == 3) {
+            low = 1 << 16;
+        }
+    }
+
+    do {
+        if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            if (pos->value <= high && low <= pos[1].value) {
+                return is_positive_char;
+            }
+            pos += 2;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            if (low <= pos->value && pos->value <= high) {
+                return is_positive_char;
+            }
+            pos += 1;
+        }
+    } while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
+
+    return !is_positive_char;
+}
+
+
+// transforms a grammar pushdown stack into N possible stacks, all ending
+// at a character range (terminal element)
+static void whisper_grammar_advance_stack(
+        const std::vector<std::vector<whisper_grammar_element>>   & rules,
+        const std::vector<const whisper_grammar_element *>        & stack,
+        std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
+
+    if (stack.empty()) {
+        new_stacks.push_back(stack);
+        return;
+    }
+
+    const whisper_grammar_element * pos = stack.back();
+
+    switch (pos->type) {
+        case WHISPER_GRETYPE_RULE_REF: {
+            const size_t                  rule_id = static_cast<size_t>(pos->value);
+            const whisper_grammar_element * subpos  = rules[rule_id].data();
+            do {
+                // init new stack without the top (pos)
+                std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+                if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
+                    // if this rule ref is followed by another element, add that to stack
+                    new_stack.push_back(pos + 1);
+                }
+                if (!whisper_grammar_is_end_of_sequence(subpos)) {
+                    // if alternate is nonempty, add to stack
+                    new_stack.push_back(subpos);
+                }
+                whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+                while (!whisper_grammar_is_end_of_sequence(subpos)) {
+                    // scan to end of alternate def
+                    subpos++;
+                }
+                if (subpos->type == WHISPER_GRETYPE_ALT) {
+                    // there's another alternate def of this rule to process
+                    subpos++;
+                } else {
+                    break;
+                }
+            } while (true);
+            break;
+        }
+        case WHISPER_GRETYPE_CHAR:
+        case WHISPER_GRETYPE_CHAR_NOT:
+            new_stacks.push_back(stack);
+            break;
+        default:
+            // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
+            // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
+            // those
+            WHISPER_ASSERT(false);
+    }
+}
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `whisper_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
+        const std::vector<std::vector<whisper_grammar_element>>         & rules,
+        const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+        const uint32_t                                                  chr) {
+
+    std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
+
+    for (const auto & stack : stacks) {
+        if (stack.empty()) {
+            continue;
+        }
+
+        auto match = whisper_grammar_match_char(stack.back(), chr);
+        if (match.first) {
+            const whisper_grammar_element * pos = match.second;
+
+            // update top of stack to next element, if any
+            std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
+            if (!whisper_grammar_is_end_of_sequence(pos)) {
+                new_stack.push_back(pos);
+            }
+            whisper_grammar_advance_stack(rules, new_stack, new_stacks);
+        }
+    }
+
+    return new_stacks;
+}
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
+        const std::vector<std::vector<whisper_grammar_element>>         & rules,
+        const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+        const std::vector<whisper_grammar_candidate>                    & candidates);
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
+        const std::vector<std::vector<whisper_grammar_element>> & rules,
+        const std::vector<const whisper_grammar_element *>      & stack,
+        const std::vector<whisper_grammar_candidate>            & candidates) {
+
+    std::vector<whisper_grammar_candidate> rejects;
+
+    if (stack.empty()) {
+        for (auto tok : candidates) {
+            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
+                rejects.push_back(tok);
+            }
+        }
+        return rejects;
+    }
+
+    const whisper_grammar_element * stack_pos = stack.back();
+
+    std::vector<whisper_grammar_candidate> next_candidates;
+    for (auto tok : candidates) {
+        if (*tok.code_points == 0) {
+            // reached end of full codepoints in token, reject iff it ended in a partial sequence
+            // that cannot satisfy this position in grammar
+            if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
+                rejects.push_back(tok);
+            }
+        } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
+            next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
+        } else {
+            rejects.push_back(tok);
+        }
+    }
+
+    const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
+
+    // update top of stack to next element, if any
+    std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
+    if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
+        stack_after.push_back(stack_pos_after);
+    }
+    std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
+    whisper_grammar_advance_stack(rules, stack_after, next_stacks);
+
+    auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
+    for (auto tok : next_rejects) {
+        rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
+    }
+
+    return rejects;
+}
+
+static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
+        const std::vector<std::vector<whisper_grammar_element>>         & rules,
+        const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
+        const std::vector<whisper_grammar_candidate>                    & candidates) {
+    if (candidates.empty() || stacks.empty()) {
+        return std::vector<whisper_grammar_candidate>();
+    }
+
+    auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
+
+    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
+        rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
+    }
+    return rejects;
+}
+
+static struct whisper_grammar whisper_grammar_init(
+            const whisper_grammar_element ** rules,
+                                 size_t      n_rules,
+                                 size_t      i_start_rule) {
+    const whisper_grammar_element * pos;
+
+    // copy rule definitions into vectors
+    std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
+            vec_rules[i].push_back(*pos);
+        }
+        vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
+    }
+
+    // loop over alternates of start rule to build initial stacks
+    std::vector<std::vector<const whisper_grammar_element *>> stacks;
+    pos = rules[i_start_rule];
+    do {
+        std::vector<const whisper_grammar_element *> stack;
+        if (!whisper_grammar_is_end_of_sequence(pos)) {
+            // if alternate is nonempty, add to stack
+            stack.push_back(pos);
+        }
+        whisper_grammar_advance_stack(vec_rules, stack, stacks);
+        while (!whisper_grammar_is_end_of_sequence(pos)) {
+            // scan to end of alternate def
+            pos++;
+        }
+        if (pos->type == WHISPER_GRETYPE_ALT) {
+            // there's another alternate def of this rule to process
+            pos++;
+        } else {
+            break;
+        }
+    } while (true);
+
+    return { std::move(vec_rules), std::move(stacks), {} };
+}
+
+static void whisper_suppress_invalid_grammar(
+             whisper_context  & ctx,
+    const whisper_full_params & params,
+           std::vector<float> & logits,
+    const     whisper_grammar & grammar) {
+
+    if (grammar.rules.empty() || grammar.stacks.empty()) {
+        return;
+    }
+
+    //bool allow_eot = false;
+    //for (const auto & stack : grammar.stacks) {
+    //    if (stack.empty()) {
+    //        allow_eot = true;
+    //        break;
+    //    }
+    //}
+
+    const whisper_token eot = whisper_token_eot(&ctx);
+
+    std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
+    std::vector<whisper_grammar_candidate>                              candidates_grammar;
+
+    for (whisper_token id = 0; id < eot; ++id) {
+        const std::string & text = ctx.vocab.id_to_token[id];
+        if (!text.empty()) {
+            candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
+            candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
+        }
+    }
+
+    const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
+
+    for (const auto & reject : rejects) {
+        logits[reject.id] -= params.grammar_penalty;
+    }
+
+    // when the grammar allows a continuation, we penalize the end-of-text token
+    //if (!allow_eot) {
+    //    logits[eot] -= params.grammar_penalty;
+    //}
+    //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
+}
+
+static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
+    if (grammar.rules.empty() || grammar.stacks.empty()) {
+        return;
+    }
+
+    //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
+
+    const std::string & text = ctx.vocab.id_to_token[token];
+
+    if (text.rfind("[_", 0) == 0) {
+        // fprintf(stderr, " (skipped)\n");
+        return;
+    }
+    // fprintf(stderr, "\n");
+
+    // Note terminating 0 in decoded string
+    const auto   decoded     = decode_utf8(text.c_str(), grammar.partial_utf8);
+    const auto & code_points = decoded.first;
+    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+        grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
+    }
+    grammar.partial_utf8 = decoded.second;
+}
+
+//////////////
+// END grammar
+//////////////
+
 ////////////////////////////////////////////////////////////////////////////
 
 struct whisper_context_params * whisper_context_default_params_by_ref() {
@@ -3714,6 +4356,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
         /*.translate         =*/ false,
         /*.no_context        =*/ true,
+        /*.no_timestamps     =*/ false,
         /*.single_segment    =*/ false,
         /*.print_special     =*/ false,
         /*.print_progress    =*/ true,
@@ -3747,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.max_initial_ts    =*/  1.0f,
         /*.length_penalty    =*/ -1.0f,
 
-        /*.temperature_inc   =*/  0.4f,
+        /*.temperature_inc   =*/  0.2f,
         /*.entropy_thold     =*/  2.4f,
         /*.logprob_thold     =*/ -1.0f,
         /*.no_speech_thold   =*/  0.6f,
@@ -3776,19 +4419,24 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
 
         /*.logits_filter_callback           =*/ nullptr,
         /*.logits_filter_callback_user_data =*/ nullptr,
+
+        /*.grammar_rules   =*/ nullptr,
+        /*.n_grammar_rules =*/ 0,
+        /*.i_start_rule    =*/ 0,
+        /*.grammar_penalty =*/ 100.0f,
     };
 
     switch (strategy) {
         case WHISPER_SAMPLING_GREEDY:
             {
                 result.greedy = {
-                    /*.best_of   =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
+                    /*.best_of   =*/ 5,
                 };
             } break;
         case WHISPER_SAMPLING_BEAM_SEARCH:
             {
                 result.beam_search = {
-                    /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
+                    /*.beam_size =*/ 5,
 
                     /*.patience  =*/ -1.0f,
                 };
@@ -3878,11 +4526,12 @@ static const std::vector<std::string> non_speech_tokens = {
 // process the logits for the selected decoder
 // - applies logit filters
 // - computes logprobs and probs
+// TODO: optimize
 static void whisper_process_logits(
               struct whisper_context & ctx,
                struct whisper_state  & state,
-    const struct whisper_full_params   params,
               struct whisper_decoder & decoder,
+    const struct whisper_full_params   params,
                                float   temperature) {
     const auto & vocab      = ctx.vocab;
     const auto & tokens_cur = decoder.sequence.tokens;
@@ -3899,7 +4548,7 @@ static void whisper_process_logits(
     auto & logprobs = decoder.logprobs;
     {
         logits.resize(n_logits);
-        memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
+        memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
 
         if (temperature > 0.0f) {
             for (int i = 0; i < n_logits; i++) {
@@ -3927,6 +4576,11 @@ static void whisper_process_logits(
         // suppress <|notimestamps|> token
         // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
         logits[vocab.token_not] = -INFINITY;
+        if (params.no_timestamps) {
+            for (int i = vocab.token_beg; i < n_logits; ++i) {
+                logits[i] = -INFINITY;
+            }
+        }
 
         // suppress sot and nosp tokens
         logits[vocab.token_sot]  = -INFINITY;
@@ -3942,6 +4596,14 @@ static void whisper_process_logits(
         logits[vocab.token_transcribe] = -INFINITY;
         logits[vocab.token_prev]       = -INFINITY;
 
+        // suppress lang tokens
+        for (size_t i = 0; i < g_lang.size(); ++i) {
+            logits[whisper_token_lang(&ctx, i)] = -INFINITY;
+        }
+
+        // suppress prev token
+        logits[vocab.token_prev] = -INFINITY;
+
         if (params.logits_filter_callback) {
             params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
         }
@@ -4056,6 +4718,30 @@ static void whisper_process_logits(
                     logits[i]   = -INFINITY;
                     logprobs[i] = -INFINITY;
                 }
+            } else {
+                if (params.n_grammar_rules > 0) {
+                    whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
+
+                    // populate the logprobs array (log_softmax)
+                    {
+                        const float logit_max = *std::max_element(logits.begin(), logits.end());
+                        float logsumexp = 0.0f;
+                        for (int i = 0; i < n_logits; ++i) {
+                            if (logits[i] > -INFINITY) {
+                                logsumexp += expf(logits[i] - logit_max);
+                            }
+                        }
+                        logsumexp = logf(logsumexp) + logit_max;
+
+                        for (int i = 0; i < n_logits; ++i) {
+                            if (logits[i] > -INFINITY) {
+                                logprobs[i] = logits[i] - logsumexp;
+                            } else {
+                                logprobs[i] = -INFINITY;
+                            }
+                        }
+                    }
+                }
             }
         }
     }
@@ -4073,38 +4759,60 @@ static void whisper_process_logits(
 
 #if 0
     // print first 100 logits - token string : logit
-    for (int i = 0; i < 100; i++) {
-        const auto token   = vocab.id_to_token.at(i);
-        const auto prob    = probs[i];
-        const auto logit   = logits[i];
-        const auto logprob = logprobs[i];
-        printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+    //for (int i = 0; i < 10; i++) {
+    //    const auto token   = vocab.id_to_token.at(i);
+    //    const auto prob    = probs[i];
+    //    const auto logit   = logits[i];
+    //    const auto logprob = logprobs[i];
+    //    printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
+    //}
+
+    // print sorted
+    {
+        std::vector<std::pair<float, int>> pairs;
+
+        for (int i = 0; i < n_logits; ++i) {
+            pairs.push_back(std::make_pair(probs[i], i));
+        }
+
+        std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
+            return a.first > b.first;
+        });
+
+        for (int i = 0; i < 10; i++) {
+            const auto token   = vocab.id_to_token.at(pairs[i].second);
+            const auto prob    = pairs[i].first;
+            const auto logit   = logits[pairs[i].second];
+            const auto logprob = logprobs[pairs[i].second];
+            printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
+        }
+
+        printf("----------------\n");
     }
 
     // "And", "and", " And", " and"
-    printf("logits[\"and\"]  = %f\n", logits[vocab.token_to_id.at("and")]);
-    printf("logits[\"And\"]  = %f\n", logits[vocab.token_to_id.at("And")]);
-    printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
-    printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
-    printf("logits[\" so\"]  = %f\n", logits[vocab.token_to_id.at(" so")]);
-
-    printf("logprobs[\"and\"]  = %f\n", logprobs[vocab.token_to_id.at("and")]);
-    printf("logprobs[\"And\"]  = %f\n", logprobs[vocab.token_to_id.at("And")]);
-    printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
-    printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
-    printf("logprobs[\" so\"]  = %f\n", logprobs[vocab.token_to_id.at(" so")]);
-
-    printf("probs[\"and\"]  = %f\n", probs[vocab.token_to_id.at("and")]);
-    printf("probs[\"And\"]  = %f\n", probs[vocab.token_to_id.at("And")]);
-    printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
-    printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
-    printf("probs[\" so\"]  = %f\n", probs[vocab.token_to_id.at(" so")]);
+    //printf("logits[\"and\"]  = %f\n", logits[vocab.token_to_id.at("and")]);
+    //printf("logits[\"And\"]  = %f\n", logits[vocab.token_to_id.at("And")]);
+    //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
+    //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
+    //printf("logits[\" so\"]  = %f\n", logits[vocab.token_to_id.at(" so")]);
+
+    //printf("logprobs[\"and\"]  = %f\n", logprobs[vocab.token_to_id.at("and")]);
+    //printf("logprobs[\"And\"]  = %f\n", logprobs[vocab.token_to_id.at("And")]);
+    //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
+    //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
+    //printf("logprobs[\" so\"]  = %f\n", logprobs[vocab.token_to_id.at(" so")]);
+
+    //printf("probs[\"and\"]  = %f\n", probs[vocab.token_to_id.at("and")]);
+    //printf("probs[\"And\"]  = %f\n", probs[vocab.token_to_id.at("And")]);
+    //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
+    //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
+    //printf("probs[\" so\"]  = %f\n", probs[vocab.token_to_id.at(" so")]);
 #endif
 }
 
 static whisper_token_data whisper_sample_token(
             whisper_context & ctx,
-              whisper_state & state,
       const whisper_decoder & decoder,
                        bool   best) {
     whisper_token_data result = {
@@ -4149,7 +4857,7 @@ static whisper_token_data whisper_sample_token(
     } else {
         std::discrete_distribution<> dist(probs.begin(), probs.end());
 
-        result.id   = dist(state.rng);
+        result.id   = dist(decoder.rng);
         result.p    = probs[result.id];
         result.plog = logprobs[result.id];
     }
@@ -4159,15 +4867,12 @@ static whisper_token_data whisper_sample_token(
         result.pt  = result.p;
     }
 
-    state.n_sample++;
-
     return result;
 }
 
 static std::vector<whisper_token_data> whisper_sample_token_topk(
             whisper_context & ctx,
-              whisper_state & state,
-      const whisper_decoder & decoder,
+            whisper_decoder & decoder,
                         int   k) {
     const auto & vocab = ctx.vocab;
 
@@ -4177,7 +4882,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
 
     const int n_logits = vocab.n_vocab;
 
-    auto & logits_id = state.logits_id;
+    auto & logits_id = decoder.logits_id;
 
     logits_id.resize(n_logits);
     for (int i = 0; i < n_logits; ++i) {
@@ -4223,8 +4928,11 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
         ptsum = sum_ts;
     }
 
+    std::discrete_distribution<> dist(probs.begin(), probs.end());
+
     for (int i = 0; i < k; ++i) {
-        const auto id = logits_id[i].second;
+        const auto id = dist(decoder.rng);
+        //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
 
         result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
 
@@ -4234,8 +4942,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
         }
     }
 
-    state.n_sample++;
-
     return result;
 }
 
@@ -4288,125 +4994,6 @@ static void whisper_sequence_score(
     }
 }
 
-static bool whisper_kv_swap_fast(
-                   std::vector<int> & view,
-                    whisper_decoder   src[],
-                std::vector<kv_buf> & kv_swap_bufs,
-                          const int & n_decoders) {
-    WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
-
-    // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
-    std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
-
-    // (buffer->decoder or decoder->decoder)
-    std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
-
-    // (decoder<->decoder)
-    std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
-    std::vector<whisper_pair<int, int>> p_swap_vec;
-    p_swap_vec.reserve(n_decoders);
-
-    // see https://github.com/ggerganov/whisper.cpp/wiki
-    for (int i = 0; i < n_decoders; i++) {
-        // zero-copy (no modification)
-        if (i == view[i] || view[i] < 0) {
-            continue;
-        }
-
-        bool is_one_copy = true;
-        // since we modify data sequentially, we only consider decoder indices after current index
-        for (int j = i + 1; j < n_decoders; j++) {
-            if (i == view[j]) {
-                // detect symmetric diagram
-                if (j == view[i]) {
-                    p_swap_set.insert(i);
-                    p_swap_set.insert(j);
-                    p_swap_vec.emplace_back(i, j);
-                } else {
-                    two_copy.insert(i);
-                    is_one_copy = false;
-                }
-                break;
-            }
-        }
-        if (is_one_copy) {
-            one_copy.insert(i);
-        }
-    }
-
-    kv_swap_bufs.resize(n_decoders);
-
-    for (int i = 0; i < n_decoders; i++) {
-        kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
-        kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
-    }
-
-    for (auto & i : two_copy) {
-        // make a copy of KV caches
-        WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
-        //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
-        //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
-        ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
-        ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
-    }
-
-    // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
-    for (auto & i : two_copy) {
-        // skip the decoder indices that require pointer swapping
-        if (p_swap_set.find(i) != p_swap_set.end()) {
-            continue;
-        }
-
-        if (two_copy.find(view[i]) != two_copy.end()) {
-            // modify KV caches of decoder using data from kv_swap_bufs
-            WHISPER_PRINT_DEBUG("%s: two-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
-            //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
-            //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
-            ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
-            ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
-        } else {
-            // modify KV caches of decoder using data from correspond decoder KV caches directly
-            WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
-            //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
-            //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
-            ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
-            ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
-        }
-    }
-
-    // then modify one-copy decoder KV caches
-    for (auto & i : one_copy) {
-        // skip the decoder indices that require pointer swapping
-        if (p_swap_set.find(i) != p_swap_set.end()) {
-            continue;
-        }
-
-        if (two_copy.find(view[i]) != two_copy.end()) {
-            // modify KV caches of decoder using data from kv_swap_bufs
-            WHISPER_PRINT_DEBUG("%s: one-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
-            //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
-            //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
-            ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
-            ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
-        } else {
-            // modify KV caches of decoder using data from correspond decoder KV caches directly
-            WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
-            //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
-            //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
-            ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
-            ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
-        }
-    }
-
-    // swap the pointers
-    for (auto & i : p_swap_vec) {
-        WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
-        std::swap(src[i.first].kv_self, src[i.second].kv_self);
-    }
-
-    return true;
-}
-
 int whisper_full_with_state(
         struct whisper_context * ctx,
           struct whisper_state * state,
@@ -4496,25 +5083,23 @@ int whisper_full_with_state(
 
     n_decoders = std::max(1, n_decoders);
 
+    if (n_decoders > WHISPER_MAX_DECODERS) {
+        WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
+        return -4;
+    }
+
     // TAGS: WHISPER_DECODER_INIT
     for (int j = 1; j < n_decoders; j++) {
         auto & decoder = state->decoders[j];
 
-        if (decoder.kv_self.ctx == nullptr) {
-            decoder.kv_self = state->decoders[0].kv_self;
-            if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
-                WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
-                return -4;
-            }
-
-            WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
+        decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
 
-            decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
+        decoder.probs.resize   (ctx->vocab.n_vocab);
+        decoder.logits.resize  (ctx->vocab.n_vocab);
+        decoder.logprobs.resize(ctx->vocab.n_vocab);
+        decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
 
-            decoder.probs.resize   (ctx->vocab.n_vocab);
-            decoder.logits.resize  (ctx->vocab.n_vocab);
-            decoder.logprobs.resize(ctx->vocab.n_vocab);
-        }
+        decoder.rng = std::mt19937(0);
     }
 
     // the accumulated text context so far
@@ -4553,7 +5138,7 @@ int whisper_full_with_state(
     state->exp_n_audio_ctx = params.audio_ctx;
 
     // these tokens determine the task that will be performed
-    std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
+    std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
 
     if (whisper_is_multilingual(ctx)) {
         const int lang_id = whisper_lang_id(params.language);
@@ -4566,17 +5151,19 @@ int whisper_full_with_state(
         }
     }
 
+    // distilled models require the "no_timestamps" token
     {
         const bool is_distil = ctx->model.hparams.n_text_layer == 2;
-
-        // distilled models require the "no_timestamps" token
-        // TODO: add input parameter (#1229)
-        if (is_distil) {
+        if (is_distil && !params.no_timestamps) {
             WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
-            prompt_init.push_back(whisper_token_not(ctx));
+            params.no_timestamps = true;
         }
     }
 
+    if (params.no_timestamps) {
+        prompt_init.push_back(whisper_token_not(ctx));
+    }
+
     int seek = seek_start;
 
     std::vector<whisper_token> prompt;
@@ -4589,8 +5176,10 @@ int whisper_full_with_state(
         bool has_ts;
 
         whisper_sequence sequence;
+        whisper_grammar grammar;
     };
 
+    std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
     std::vector<beam_candidate> beam_candidates;
 
     // main loop
@@ -4652,14 +5241,12 @@ int whisper_full_with_state(
 
             n_decoders_cur = std::max(1, n_decoders_cur);
 
-            WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur);
+            WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
 
             // TAGS: WHISPER_DECODER_INIT
             for (int j = 0; j < n_decoders_cur; ++j) {
                 auto & decoder = state->decoders[j];
 
-                decoder.kv_self.n = 0;
-
                 decoder.sequence.tokens.clear();
                 decoder.sequence.result_len       = 0;
                 decoder.sequence.sum_logprobs_all = 0.0;
@@ -4673,10 +5260,16 @@ int whisper_full_with_state(
                 decoder.failed    = false;
                 decoder.completed = false;
                 decoder.has_ts    = false;
+
+                if (params.grammar_rules != nullptr) {
+                    decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
+                } else {
+                    decoder.grammar = {};
+                }
             }
 
             // init prompt and kv cache for the current iteration
-            // run whisper_decoder() only for decoder 0 and copy the results for the other decoders
+            // TODO: do not recompute the prompt if it is the same as previous time
             {
                 prompt.clear();
 
@@ -4698,7 +5291,11 @@ int whisper_full_with_state(
                 }
                 WHISPER_PRINT_DEBUG("\n\n");
 
-                if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
+                whisper_kv_cache_clear(state->kv_self);
+
+                whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
+
+                if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
                     WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                     return -7;
                 }
@@ -4706,20 +5303,14 @@ int whisper_full_with_state(
                 {
                     const int64_t t_start_sample_us = ggml_time_us();
 
-                    whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
+                    state->decoders[0].i_batch = prompt.size() - 1;
 
-                    state->decoders[0].kv_self.n += prompt.size();
+                    whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
 
                     for (int j = 1; j < n_decoders_cur; ++j) {
                         auto & decoder = state->decoders[j];
 
-                        // TODO: fix CUDA
-                        //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
-                        //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
-                        ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
-                        ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
-
-                        decoder.kv_self.n += prompt.size();
+                        whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
 
                         memcpy(decoder.probs.data(),    state->decoders[0].probs.data(),    decoder.probs.size()*sizeof(decoder.probs[0]));
                         memcpy(decoder.logits.data(),   state->decoders[0].logits.data(),   decoder.logits.size()*sizeof(decoder.logits[0]));
@@ -4734,41 +5325,81 @@ int whisper_full_with_state(
                 const int64_t t_start_sample_us = ggml_time_us();
 
                 if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
-                    beam_candidates.clear();
+                    for (auto & bc : bc_per_dec) {
+                        bc.clear();
+                    }
                 }
 
-                // generate new sequence candidates for each decoder
-                for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = state->decoders[j];
+                // sampling
+                // TODO: avoid memory allocations, optimize, avoid threads?
+                {
+                    std::atomic<int> j_cur(0);
 
-                    if (decoder.completed || decoder.failed) {
-                        continue;
-                    }
+                    auto process = [&]() {
+                        while (true) {
+                            const int j = j_cur.fetch_add(1);
 
-                    switch (params.strategy) {
-                        case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
-                            {
-                                if (t_cur < 1e-6f) {
-                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
-                                } else {
-                                    decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
-                                }
+                            if (j >= n_decoders_cur) {
+                                break;
+                            }
 
-                                decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
-                            } break;
-                        case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
-                            {
-                                const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
+                            auto & decoder = state->decoders[j];
 
-                                for (const auto & token : tokens_new) {
-                                    beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
-                                    beam_candidates.back().sequence.tokens.push_back(token);
-                                    beam_candidates.back().sequence.sum_logprobs_all += token.plog;
+                            if (decoder.completed || decoder.failed) {
+                                continue;
+                            }
 
-                                    //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
-                                }
-                            } break;
+                            switch (params.strategy) {
+                                case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
+                                    {
+                                        if (t_cur < 1e-6f) {
+                                            decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
+                                        } else {
+                                            decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
+                                        }
+
+                                        decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
+                                    } break;
+                                case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
+                                    {
+                                        const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
+
+                                        for (const auto & token : tokens_new) {
+                                            bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
+                                            bc_per_dec[j].back().sequence.tokens.push_back(token);
+                                            bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
+                                        }
+                                    } break;
+                            };
+                        }
                     };
+
+                    const int n_threads = std::min(params.n_threads, n_decoders_cur);
+
+                    if (n_threads == 1) {
+                        process();
+                    } else {
+                        std::vector<std::thread> threads(n_threads - 1);
+
+                        for (int t = 0; t < n_threads - 1; ++t) {
+                            threads[t] = std::thread(process);
+                        }
+
+                        process();
+
+                        for (int t = 0; t < n_threads - 1; ++t) {
+                            threads[t].join();
+                        }
+                    }
+                }
+
+                beam_candidates.clear();
+                for (const auto & bc : bc_per_dec) {
+                    beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
+
+                    if (!bc.empty()) {
+                        state->n_sample += 1;
+                    }
                 }
 
                 // for beam-search, choose the top candidates and update the KV caches
@@ -4781,7 +5412,6 @@ int whisper_full_with_state(
                     });
 
                     uint32_t cur_c = 0;
-                    std::vector<int> decoder_idx(n_decoders_cur, -1);
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
                         auto & decoder = state->decoders[j];
@@ -4790,23 +5420,38 @@ int whisper_full_with_state(
                             continue;
                         }
 
+                        if (cur_c >= beam_candidates.size()) {
+                            cur_c = 0;
+                        }
+
                         auto & cur = beam_candidates[cur_c++];
 
                         while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
                             ++cur_c;
                         }
 
-                        decoder.sequence   = cur.sequence;
                         decoder.seek_delta = cur.seek_delta;
                         decoder.has_ts     = cur.has_ts;
+                        decoder.sequence   = cur.sequence;
+                        decoder.grammar    = cur.grammar;
+
+                        whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
 
-                        decoder_idx[j] = cur.decoder_idx;
                         WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
                                 __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
                     }
 
-                    // update KV caches
-                    whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
+                    for (int j = 0; j < n_decoders_cur; ++j) {
+                        auto & decoder = state->decoders[j];
+
+                        if (decoder.completed || decoder.failed) {
+                            continue;
+                        }
+
+                        whisper_kv_cache_seq_rm(state->kv_self, j,                           -1, -1);
+                        whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
+                        whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j,    -1, -1);
+                    }
                 }
 
                 // update the decoder state
@@ -4844,6 +5489,8 @@ int whisper_full_with_state(
                             has_ts = true;
                         }
 
+                        whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
+
 #ifdef WHISPER_DEBUG
                         {
                             const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
@@ -4913,32 +5560,83 @@ int whisper_full_with_state(
                 state->t_sample_us += ggml_time_us() - t_start_sample_us;
 
                 // obtain logits for the next token
-                for (int j = 0; j < n_decoders_cur; ++j) {
-                    auto & decoder = state->decoders[j];
+                {
+                    auto & batch = state->batch;
 
-                    if (decoder.failed || decoder.completed) {
-                        continue;
-                    }
+                    batch.n_tokens = 0;
 
-                    decoder.tokens_tmp.resize(1);
-                    decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id;
+                    const int n_past = prompt.size() + i;
 
-                    //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
+                    for (int j = 0; j < n_decoders_cur; ++j) {
+                        auto & decoder = state->decoders[j];
 
-                    if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
+                        if (decoder.failed || decoder.completed) {
+                            continue;
+                        }
+
+                        //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
+
+                        decoder.i_batch = batch.n_tokens;
+
+                        batch.token   [batch.n_tokens]    = decoder.sequence.tokens.back().id;
+                        batch.pos     [batch.n_tokens]    = n_past;
+                        batch.n_seq_id[batch.n_tokens]    = 1;
+                        batch.seq_id  [batch.n_tokens][0] = j;
+                        batch.logits  [batch.n_tokens]    = 1;
+                        batch.n_tokens++;
+                    }
+
+                    assert(batch.n_tokens > 0);
+
+                    if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
                         WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                         return -8;
                     }
 
+                    const int64_t t_start_sample_us = ggml_time_us();
+
+                    // TODO: avoid memory allocations, optimize, avoid threads?
                     {
-                        const int64_t t_start_sample_us = ggml_time_us();
+                        std::atomic<int> j_cur(0);
+
+                        auto process = [&]() {
+                            while (true) {
+                                const int j = j_cur.fetch_add(1);
 
-                        whisper_process_logits(*ctx, *state, params, decoder, t_cur);
+                                if (j >= n_decoders_cur) {
+                                    break;
+                                }
 
-                        ++decoder.kv_self.n;
+                                auto & decoder = state->decoders[j];
+
+                                if (decoder.failed || decoder.completed) {
+                                    continue;
+                                }
+
+                                whisper_process_logits(*ctx, *state, decoder, params, t_cur);
+                            }
+                        };
 
-                        state->t_sample_us += ggml_time_us() - t_start_sample_us;
+                        const int n_threads = std::min(params.n_threads, n_decoders_cur);
+
+                        if (n_threads == 1) {
+                            process();
+                        } else {
+                            std::vector<std::thread> threads(n_threads - 1);
+
+                            for (int t = 0; t < n_threads - 1; ++t) {
+                                threads[t] = std::thread(process);
+                            }
+
+                            process();
+
+                            for (int t = 0; t < n_threads - 1; ++t) {
+                                threads[t].join();
+                            }
+                        }
                     }
+
+                    state->t_sample_us += ggml_time_us() - t_start_sample_us;
                 }
             }
 
@@ -5235,11 +5933,13 @@ int whisper_full_parallel(
         ctx->state->t_sample_us += states[i]->t_sample_us;
         ctx->state->t_encode_us += states[i]->t_encode_us;
         ctx->state->t_decode_us += states[i]->t_decode_us;
+        ctx->state->t_batchd_us += states[i]->t_batchd_us;
         ctx->state->t_prompt_us += states[i]->t_prompt_us;
 
         ctx->state->n_sample += states[i]->n_sample;
         ctx->state->n_encode += states[i]->n_encode;
         ctx->state->n_decode += states[i]->n_decode;
+        ctx->state->n_batchd += states[i]->n_batchd;
         ctx->state->n_prompt += states[i]->n_prompt;
 
         whisper_free_state(states[i]);
@@ -5372,8 +6072,8 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
     size_t n    = 20;
     size_t arr  = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
 
-    // 1GB MB array
-    const size_t size = arr*1024llu*1024llu;
+    // 1GB array
+    const size_t size = arr*1e6;
 
     // single-thread
     {
@@ -5399,7 +6099,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
             src[rand() % size] = rand() % 256;
         }
 
-        snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
+        snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s (1 thread)\n", (double) (n*size)/(tsum*1e9));
         s += strbuf;
 
         // needed to prevent the compiler from optimizing the memcpy away
index 0ea5237e5f2f7c0d51a2432439f8d425e23b78e6..84540989d2ed339fb1ef4e834fa222af44e49e7b 100644 (file)
@@ -78,7 +78,9 @@ extern "C" {
     struct whisper_state;
     struct whisper_full_params;
 
-    typedef int whisper_token;
+    typedef int32_t whisper_pos;
+    typedef int32_t whisper_token;
+    typedef int32_t whisper_seq_id;
 
     struct whisper_context_params {
         bool  use_gpu;
@@ -109,6 +111,37 @@ extern "C" {
         void  (*close)(void * ctx);
     } whisper_model_loader;
 
+    // grammar element type
+    enum whisper_gretype {
+        // end of rule definition
+        WHISPER_GRETYPE_END            = 0,
+
+        // start of alternate definition for rule
+        WHISPER_GRETYPE_ALT            = 1,
+
+        // non-terminal element: reference to rule
+        WHISPER_GRETYPE_RULE_REF       = 2,
+
+        // terminal element: character (code point)
+        WHISPER_GRETYPE_CHAR           = 3,
+
+        // inverse char(s) ([^a], [^a-b] [^abc])
+        WHISPER_GRETYPE_CHAR_NOT       = 4,
+
+        // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
+        // be an inclusive range ([a-z])
+        WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
+
+        // modifies a preceding WHISPER_GRETYPE_CHAR or
+        // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+        WHISPER_GRETYPE_CHAR_ALT       = 6,
+    };
+
+    typedef struct whisper_grammar_element {
+        enum whisper_gretype type;
+        uint32_t             value; // Unicode code point or rule ID
+    } whisper_grammar_element;
+
     // Various functions for loading a ggml whisper model.
     // Allocate (almost) all memory needed for the model.
     // Return NULL on failure
@@ -402,6 +435,7 @@ extern "C" {
 
         bool translate;
         bool no_context;        // do not use past transcription (if any) as initial prompt for the decoder
+        bool no_timestamps;     // do not generate timestamps
         bool single_segment;    // force single segment output (useful for streaming)
         bool print_special;     // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
         bool print_progress;    // print progress information
@@ -479,6 +513,11 @@ extern "C" {
         // called by each decoder to filter obtained logits
         whisper_logits_filter_callback logits_filter_callback;
         void * logits_filter_callback_user_data;
+
+        const whisper_grammar_element ** grammar_rules;
+        size_t                           n_grammar_rules;
+        size_t                           i_start_rule;
+        float                            grammar_penalty;
     };
 
     // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
index d65a58e548c183e00816fdb23585e2ac8a748e9f..d250000ded2c44ae4b28b3767f1becd850a60a8c 100644 (file)
@@ -26,6 +26,15 @@ if (NOT UNAME_M)
 endif()
 #message(STATUS "UNAME_S: ${UNAME_S}  UNAME_P: ${UNAME_P}  UNAME_M: ${UNAME_M}")
 
+# this version of Apple ld64 is buggy
+execute_process(
+    COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
+    ERROR_VARIABLE output
+)
+if (output MATCHES "dyld-1015\.7")
+    add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
+endif()
+
 # Mac OS + Arm can report x86_64
 # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
 if (UNAME_S MATCHES "Darwin")
index 3d22b0b27e444f2a1995a7e729ee7171faa189a3..4fe9cc487d29f6eb2cac0c4fb0a172d6b6d28136 100644 (file)
@@ -346,9 +346,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     }
 
     GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
-    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+    GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
     if (ctx->device.maxTransferRate != 0) {
-        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
+        GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
     } else {
         GGML_METAL_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
     }
@@ -541,11 +541,11 @@ bool ggml_metal_add_buffer(
             ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
             if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
+                GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1e6);
                 return false;
             }
 
-            GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
+            GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1e6);
 
             ++ctx->n_buffers;
         } else {
@@ -565,11 +565,11 @@ bool ggml_metal_add_buffer(
                 ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
                 if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                    GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
+                    GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1e6);
                     return false;
                 }
 
-                GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
+                GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1e6, i);
                 if (i + size_step < size) {
                     GGML_METAL_LOG_INFO("\n");
                 }
@@ -580,8 +580,8 @@ bool ggml_metal_add_buffer(
 
 #if TARGET_OS_OSX
         GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
-                ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
-                ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+                ctx->device.currentAllocatedSize / 1e6,
+                ctx->device.recommendedMaxWorkingSetSize / 1e6);
 
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
             GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
@@ -589,7 +589,7 @@ bool ggml_metal_add_buffer(
             GGML_METAL_LOG_INFO("\n");
         }
 #else
-        GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
+        GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1e6);
 #endif
     }