]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : whisper.cpp (whisper full GPU, fix warnings) (#606)
authorGeorgi Gerganov <redacted>
Sun, 12 Nov 2023 14:35:03 +0000 (16:35 +0200)
committerGitHub <redacted>
Sun, 12 Nov 2023 14:35:03 +0000 (16:35 +0200)
* sync : whisper.cpp (whisper full GPU, fix warnings)

ggml-ci

* ci : enable CUDA / Metal

ggml-ci

* cuda : fallback to CPU for mul mat ne03 != ne13 (fix SAM + CUDA)

ggml-ci

ci/run.sh
examples/common.h
examples/whisper/CMakeLists.txt
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
src/ggml-cuda.cu
src/ggml-metal.h
src/ggml-metal.m
tests/test-conv1d.cpp
tests/test-conv2d.cpp

index 7afe8304fe402170a694f02c5de229d62ca8b1bd..a593090f37b80e6ddeac6739b76c276c61829d0c 100644 (file)
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -30,6 +30,16 @@ sd=`dirname $0`
 cd $sd/../
 SRC=`pwd`
 
+CMAKE_EXTRA=""
+
+if [ ! -z ${GG_BUILD_CUDA} ]; then
+    CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUBLAS=ON"
+fi
+
+if [ ! -z ${GG_BUILD_METAL} ]; then
+    CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
+fi
+
 ## helpers
 
 # download a file if it does not exist or if it is outdated
@@ -81,8 +91,8 @@ function gg_run_ctest_debug {
 
     set -e
 
-    (time cmake -DCMAKE_BUILD_TYPE=Debug ..     ) 2>&1 | tee -a $OUT/${ci}-cmake.log
-    (time make -j                               ) 2>&1 | tee -a $OUT/${ci}-make.log
+    (time cmake -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} ..     ) 2>&1 | tee -a $OUT/${ci}-cmake.log
+    (time make -j                                              ) 2>&1 | tee -a $OUT/${ci}-make.log
 
     (time ctest --output-on-failure -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
 
@@ -109,8 +119,8 @@ function gg_run_ctest_release {
 
     set -e
 
-    (time cmake -DCMAKE_BUILD_TYPE=Release ..   ) 2>&1 | tee -a $OUT/${ci}-cmake.log
-    (time make -j                               ) 2>&1 | tee -a $OUT/${ci}-make.log
+    (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} ..   ) 2>&1 | tee -a $OUT/${ci}-cmake.log
+    (time make -j                                              ) 2>&1 | tee -a $OUT/${ci}-make.log
 
     if [ -z $GG_BUILD_LOW_PERF ]; then
         (time ctest --output-on-failure ) 2>&1 | tee -a $OUT/${ci}-ctest.log
index 9a94bab7a7099c569f6f559ca1e6a15547ce4f3e..54f0b00d0ef41fd4e65c969663536c85bcb7680c 100644 (file)
@@ -181,7 +181,7 @@ private:
     // It is assumed that PCM data is normalized to a range from -1 to 1
     bool write_audio(const float * data, size_t length) {
         for (size_t i = 0; i < length; ++i) {
-            const auto intSample = static_cast<const int16_t>(data[i] * 32767);
+            const int16_t intSample = data[i] * 32767;
             file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
             dataSize += sizeof(int16_t);
         }
index a58251d781addf41a8dc887f0259b0f187bc3652..63f8cd40891d66ae94bcd6bb7ead2afe61f00b01 100644 (file)
@@ -13,6 +13,7 @@ set(TEST_TARGET whisper)
 add_executable(${TEST_TARGET} main.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE whisper-cpp common)
 target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
+target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include/ggml)
 
 #
 # whisper-quantize
index bed0789f9ad06ad4b4cd04d2b8bdff586abb1c98..e43dfe3f948f916836fb6022f4379355bd5d2bcf 100644 (file)
@@ -90,6 +90,7 @@ struct whisper_params {
     bool print_progress  = false;
     bool no_timestamps   = false;
     bool log_score       = false;
+    bool use_gpu         = true;
 
     std::string language  = "en";
     std::string prompt;
@@ -165,6 +166,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-f"    || arg == "--file")            { params.fname_inp.emplace_back(argv[++i]); }
         else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
         else if (arg == "-ls"   || arg == "--log-score")       { params.log_score = true; }
+        else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu = false; }
         else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             whisper_print_usage(argc, argv, params);
@@ -221,6 +223,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
     fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
+    fprintf(stderr, "  -ng,       --no-gpu            [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
     fprintf(stderr, "\n");
 }
 
@@ -877,7 +880,10 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
+    struct whisper_context_params cparams;
+    cparams.use_gpu = params.use_gpu;
+
+    struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     if (ctx == nullptr) {
         fprintf(stderr, "error: failed to initialize whisper context\n");
index 17ef4d9e8abbeb1a7b4eb8e5dc2ced88463aeaf6..1c7d7e9440dcb82cdedf3172aec31ee01c74e46e 100644 (file)
@@ -1,10 +1,15 @@
 #include "whisper.h"
+
 #ifdef WHISPER_USE_COREML
 #include "coreml/whisper-encoder.h"
 #endif
 
 #ifdef GGML_USE_METAL
-#  include "ggml-metal.h"
+#include "ggml-metal.h"
+#endif
+
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
 #endif
 
 #ifdef WHISPER_USE_OPENVINO
@@ -13,6 +18,7 @@
 
 #include "ggml.h"
 #include "ggml-alloc.h"
+#include "ggml-backend.h"
 
 #include <algorithm>
 #include <cassert>
@@ -97,10 +103,32 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 #define BYTESWAP_TENSOR(t) do {} while (0)
 #endif
 
+#ifdef __GNUC__
+#ifdef __MINGW32__
+#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+#else
+#define WHISPER_ATTRIBUTE_FORMAT(...)
+#endif
+
+//
+// logging
+//
+
+WHISPER_ATTRIBUTE_FORMAT(2, 3)
+static void whisper_log_internal        (ggml_log_level level, const char * format, ...);
+static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data);
+
+#define WHISPER_LOG_INFO(...)  whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
+#define WHISPER_LOG_WARN(...)  whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
+#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
 #define WHISPER_ASSERT(x) \
     do { \
         if (!(x)) { \
-            log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
             abort(); \
         } \
     } while (0)
@@ -127,8 +155,8 @@ static void byteswap_tensor(ggml_tensor * tensor) {
 //
 
 static void ggml_graph_compute_helper(
+          struct ggml_cgraph * graph,
         std::vector<uint8_t> & buf,
-                 ggml_cgraph * graph,
                          int   n_threads,
       whisper_abort_callback   abort_callback,
                         void * abort_callback_data) {
@@ -145,6 +173,21 @@ static void ggml_graph_compute_helper(
     ggml_graph_compute(graph, &plan);
 }
 
+static void ggml_graph_compute_helper(
+       struct ggml_backend * backend,
+        struct ggml_cgraph * graph,
+                       int   n_threads) {
+    if (ggml_backend_is_cpu(backend)) {
+        ggml_backend_cpu_set_n_threads(backend, n_threads);
+    }
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(backend)) {
+        ggml_backend_metal_set_n_cb(backend, n_threads);
+    }
+#endif
+    ggml_backend_graph_compute(backend, graph);
+}
+
 // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
 // the idea is to represent the original matrix multiplication:
 //
@@ -179,6 +222,7 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g
 }
 
 // TODO: check if other platforms can benefit from this optimization
+// TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly
 #if defined(GGML_USE_METAL)
 #define ggml_mul_mat ggml_mul_mat_pad
 #endif
@@ -193,6 +237,15 @@ enum e_model {
     MODEL_LARGE,
 };
 
+static const std::map<e_model, std::string> g_model_name = {
+    { MODEL_UNKNOWN,  "unknown"  },
+    { MODEL_TINY,     "tiny"     },
+    { MODEL_BASE,     "base"     },
+    { MODEL_SMALL,    "small"    },
+    { MODEL_MEDIUM,   "medium"   },
+    { MODEL_LARGE,    "large"    },
+};
+
 static const std::map<std::string, std::pair<int, std::string>> g_lang = {
     { "en",  { 0,  "english",         } },
     { "zh",  { 1,  "chinese",         } },
@@ -293,75 +346,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
     { "ba",  { 96,  "bashkir",        } },
     { "jw",  { 97,  "javanese",       } },
     { "su",  { 98,  "sundanese",      } },
-};
-
-static const size_t MB = 1ull*1024*1024;
-
-// TODO: avoid using GGUF
-static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
-    { GGML_TYPE_F32,
-        {
-            { MODEL_TINY,     74ull*MB },
-            { MODEL_BASE,    142ull*MB },
-            { MODEL_SMALL,   466ull*MB },
-            { MODEL_MEDIUM, 1464ull*MB },
-            { MODEL_LARGE,  2952ull*MB },
-        },
-    },
-    { GGML_TYPE_F16,
-        {
-            { MODEL_TINY,     74ull*MB },
-            { MODEL_BASE,    142ull*MB },
-            { MODEL_SMALL,   466ull*MB },
-            { MODEL_MEDIUM, 1464ull*MB },
-            { MODEL_LARGE,  2952ull*MB },
-        },
-    },
-    { GGML_TYPE_Q4_0,
-        {
-            { MODEL_TINY,     26ull*MB },
-            { MODEL_BASE,     50ull*MB },
-            { MODEL_SMALL,   154ull*MB },
-            { MODEL_MEDIUM,  470ull*MB },
-            { MODEL_LARGE,   940ull*MB },
-        },
-    },
-    { GGML_TYPE_Q4_1,
-        {
-            { MODEL_TINY,     32ull*MB },
-            { MODEL_BASE,     58ull*MB },
-            { MODEL_SMALL,   182ull*MB },
-            { MODEL_MEDIUM,  562ull*MB },
-            { MODEL_LARGE,  1124ull*MB },
-        },
-    },
-    { GGML_TYPE_Q5_0,
-        {
-            { MODEL_TINY,     30ull*MB },
-            { MODEL_BASE,     54ull*MB },
-            { MODEL_SMALL,   170ull*MB },
-            { MODEL_MEDIUM,  516ull*MB },
-            { MODEL_LARGE,  1034ull*MB },
-        },
-    },
-    { GGML_TYPE_Q5_1,
-        {
-            { MODEL_TINY,     32ull*MB },
-            { MODEL_BASE,     58ull*MB },
-            { MODEL_SMALL,   182ull*MB },
-            { MODEL_MEDIUM,  562ull*MB },
-            { MODEL_LARGE,  1124ull*MB },
-        },
-    },
-    { GGML_TYPE_Q8_0,
-        {
-            { MODEL_TINY,     45ull*MB },
-            { MODEL_BASE,     84ull*MB },
-            { MODEL_SMALL,   268ull*MB },
-            { MODEL_MEDIUM,  834ull*MB },
-            { MODEL_LARGE,  1674ull*MB },
-        },
-    },
+    { "yue", { 99,  "cantonese",      } },
 };
 
 struct whisper_mel {
@@ -402,7 +387,11 @@ struct whisper_vocab {
     id token_beg        = 50363; // begin timestamps
 
     bool is_multilingual() const {
-        return n_vocab == 51865;
+        return n_vocab >= 51865;
+    }
+
+    int num_languages() const {
+        return n_vocab - 51765 - (is_multilingual() ? 1 : 0);
     }
 };
 
@@ -540,8 +529,7 @@ struct whisper_kv_cache {
 
     struct ggml_context * ctx;
 
-    // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init)
-    std::vector<uint8_t> buf;
+    ggml_backend_buffer_t buffer;
 
     int n; // number of tokens currently in the cache
 };
@@ -580,11 +568,11 @@ struct whisper_model {
     std::vector<whisper_layer_encoder> layers_encoder;
     std::vector<whisper_layer_decoder> layers_decoder;
 
-    // context
+    // ggml context that contains all the meta information about the model tensors
     struct ggml_context * ctx;
 
-    // the model memory buffer is read-only and can be shared between processors
-    std::vector<uint8_t> * buf;
+    // the model backend data is read-only and can be shared between processors
+    struct ggml_backend_buffer * buffer;
 
     // tensors
     int n_loaded;
@@ -649,37 +637,47 @@ struct whisper_allocr {
     ggml_allocr * alloc = nullptr;
 
     std::vector<uint8_t> meta;
-    std::vector<uint8_t> data;
+
+    ggml_backend_buffer_t buffer;
 };
 
 static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
-    return allocr.meta.size() + allocr.data.size();
+    return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
 }
 
 // measure the memory usage of a graph and prepare the allocr's internal data buffer
-static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
-    const int tensor_alignment = 32;
+static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
+    auto & alloc  = allocr.alloc;
+    auto & meta   = allocr.meta;
 
-    auto & alloc = allocr.alloc;
-    auto & meta  = allocr.meta;
-    auto & data  = allocr.data;
+    alloc = ggml_allocr_new_measure_from_backend(backend);
 
     meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
 
-    alloc = ggml_allocr_new_measure(tensor_alignment);
+    ggml_allocr_alloc_graph(alloc, get_graph());
+}
 
-    const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
+static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
+    if (allocr.alloc == nullptr) {
+        // this can be null if we use external encoder like CoreML or OpenVINO
+        return;
+    }
 
-    ggml_allocr_free(alloc);
+    auto & alloc  = allocr.alloc;
+    auto & buffer = allocr.buffer;
+
+    size_t size = ggml_allocr_max_size(alloc);
 
-    data.resize(alloc_size);
+    ggml_allocr_free(alloc);
 
-    alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
+    buffer = ggml_backend_alloc_buffer(backend, size);
+    alloc = ggml_allocr_new_from_buffer(buffer);
 }
 
 static void whisper_allocr_free(struct whisper_allocr & allocr) {
     if (allocr.alloc) {
         ggml_allocr_free(allocr.alloc);
+        ggml_backend_buffer_free(allocr.buffer);
         allocr.alloc = nullptr;
     }
 }
@@ -708,8 +706,7 @@ struct whisper_state {
     // buffer for swapping KV caches between decoders during beam-search
     std::vector<kv_buf> kv_swap_bufs;
 
-    // reusable buffer for `struct ggml_graph_plan.work_data`
-    std::vector<uint8_t> work_buffer;
+    ggml_backend_t backend = nullptr;
 
     // ggml-alloc:
     // - stores meta info about the intermediate tensors into the `meta` buffers
@@ -723,6 +720,9 @@ struct whisper_state {
     struct ggml_tensor * embd_conv = nullptr;
     struct ggml_tensor * embd_enc  = nullptr;
 
+    // helper for GPU offloading
+    std::vector<float> inp_mel;
+
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
 
@@ -736,23 +736,22 @@ struct whisper_state {
 
     int lang_id = 0; // english by default
 
-    std::string path_model; // populated by whisper_init_from_file()
+    std::string path_model; // populated by whisper_init_from_file_with_params()
+
 #ifdef WHISPER_USE_COREML
     whisper_coreml_context * ctx_coreml = nullptr;
 #endif
 
-#ifdef GGML_USE_METAL
-    ggml_metal_context * ctx_metal = nullptr;
-#endif
-
 #ifdef WHISPER_USE_OPENVINO
     whisper_openvino_context * ctx_openvino = nullptr;
 #endif
 
     // [EXPERIMENTAL] token-level timestamps data
-    int64_t t_beg = 0;
+    int64_t t_beg  = 0;
     int64_t t_last = 0;
+
     whisper_token tid_last;
+
     std::vector<float> energy; // PCM signal energy
 
     // [EXPERIMENTAL] speed-up techniques
@@ -766,34 +765,25 @@ struct whisper_context {
     ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
     ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
 
+    whisper_context_params params;
+
     whisper_model model;
     whisper_vocab vocab;
+
     whisper_state * state = nullptr;
 
-    std::string path_model; // populated by whisper_init_from_file()
-};
+    ggml_backend_t backend = nullptr;
 
-static void whisper_default_log(const char * text) {
-    fprintf(stderr, "%s", text);
-}
+    std::string path_model; // populated by whisper_init_from_file_with_params()
+};
 
-static whisper_log_callback whisper_log = whisper_default_log;
+struct whisper_global {
+    // We save the log callback globally
+    ggml_log_callback log_callback = whisper_log_callback_default;
+    void * log_callback_user_data = nullptr;
+};
 
-#ifdef __GNUC__
-#ifdef __MINGW32__
-__attribute__((gnu_format(printf, 1, 2)))
-#else
-__attribute__((format(printf, 1, 2)))
-#endif
-#endif
-static void log(const char * fmt, ...) {
-    if (!whisper_log) return;
-    char buf[1024];
-    va_list args;
-    va_start(args, fmt);
-    vsnprintf(buf, sizeof(buf), fmt, args);
-    whisper_log(buf);
-}
+static whisper_global g_state;
 
 template<typename T>
 static void read_safe(whisper_model_loader * loader, T & dest) {
@@ -804,6 +794,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
 static bool kv_cache_init(
         const struct whisper_hparams & hparams,
              struct whisper_kv_cache & cache,
+                      ggml_backend_t   backend,
                            ggml_type   wtype,
                                  int   n_ctx) {
     const int64_t n_text_state = hparams.n_text_state;
@@ -812,30 +803,41 @@ static bool kv_cache_init(
     const int64_t n_mem      = n_text_layer*n_ctx;
     const int64_t n_elements = n_text_state*n_mem;
 
-    const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
-
-    cache.buf.resize(mem_bytes);
-
     struct ggml_init_params params = {
-        /*.mem_size   =*/ cache.buf.size(),
-        /*.mem_buffer =*/ cache.buf.data(),
-        /*.no_alloc   =*/ false,
+        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
+        /*.mem_buffer =*/ nullptr,
+        /*.no_alloc   =*/ true,
     };
 
     cache.ctx = ggml_init(params);
 
     if (!cache.ctx) {
-        log("%s: failed to allocate memory for kv cache\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
         return false;
     }
 
     cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
     cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
 
+    const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
+
+    cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
+
+    // allocate the tensors into the backend buffer
+    {
+        ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
+
+        ggml_allocr_alloc(alloc, cache.k);
+        ggml_allocr_alloc(alloc, cache.v);
+
+        ggml_allocr_free(alloc);
+    }
+
     return true;
 }
 
-static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
+// TODO: remove after batched decoding
+static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
     WHISPER_ASSERT(cache.ctx);
 
     const int n_elements = ggml_nelements(cache.k);
@@ -844,34 +846,78 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
     const ggml_type wtype = cache.k->type;
     WHISPER_ASSERT(wtype == cache.v->type);
 
-    WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
-
     struct ggml_init_params params = {
-        /*.mem_size   =*/ cache.buf.size(),
-        /*.mem_buffer =*/ cache.buf.data(),
-        /*.no_alloc   =*/ false,
+        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
+        /*.mem_buffer =*/ nullptr,
+        /*.no_alloc   =*/ true,
     };
 
     cache.ctx = ggml_init(params);
 
     if (!cache.ctx) {
-        log("%s: failed to allocate memory for kv cache\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
         return false;
     }
 
     cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
     cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
 
+    const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
+
+    cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
+
+    // allocate the tensors into the backend buffer
+    {
+        ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
+
+        ggml_allocr_alloc(alloc, cache.k);
+        ggml_allocr_alloc(alloc, cache.v);
+
+        ggml_allocr_free(alloc);
+    }
+
     return true;
 }
 
 static void kv_cache_free(struct whisper_kv_cache & cache) {
     if (cache.ctx) {
         ggml_free(cache.ctx);
+        ggml_backend_buffer_free(cache.buffer);
         cache.ctx = nullptr;
     }
 }
 
+static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
+    ggml_backend_t backend_gpu = NULL;
+
+    // initialize the backends
+#ifdef GGML_USE_CUBLAS
+    if (params.use_gpu) {
+        WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
+        backend_gpu = ggml_backend_cuda_init();
+        if (!backend_gpu) {
+            WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+    }
+#endif
+
+#ifdef GGML_USE_METAL
+    if (params.use_gpu) {
+        WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
+        ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
+        backend_gpu = ggml_backend_metal_init();
+        if (!backend_gpu) {
+            WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
+        }
+    }
+#endif
+
+    if (backend_gpu) {
+        return backend_gpu;
+    }
+    return ggml_backend_cpu_init();
+}
+
 // load the model from a ggml file
 //
 // file format:
@@ -884,7 +930,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
 // see the convert-pt-to-ggml.py script for details
 //
 static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
-    log("%s: loading model\n", __func__);
+    WHISPER_LOG_INFO("%s: loading model\n", __func__);
 
     const int64_t t_start_us = ggml_time_us();
 
@@ -898,7 +944,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         uint32_t magic;
         read_safe(loader, magic);
         if (magic != GGML_FILE_MAGIC) {
-            log("%s: invalid model data (bad magic)\n", __func__);
+            WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
             return false;
         }
     }
@@ -921,6 +967,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         assert(hparams.n_text_state == hparams.n_audio_state);
 
+        std::string mver = "";
+
         if (hparams.n_audio_layer == 4) {
             model.type = e_model::MODEL_TINY;
         }
@@ -939,6 +987,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         if (hparams.n_audio_layer == 32) {
             model.type = e_model::MODEL_LARGE;
+
+            if (hparams.n_vocab == 51866) {
+                mver = " v3";
+            }
         }
 
         const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
@@ -949,41 +1001,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         // in order to save memory and also to speed up the computation
         wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
         if (wctx.wtype == GGML_TYPE_COUNT) {
-            log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
+            WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
             return false;
         }
 
-        const size_t scale = model.hparams.ftype ? 1 : 2;
-
-        log("%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
-        log("%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
-        log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
-        log("%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
-        log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
-        log("%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
-        log("%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
-        log("%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
-        log("%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
-        log("%s: n_mels        = %d\n", __func__, hparams.n_mels);
-        log("%s: ftype         = %d\n", __func__, model.hparams.ftype);
-        log("%s: qntvr         = %d\n", __func__, qntvr);
-        log("%s: type          = %d\n", __func__, model.type);
-
-        // print memory requirements
-        {
-            // TODO
-            //log("%s: mem required  = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
-            //        mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
-        }
-
-        // initialize all memory buffers
-        // always have at least one decoder
-
-        wctx.model.buf = new std::vector<uint8_t>();
-        wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
-
-        // we skip initialization of the state until it is needed
-        // because it might be that state will always be provided externally.
+        WHISPER_LOG_INFO("%s: n_vocab       = %d\n", __func__, hparams.n_vocab);
+        WHISPER_LOG_INFO("%s: n_audio_ctx   = %d\n", __func__, hparams.n_audio_ctx);
+        WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
+        WHISPER_LOG_INFO("%s: n_audio_head  = %d\n", __func__, hparams.n_audio_head);
+        WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
+        WHISPER_LOG_INFO("%s: n_text_ctx    = %d\n", __func__, hparams.n_text_ctx);
+        WHISPER_LOG_INFO("%s: n_text_state  = %d\n", __func__, hparams.n_text_state);
+        WHISPER_LOG_INFO("%s: n_text_head   = %d\n", __func__, hparams.n_text_head);
+        WHISPER_LOG_INFO("%s: n_text_layer  = %d\n", __func__, hparams.n_text_layer);
+        WHISPER_LOG_INFO("%s: n_mels        = %d\n", __func__, hparams.n_mels);
+        WHISPER_LOG_INFO("%s: ftype         = %d\n", __func__, model.hparams.ftype);
+        WHISPER_LOG_INFO("%s: qntvr         = %d\n", __func__, qntvr);
+        WHISPER_LOG_INFO("%s: type          = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
     }
 
     // load mel filters
@@ -1004,7 +1038,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         read_safe(loader, n_vocab);
 
         //if (n_vocab != model.hparams.n_vocab) {
-        //    log("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+        //    WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
         //            __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
         //    return false;
         //}
@@ -1024,7 +1058,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 word.assign(&tmp[0], tmp.size());
             } else {
                 // seems like we have an empty-string token in multi-language models (i = 50256)
-                //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
+                //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
                 word = "";
             }
 
@@ -1038,17 +1072,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         if (vocab.is_multilingual()) {
             vocab.token_eot++;
             vocab.token_sot++;
-            vocab.token_translate++;
-            vocab.token_transcribe++;
-            vocab.token_solm++;
-            vocab.token_prev++;
-            vocab.token_nosp++;
-            vocab.token_not++;
-            vocab.token_beg++;
+
+            // account for variable number of language tokens
+            const int dt = vocab.num_languages() - 98;
+
+            vocab.token_translate  += dt;
+            vocab.token_transcribe += dt;
+            vocab.token_solm       += dt;
+            vocab.token_prev       += dt;
+            vocab.token_nosp       += dt;
+            vocab.token_not        += dt;
+            vocab.token_beg        += dt;
         }
 
         if (n_vocab < model.hparams.n_vocab) {
-            log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
+            WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
             for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
                 if (i > vocab.token_beg) {
                     word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
@@ -1073,139 +1111,36 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 vocab.id_to_token[i] = word;
             }
         }
-    }
 
-    size_t ctx_size = 0;
+        WHISPER_LOG_INFO("%s: n_langs       = %d\n", __func__, vocab.num_languages());
+    }
 
     const ggml_type wtype = wctx.wtype;
     const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
 
+    // create the ggml context
     {
         const auto & hparams = model.hparams;
 
-        const int n_vocab = hparams.n_vocab;
-
-        const int n_audio_ctx   = hparams.n_audio_ctx;
-        const int n_audio_state = hparams.n_audio_state;
         const int n_audio_layer = hparams.n_audio_layer;
+        const int n_text_layer  = hparams.n_text_layer;
 
-        const int n_text_ctx   = hparams.n_text_ctx;
-        const int n_text_state = hparams.n_text_state;
-        const int n_text_layer = hparams.n_text_layer;
-
-        const int n_mels = hparams.n_mels;
-
-        // encoder
-        {
-            ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
-
-            ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(vtype);         // e_conv_1_w
-            ctx_size +=          n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
-
-            ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(vtype);         // e_conv_2_w
-            ctx_size +=                 n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
-
-            ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
-            ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
-        }
-
-        // decoder
-        {
-            ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
-
-            ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
-
-            ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
-            ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
-        }
-
-        // encoder layers
-        {
-            ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
-            ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
-
-            ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // mlp_0_w
-            ctx_size += n_audio_layer*(              4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
-
-            ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // mlp_1_w
-            ctx_size += n_audio_layer*(                n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
-
-            ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
-            ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
-
-            ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // attn_q_w
-            ctx_size += n_audio_layer*(              n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
-
-            ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
-
-            ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // attn_v_w
-            ctx_size += n_audio_layer*(              n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
-
-            ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype));         // attn_ln_1_w
-            ctx_size += n_audio_layer*(              n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
-        }
-
-        // decoder layers
-        {
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
-
-            ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype));         // mlp_0_w
-            ctx_size += n_text_layer*(             4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
-
-            ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype));         // mlp_1_w
-            ctx_size += n_text_layer*(               n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
-
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // attn_q_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // attn_v_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // attn_ln_1_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
-                                                                                                //
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
-            ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // cross_attn_q_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // cross_attn_v_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
-
-            ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype));         // cross_attn_ln_1_w
-            ctx_size += n_text_layer*(             n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
-        }
-
-        ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
-
-        log("%s: model ctx     = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
-    }
+        const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
 
-    // create the ggml context
-    {
         struct ggml_init_params params = {
-            /*.mem_size   =*/ wctx.model.buf->size(),
-            /*.mem_buffer =*/ wctx.model.buf->data(),
-            /*.no_alloc   =*/ false,
+            /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
         };
 
         model.ctx = ggml_init(params);
         if (!model.ctx) {
-            log("%s: ggml_init() failed\n", __func__);
+            WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
             return false;
         }
     }
 
-    // prepare memory for the weights
+    // prepare tensors for the weights
     {
         auto & ctx = model.ctx;
 
@@ -1228,16 +1163,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
         // encoder
         {
-            model.e_pe       = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
+            model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
 
-            model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype,         3, n_mels, n_audio_state);
-            model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
+            model.e_conv_1_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_mels,     n_audio_state);
+            model.e_conv_1_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
 
-            model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype,         3, n_audio_state, n_audio_state);
-            model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
+            model.e_conv_2_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_audio_state, n_audio_state);
+            model.e_conv_2_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,    n_audio_ctx,   n_audio_state);
 
-            model.e_ln_w     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
-            model.e_ln_b     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+            model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
+            model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
 
             // map by name
             model.tensors["encoder.positional_embedding"] = model.e_pe;
@@ -1401,12 +1336,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
+    wctx.backend = whisper_backend_init(wctx.params);
+
+    {
+        size_t size_main = 0;
+
+        for (const auto & t : model.tensors) {
+            size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
+        }
+
+        model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main);
+
+        WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0);
+    }
+
+    ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
+
+    // allocate tensors in the backend buffers
+    {
+        for (const auto & t : model.tensors) {
+            ggml_allocr_alloc(alloc, t.second);
+        }
+    }
+
     // load weights
     {
         size_t total_size = 0;
 
         model.n_loaded = 0;
 
+        std::vector<char> read_buf;
+
         while (true) {
             int32_t n_dims;
             int32_t length;
@@ -1433,50 +1393,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             name.assign(&tmp[0], tmp.size());
 
             if (model.tensors.find(name) == model.tensors.end()) {
-                log("%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
                 return false;
             }
 
             auto tensor = model.tensors[name.data()];
-            if (ggml_nelements(tensor) != nelements) {
-                log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
-                log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
-                        __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
-                return false;
-            }
 
-            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
-                log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
-                        __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
-                return false;
-            }
+            const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
 
-            const size_t bpe = ggml_type_size(ggml_type(ttype));
+            if (!is_conv_bias) {
+                if (ggml_nelements(tensor) != nelements) {
+                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                    WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
+                            __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
+                    return false;
+                }
 
-            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
-                log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
-                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
-                return false;
+                if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
+                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+                            __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
+                    return false;
+                }
+
+                const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+                if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                            __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                    return false;
+                }
             }
 
-            loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
-            BYTESWAP_TENSOR(tensor);
+            ggml_backend_t backend = wctx.backend;
+
+            //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
+
+            if ((ggml_backend_is_cpu(backend)
+#ifdef GGML_USE_METAL
+                || ggml_backend_is_metal(backend)
+#endif
+                ) && !is_conv_bias) {
+                // for the CPU and Metal backend, we can read directly into the tensor
+                loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
+                BYTESWAP_TENSOR(tensor);
+            } else {
+                // read into a temporary buffer first, then copy to device memory
+                read_buf.resize(ggml_nbytes(tensor));
+
+                // we repeat the 2 bias tensors along dim 0:
+                // [1, 512] -> [3000, 512] (conv1.bias)
+                // [1, 512] -> [1500, 512] (conv2.bias)
+                if (is_conv_bias) {
+                    loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
+
+                    float * data_f32 = (float *) read_buf.data();
+                    for (int64_t y = 0; y < tensor->ne[1]; ++y) {
+                        const int64_t yy = tensor->ne[1] - y - 1;
+                        const float val = data_f32[yy];
+
+                        for (int64_t x = 0; x < tensor->ne[0]; ++x) {
+                            data_f32[yy*tensor->ne[0] + x] = val;
+                        }
+                    }
+                } else {
+                    loader->read(loader->context, read_buf.data(), read_buf.size());
+                }
+
+                ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
+            }
 
             //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
             model.n_loaded++;
         }
 
-        log("%s: model size    = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
+        WHISPER_LOG_INFO("%s: model size    = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
 
         if (model.n_loaded == 0) {
-            log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+            WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
         } else if (model.n_loaded != (int) model.tensors.size()) {
-            log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+            WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
             return false;
         }
     }
 
+    ggml_allocr_free(alloc);
+
     wctx.t_load_us = ggml_time_us() - t_start_us;
 
     return true;
@@ -1532,10 +1534,12 @@ static struct ggml_cgraph * whisper_build_graph_conv(
     if (!ggml_allocr_is_measure(alloc)) {
         assert(mel_inp.n_mel == n_mels);
 
-        float * dst = (float *) mel->data;
+        wstate.inp_mel.resize(ggml_nelements(mel));
+
+        float * dst = wstate.inp_mel.data();
         memset(dst, 0, ggml_nbytes(mel));
 
-        const int i0 = std::min(mel_offset, mel_inp.n_len);
+        const int i0 = std::min(mel_offset,           mel_inp.n_len);
         const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
 
         for (int j = 0; j < mel_inp.n_mel; ++j) {
@@ -1543,6 +1547,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
                 dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
             }
         }
+
+        ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
     }
 
     struct ggml_tensor * cur = nullptr;
@@ -1551,24 +1557,27 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         // convolution + gelu
         {
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        model.e_conv_1_b,
-                        cur),
-                    cur);
+            cur = ggml_add(ctx0, cur, model.e_conv_1_b);
+            //cur = ggml_add(ctx0,
+            //        ggml_repeat(ctx0,
+            //            model.e_conv_1_b,
+            //            cur),
+            //        cur);
 
             cur = ggml_gelu(ctx0, cur);
 
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        model.e_conv_2_b,
-                        cur),
-                    cur);
+            cur = ggml_add(ctx0, cur, model.e_conv_2_b);
+            //cur = ggml_add(ctx0,
+            //        ggml_repeat(ctx0,
+            //            model.e_conv_2_b,
+            //            cur),
+            //        cur);
 
             cur = ggml_gelu(ctx0, cur);
         }
 
+        ggml_set_name(cur, "embd_conv");
         wstate.embd_conv = cur;
     } else {
 #ifdef WHISPER_USE_COREML
@@ -1576,7 +1585,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         ggml_allocr_alloc(alloc, cur);
 
         if (!ggml_allocr_is_measure(alloc)) {
-            whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
+            whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
         }
 #endif
 #ifdef WHISPER_USE_OPENVINO
@@ -1588,6 +1597,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         }
 #endif
 
+        ggml_set_name(cur, "embd_enc");
         wstate.embd_enc = cur;
     }
 
@@ -1621,15 +1631,22 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
     ggml_allocr * alloc = wstate.alloc_encode.alloc;
 
+    //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
+    //ggml_allocr_alloc(alloc, cur);
+
+    //if (!ggml_allocr_is_measure(alloc)) {
+    //    ggml_backend_tensor_copy(wstate.embd_conv, cur);
+    //}
+    struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
+
     struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
     ggml_allocr_alloc(alloc, KQscale);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
+        const float val = 1.0f/sqrtf(float(n_state)/n_head);
+        ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
     }
 
-    struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
-
     // ===================================================================
     // NOTE: experimenting with partial evaluation of the encoder (ignore)
     //static int iter = -1;
@@ -1648,7 +1665,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
     const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
 
     struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
-
     cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
 
     // ===================================================================
@@ -1870,13 +1886,20 @@ static struct ggml_cgraph * whisper_build_graph_cross(
 
     ggml_allocr * alloc = wstate.alloc_cross.alloc;
 
+    //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+    //ggml_allocr_alloc(alloc, cur);
+
+    //if (!ggml_allocr_is_measure(alloc)) {
+    //    ggml_backend_tensor_copy(wstate.embd_enc, cur);
+    //}
     struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
 
     struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
     ggml_allocr_alloc(alloc, Kscale);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
+        const float val = pow(float(n_state) / n_head, -0.25);
+        ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
     }
 
     for (int il = 0; il < model.hparams.n_text_layer; ++il) {
@@ -1947,7 +1970,7 @@ static bool whisper_encode_internal(
         ggml_allocr_alloc_graph(alloc, gf);
 
         if (!whisper_encode_external(wstate)) {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
+            ggml_graph_compute_helper(wstate.backend, gf, n_threads);
         }
     }
 
@@ -1961,16 +1984,7 @@ static bool whisper_encode_internal(
 
         ggml_allocr_alloc_graph(alloc, gf);
 
-#ifdef GGML_USE_METAL
-        if (wstate.ctx_metal) {
-            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
-            ggml_metal_graph_compute(wstate.ctx_metal, gf);
-        } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-        }
-#else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-#endif
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
     // cross
@@ -1983,24 +1997,13 @@ static bool whisper_encode_internal(
 
         ggml_allocr_alloc_graph(alloc, gf);
 
-#ifdef GGML_USE_METAL
-        if (wstate.ctx_metal) {
-            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
-            ggml_metal_graph_compute(wstate.ctx_metal, gf);
-        } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-        }
-#else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-#endif
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
-    // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-
     wstate.t_encode_us += ggml_time_us() - t_start_us;
     wstate.n_encode++;
 
-    return true;
+    return !(abort_callback && abort_callback(abort_callback_data));
 }
 
 static struct ggml_cgraph * whisper_build_graph_decoder(
@@ -2043,7 +2046,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     ggml_allocr_alloc(alloc, embd);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        memcpy(embd->data, tokens, N*ggml_element_size(embd));
+        ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
     }
 
     struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@@ -2051,7 +2054,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
     if (!ggml_allocr_is_measure(alloc)) {
         for (int i = 0; i < N; ++i) {
-            ((int32_t *) position->data)[i] = n_past + i;
+            const int32_t val = n_past + i;
+            ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
         }
     }
 
@@ -2059,7 +2063,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     ggml_allocr_alloc(alloc, KQscale);
 
     if (!ggml_allocr_is_measure(alloc)) {
-        ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
+        const float val = pow(float(n_state)/n_head, -0.25);
+        ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
     }
 
     // token encoding + position encoding
@@ -2383,25 +2388,18 @@ static bool whisper_decode_internal(
 
         logits = gf->nodes[gf->n_nodes - 1];
 
-#ifdef GGML_USE_METAL
-        if (wstate.ctx_metal) {
-            ggml_metal_set_n_cb     (wstate.ctx_metal, n_threads);
-            ggml_metal_graph_compute(wstate.ctx_metal, gf);
-        } else {
-            ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-        }
-#else
-        ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
-#endif
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
     // extract logits for all N tokens
     //logits_out.resize(n_tokens*n_vocab);
     //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
+    //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
 
     // extract logits only for the last token
     logits_out.resize(n_vocab);
-    memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
+    //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
+    ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
 
     if (n_tokens > 1) {
         //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
@@ -2420,7 +2418,7 @@ static bool whisper_decode_internal(
         wstate.n_prompt++;
     }
 
-    return true;
+    return !(abort_callback && abort_callback(abort_callback_data));
 }
 
 
@@ -2767,7 +2765,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
                 --j;
             }
             if (!found) {
-                log("unknown token\n");
+                WHISPER_LOG_ERROR("unknown token\n");
                 ++i;
             }
         }
@@ -2830,45 +2828,48 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
 
 struct whisper_state * whisper_init_state(whisper_context * ctx) {
     fill_sin_cos_table();
+
     whisper_state * state = new whisper_state;
 
-    if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
-        log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
+    state->backend = whisper_backend_init(ctx->params);
+
+    if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
+        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         delete state;
         return nullptr;
     }
 
     {
         const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
-        log("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
-    if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
-        log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
+    if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
+        WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
         delete state;
         return nullptr;
     }
 
     {
         const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
-        log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
     }
 
 #ifdef WHISPER_USE_COREML
     const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
 
-    log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
-    log("%s: first run on a device may take a while ...\n", __func__);
+    WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
+    WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
 
     state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
     if (!state->ctx_coreml) {
-        log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
+        WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
 #ifndef WHISPER_COREML_ALLOW_FALLBACK
         delete state;
         return nullptr;
 #endif
     } else {
-        log("%s: Core ML model loaded\n", __func__);
+        WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
     }
 #endif
 
@@ -2885,37 +2886,37 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // conv allocator
     {
-        whisper_allocr_graph_init(state->alloc_conv,
+        whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
                 [&]() {
                     return whisper_build_graph_conv(*ctx, *state, 0);
                 });
 
-        log("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
     }
 
     // encoder allocator
     if (!whisper_encode_external(*state)) {
-        whisper_allocr_graph_init(state->alloc_encode,
+        whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
                 [&]() {
                     return whisper_build_graph_encoder(*ctx, *state);
                 });
 
-        log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
     }
 
     // cross allocator
     {
-        whisper_allocr_graph_init(state->alloc_cross,
+        whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
                 [&]() {
                     return whisper_build_graph_cross(*ctx, *state);
                 });
 
-        log("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
+        WHISPER_LOG_INFO("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
     }
 
     // decoder allocator
     {
-        whisper_allocr_graph_init(state->alloc_decode,
+        whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
                 [&]() {
                     const auto & hparams = ctx->model.hparams;
 
@@ -2926,64 +2927,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
                     return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
                 });
 
-        log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
-    }
-
-#ifdef GGML_USE_METAL
-    state->ctx_metal = ggml_metal_init(1);
-    if (!state->ctx_metal) {
-        log("%s: ggml_metal_init() failed\n", __func__);
-        delete state;
-        return nullptr;
-    }
-
-    log("%s: Metal context initialized\n", __func__);
-
-    // this allocates all Metal resources and memory buffers
-
-    void * data_ptr  = NULL;
-    size_t data_size = 0;
-
-    // TODO: add mmap support
-    //if (params.use_mmap) {
-    //    data_ptr  = ctx->model.mapping->addr;
-    //    data_size = ctx->model.mapping->size;
-    //} else {
-    //    data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
-    //    data_size = ggml_get_mem_size  (ctx->model.ctx);
-    //}
-
-    data_ptr  = ggml_get_mem_buffer(ctx->model.ctx);
-    data_size = ggml_get_mem_size  (ctx->model.ctx);
-
-    const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
-
-    log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
-
-#define WHISPER_METAL_CHECK_BUF(result)              \
-    if (!(result)) {                                 \
-        log("%s: failed to add metal buffer\n", __func__); \
-        delete state;                                \
-        return nullptr;                              \
+        WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
     }
 
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
-
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv",   state->alloc_conv.meta.data(),   state->alloc_conv.meta.size(),   0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross",  state->alloc_cross.meta.data(),  state->alloc_cross.meta.size(),  0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
-
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv",   state->alloc_conv.data.data(),   state->alloc_conv.data.size(),   0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross",  state->alloc_cross.data.data(),  state->alloc_cross.data.size(),  0));
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
-
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross",  state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
-
-    WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
-#undef WHISPER_METAL_CHECK_BUF
-#endif
+    whisper_allocr_graph_realloc(state->alloc_conv,   ctx->backend);
+    whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
+    whisper_allocr_graph_realloc(state->alloc_cross,  ctx->backend);
+    whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
 
     state->rng = std::mt19937(0);
 
@@ -3004,7 +2954,7 @@ int whisper_ctx_init_openvino_encoder(
     return 1;
 #else
     if (!model_path && ctx->path_model.empty()) {
-        log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
+        WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
         return 1;
     }
 
@@ -3024,27 +2974,34 @@ int whisper_ctx_init_openvino_encoder(
         path_cache = cache_dir;
     }
 
-    log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
-    log("%s: first run on a device may take a while ...\n", __func__);
+    WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
+    WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
 
     ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
     if (!ctx->state->ctx_openvino) {
-        log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
+        WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
         return 1;
     } else {
-        log("%s: OpenVINO model loaded\n", __func__);
+        WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
     }
 
     return 0;
 #endif
 }
 
-struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
-    log("%s: loading model from '%s'\n", __func__, path_model);
+struct whisper_context_params whisper_context_default_params() {
+    struct whisper_context_params result = {
+        /*.use_gpu    =*/ true,
+    };
+    return result;
+}
+
+struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
+    WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
 
     auto fin = std::ifstream(path_model, std::ios::binary);
     if (!fin) {
-        log("%s: failed to open '%s'\n", __func__, path_model);
+        WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
         return nullptr;
     }
 
@@ -3068,7 +3025,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
         fin->close();
     };
 
-    auto ctx = whisper_init_no_state(&loader);
+    auto ctx = whisper_init_with_params_no_state(&loader, params);
 
     if (ctx) {
         ctx->path_model = path_model;
@@ -3077,7 +3034,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
     return ctx;
 }
 
-struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
+struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) {
     struct buf_context {
         uint8_t* buffer;
         size_t size;
@@ -3086,7 +3043,7 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
 
     buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
 
-    log("%s: loading model from buffer\n", __func__);
+    WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
 
     whisper_model_loader loader = {};
 
@@ -3111,17 +3068,18 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
 
     loader.close = [](void * /*ctx*/) { };
 
-    return whisper_init_no_state(&loader);
+    return whisper_init_with_params_no_state(&loader, params);
 }
 
-struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
+struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
     ggml_time_init();
 
     whisper_context * ctx = new whisper_context;
+    ctx->params = params;
 
     if (!whisper_model_load(loader, *ctx)) {
         loader->close(loader->context);
-        log("%s: failed to load model\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
         delete ctx;
         return nullptr;
     }
@@ -3131,8 +3089,8 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
     return ctx;
 }
 
-struct whisper_context * whisper_init_from_file(const char * path_model) {
-    whisper_context * ctx = whisper_init_from_file_no_state(path_model);
+struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) {
+    whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params);
     if (!ctx) {
         return nullptr;
     }
@@ -3146,8 +3104,8 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
     return ctx;
 }
 
-struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
-    whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
+struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) {
+    whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params);
     if (!ctx) {
         return nullptr;
     }
@@ -3161,8 +3119,8 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
     return ctx;
 }
 
-struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
-    whisper_context * ctx = whisper_init_no_state(loader);
+struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) {
+    whisper_context * ctx = whisper_init_with_params_no_state(loader, params);
     if (!ctx) {
         return nullptr;
     }
@@ -3176,6 +3134,30 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
     return ctx;
 }
 
+struct whisper_context * whisper_init_from_file(const char * path_model) {
+    return whisper_init_from_file_with_params(path_model, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
+    return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
+    return whisper_init_with_params(loader, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
+    return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
+    return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params());
+}
+
+struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
+    return whisper_init_with_params_no_state(loader, whisper_context_default_params());
+}
+
 void whisper_free_state(struct whisper_state * state)
 {
     if (state) {
@@ -3192,13 +3174,6 @@ void whisper_free_state(struct whisper_state * state)
         }
 #endif
 
-#ifdef GGML_USE_METAL
-        if (state->ctx_metal) {
-            ggml_metal_free(state->ctx_metal);
-            state->ctx_metal = nullptr;
-        }
-#endif
-
 #ifdef WHISPER_USE_OPENVINO
         if (state->ctx_openvino != nullptr) {
             whisper_openvino_free(state->ctx_openvino);
@@ -3207,9 +3182,11 @@ void whisper_free_state(struct whisper_state * state)
 #endif
 
         whisper_allocr_free(state->alloc_conv);
-        whisper_allocr_free(state->alloc_decode);
-        whisper_allocr_free(state->alloc_cross);
         whisper_allocr_free(state->alloc_encode);
+        whisper_allocr_free(state->alloc_cross);
+        whisper_allocr_free(state->alloc_decode);
+
+        ggml_backend_free(state->backend);
 
         delete state;
     }
@@ -3220,16 +3197,25 @@ void whisper_free(struct whisper_context * ctx) {
         if (ctx->model.ctx) {
             ggml_free(ctx->model.ctx);
         }
-        if (ctx->model.buf) {
-            delete ctx->model.buf;
+
+        if (ctx->model.buffer) {
+            ggml_backend_buffer_free(ctx->model.buffer);
         }
 
         whisper_free_state(ctx->state);
 
+        ggml_backend_free(ctx->backend);
+
         delete ctx;
     }
 }
 
+void whisper_free_context_params(struct whisper_context_params * params) {
+    if (params) {
+        delete params;
+    }
+}
+
 void whisper_free_params(struct whisper_full_params * params) {
     if (params) {
         delete params;
@@ -3237,8 +3223,8 @@ void whisper_free_params(struct whisper_full_params * params) {
 }
 
 int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
-    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
-        log("%s: failed to compute mel spectrogram\n", __func__);
+    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
+        WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
 
@@ -3251,8 +3237,8 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
 
 // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
 int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
-    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
-        log("%s: failed to compute mel spectrogram\n", __func__);
+    if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
+        WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
         return -1;
     }
 
@@ -3274,13 +3260,13 @@ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float *
 // TODO
 
 int whisper_set_mel_with_state(
-        struct whisper_context * /*ctx*/,
+        struct whisper_context * ctx,
           struct whisper_state * state,
                    const float * data,
                            int   n_len,
                            int   n_mel) {
-    if (n_mel != WHISPER_N_MEL) {
-        log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
+    if (n_mel != ctx->model.filters.n_mel) {
+        WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
         return -1;
     }
 
@@ -3304,7 +3290,7 @@ int whisper_set_mel(
 
 int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
     if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
-        log("%s: failed to eval\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return -1;
     }
 
@@ -3313,7 +3299,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
 
 int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
     if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
-        log("%s: failed to eval\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return -1;
     }
 
@@ -3324,7 +3310,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
     const int selected_decoder_id = 0;
 
     if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
-        log("%s: failed to eval\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
 
@@ -3336,12 +3322,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
     const int selected_decoder_id = 0;
 
     if (ctx->state == nullptr) {
-        log("%s: ERROR state was not loaded.\n", __func__);
+        WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
         return false;
     }
 
     if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
-        log("%s: failed to eval\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
 
@@ -3352,7 +3338,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
     const auto res = tokenize(ctx->vocab, text);
 
     if (n_max_tokens < (int) res.size()) {
-        log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
+        WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
         return -1;
     }
 
@@ -3380,7 +3366,7 @@ int whisper_lang_id(const char * lang) {
             }
         }
 
-        log("%s: unknown language '%s'\n", __func__, lang);
+        WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
         return -1;
     }
     return g_lang.at(lang).first;
@@ -3393,7 +3379,7 @@ const char * whisper_lang_str(int id) {
         }
     }
 
-    log("%s: unknown language id %d\n", __func__, id);
+    WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
     return nullptr;
 }
 
@@ -3406,25 +3392,25 @@ int whisper_lang_auto_detect_with_state(
     const int seek = offset_ms/10;
 
     if (seek < 0) {
-        log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
+        WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
         return -1;
     }
 
     if (seek >= state->mel.n_len_org) {
-        log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
+        WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
         return -2;
     }
 
     // run the encoder
     if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
-        log("%s: failed to encode\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
         return -6;
     }
 
     const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
 
     if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
-        log("%s: failed to decode\n", __func__);
+        WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
         return -7;
     }
 
@@ -3624,8 +3610,8 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
 void whisper_print_timings(struct whisper_context * ctx) {
     const int64_t t_end_us = ggml_time_us();
 
-    log("\n");
-    log("%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
+    WHISPER_LOG_INFO("\n");
+    WHISPER_LOG_INFO("%s:     load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
     if (ctx->state != nullptr) {
 
         const int32_t n_sample = std::max(1, ctx->state->n_sample);
@@ -3633,18 +3619,20 @@ void whisper_print_timings(struct whisper_context * ctx) {
         const int32_t n_decode = std::max(1, ctx->state->n_decode);
         const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
 
-        log("%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
-        log("%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
-        log("%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
-        log("%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
-        log("%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
-        log("%s:   prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
+        WHISPER_LOG_INFO("%s:     fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
+        WHISPER_LOG_INFO("%s:      mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
+        WHISPER_LOG_INFO("%s:   sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
+        WHISPER_LOG_INFO("%s:   encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
+        WHISPER_LOG_INFO("%s:   decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
+        WHISPER_LOG_INFO("%s:   prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
     }
-    log("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
+    WHISPER_LOG_INFO("%s:    total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
 }
 
 void whisper_reset_timings(struct whisper_context * ctx) {
+    ctx->t_start_us = ggml_time_us();
     if (ctx->state != nullptr) {
+        ctx->state->t_mel_us = 0;
         ctx->state->t_sample_us = 0;
         ctx->state->t_encode_us = 0;
         ctx->state->t_decode_us = 0;
@@ -3690,6 +3678,7 @@ const char * whisper_print_system_info(void) {
     s += "SSE3 = "      + std::to_string(ggml_cpu_has_sse3())      + " | ";
     s += "SSSE3 = "     + std::to_string(ggml_cpu_has_ssse3())     + " | ";
     s += "VSX = "       + std::to_string(ggml_cpu_has_vsx())       + " | ";
+    s += "CUDA = "      + std::to_string(ggml_cpu_has_cublas())    + " | ";
     s += "COREML = "    + std::to_string(whisper_has_coreml())     + " | ";
     s += "OPENVINO = "  + std::to_string(whisper_has_openvino())   + " | ";
 
@@ -3698,6 +3687,14 @@ const char * whisper_print_system_info(void) {
 
 ////////////////////////////////////////////////////////////////////////////
 
+struct whisper_context_params * whisper_context_default_params_by_ref() {
+    struct whisper_context_params params = whisper_context_default_params();
+
+    struct whisper_context_params* result = new whisper_context_params();
+    *result = params;
+    return result;
+}
+
 struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
     struct whisper_full_params params = whisper_full_default_params(strategy);
 
@@ -3774,6 +3771,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
         /*.encoder_begin_callback           =*/ nullptr,
         /*.encoder_begin_callback_user_data =*/ nullptr,
 
+        /*.abort_callback                   =*/ nullptr,
+        /*.abort_callback_user_data         =*/ nullptr,
+
         /*.logits_filter_callback           =*/ nullptr,
         /*.logits_filter_callback_user_data =*/ nullptr,
     };
@@ -3940,6 +3940,7 @@ static void whisper_process_logits(
         // suppress task tokens
         logits[vocab.token_translate]  = -INFINITY;
         logits[vocab.token_transcribe] = -INFINITY;
+        logits[vocab.token_prev]       = -INFINITY;
 
         if (params.logits_filter_callback) {
             params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
@@ -3972,7 +3973,7 @@ static void whisper_process_logits(
             const bool last_was_timestamp        = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
             const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
 
-            //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
+            //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
 
             if (last_was_timestamp) {
                 if (penultimate_was_timestamp) {
@@ -4048,7 +4049,7 @@ static void whisper_process_logits(
 
             const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
 
-            //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
+            //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
 
             if (timestamp_logprob > max_text_token_logprob) {
                 for (int i = 0; i < vocab.token_beg; ++i) {
@@ -4343,8 +4344,10 @@ static bool whisper_kv_swap_fast(
     for (auto & i : two_copy) {
         // make a copy of KV caches
         WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
-        memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
-        memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
+        //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
+        //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
+        ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
+        ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
     }
 
     // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
@@ -4357,13 +4360,17 @@ static bool whisper_kv_swap_fast(
         if (two_copy.find(view[i]) != two_copy.end()) {
             // modify KV caches of decoder using data from kv_swap_bufs
             WHISPER_PRINT_DEBUG("%s: two-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
-            memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
-            memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+            //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
+            //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+            ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
+            ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
         } else {
             // modify KV caches of decoder using data from correspond decoder KV caches directly
             WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
-            memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
-            memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+            //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
+            //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+            ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
+            ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
         }
     }
 
@@ -4377,13 +4384,17 @@ static bool whisper_kv_swap_fast(
         if (two_copy.find(view[i]) != two_copy.end()) {
             // modify KV caches of decoder using data from kv_swap_bufs
             WHISPER_PRINT_DEBUG("%s: one-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
-            memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
-            memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+            //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
+            //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+            ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
+            ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
         } else {
             // modify KV caches of decoder using data from correspond decoder KV caches directly
             WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
-            memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
-            memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+            //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
+            //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+            ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
+            ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
         }
     }
 
@@ -4411,11 +4422,11 @@ int whisper_full_with_state(
         // compute log mel spectrogram
         if (params.speed_up) {
             // TODO: Replace PV with more advanced algorithm
-            log("%s: failed to compute log mel spectrogram\n", __func__);
+            WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
             return -1;
         } else {
             if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
-                log("%s: failed to compute log mel spectrogram\n", __func__);
+                WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
                 return -2;
             }
         }
@@ -4427,13 +4438,13 @@ int whisper_full_with_state(
 
         const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
         if (lang_id < 0) {
-            log("%s: failed to auto-detect language\n", __func__);
+            WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
             return -3;
         }
         state->lang_id = lang_id;
         params.language = whisper_lang_str(lang_id);
 
-        log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
+        WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
         if (params.detect_language) {
             return 0;
         }
@@ -4491,8 +4502,8 @@ int whisper_full_with_state(
 
         if (decoder.kv_self.ctx == nullptr) {
             decoder.kv_self = state->decoders[0].kv_self;
-            if (!kv_cache_reinit(decoder.kv_self)) {
-                log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
+            if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
+                WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
                 return -4;
             }
 
@@ -4503,21 +4514,6 @@ int whisper_full_with_state(
             decoder.probs.resize   (ctx->vocab.n_vocab);
             decoder.logits.resize  (ctx->vocab.n_vocab);
             decoder.logprobs.resize(ctx->vocab.n_vocab);
-
-            // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
-#ifdef GGML_USE_METAL
-#define WHISPER_METAL_CHECK_BUF(result)              \
-            if (!(result)) {                                 \
-                log("%s: failed to add metal buffer\n", __func__); \
-                return 0;                              \
-            }
-
-            const std::string kv_name = "kv_self_" + std::to_string(j);
-            auto & kv_self = decoder.kv_self;
-
-            WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
-#undef WHISPER_METAL_CHECK_BUF
-#endif
         }
     }
 
@@ -4551,13 +4547,14 @@ int whisper_full_with_state(
 
     // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
     if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
-        log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
+        WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
         return -5;
     }
     state->exp_n_audio_ctx = params.audio_ctx;
 
     // these tokens determine the task that will be performed
     std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
+
     if (whisper_is_multilingual(ctx)) {
         const int lang_id = whisper_lang_id(params.language);
         state->lang_id = lang_id;
@@ -4569,6 +4566,17 @@ int whisper_full_with_state(
         }
     }
 
+    {
+        const bool is_distil = ctx->model.hparams.n_text_layer == 2;
+
+        // distilled models require the "no_timestamps" token
+        // TODO: add input parameter (#1229)
+        if (is_distil) {
+            WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
+            prompt_init.push_back(whisper_token_not(ctx));
+        }
+    }
+
     int seek = seek_start;
 
     std::vector<whisper_token> prompt;
@@ -4601,14 +4609,14 @@ int whisper_full_with_state(
 
         if (params.encoder_begin_callback) {
             if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
-                log("%s: encoder_begin_callback returned false - aborting\n", __func__);
+                WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
                 break;
             }
         }
 
         // encode audio features starting at offset seek
         if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
-            log("%s: failed to encode\n", __func__);
+            WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
             return -6;
         }
 
@@ -4691,7 +4699,7 @@ int whisper_full_with_state(
                 WHISPER_PRINT_DEBUG("\n\n");
 
                 if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
-                    log("%s: failed to decode\n", __func__);
+                    WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                     return -7;
                 }
 
@@ -4705,8 +4713,11 @@ int whisper_full_with_state(
                     for (int j = 1; j < n_decoders_cur; ++j) {
                         auto & decoder = state->decoders[j];
 
-                        memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
-                        memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
+                        // TODO: fix CUDA
+                        //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
+                        //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
+                        ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
+                        ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
 
                         decoder.kv_self.n += prompt.size();
 
@@ -4915,7 +4926,7 @@ int whisper_full_with_state(
                     //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
 
                     if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
-                        log("%s: failed to decode\n", __func__);
+                        WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                         return -8;
                     }
 
@@ -5241,12 +5252,12 @@ int whisper_full_parallel(
     ctx->state->t_decode_us /= n_processors;
 
     // print information about the audio boundaries
-    log("\n");
-    log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
+    WHISPER_LOG_WARN("\n");
+    WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
     for (int i = 0; i < n_processors - 1; ++i) {
-        log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
+        WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
     }
-    log("%s: the transcription quality may be degraded near these boundaries\n", __func__);
+    WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
 
     return ret;
 }
@@ -5488,12 +5499,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
             double tsum = 0.0;
 
             // heat-up
-            ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
+            ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
 
             for (int i = 0; i < n_max; ++i) {
                 const int64_t t0 = ggml_time_us();
 
-                ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
+                ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
 
                 const int64_t t1 = ggml_time_us();
 
@@ -5611,7 +5622,7 @@ static void whisper_exp_compute_token_level_timestamps(
     const int n_samples = state.energy.size();
 
     if (n_samples == 0) {
-        log("%s: no signal data available\n", __func__);
+        WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
         return;
     }
 
@@ -5832,6 +5843,32 @@ static void whisper_exp_compute_token_level_timestamps(
     //}
 }
 
-void whisper_set_log_callback(whisper_log_callback callback) {
-    whisper_log = callback;
+void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
+    g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
+    g_state.log_callback_user_data = user_data;
+}
+
+GGML_ATTRIBUTE_FORMAT(2, 3)
+static void whisper_log_internal(ggml_log_level level, const char * format, ...) {
+    va_list args;
+    va_start(args, format);
+    char buffer[1024];
+    int len = vsnprintf(buffer, 1024, format, args);
+    if (len < 1024) {
+        g_state.log_callback(level, buffer, g_state.log_callback_user_data);
+    } else {
+        char* buffer2 = new char[len+1];
+        vsnprintf(buffer2, len+1, format, args);
+        buffer2[len] = 0;
+        g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
+        delete[] buffer2;
+    }
+    va_end(args);
+}
+
+static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
 }
index c3118c9c99b7ef0d6b217f4f6ff8ceae1d0a2b8f..0ea5237e5f2f7c0d51a2432439f8d425e23b78e6 100644 (file)
@@ -1,10 +1,20 @@
 #ifndef WHISPER_H
 #define WHISPER_H
 
+#include "ggml.h"
+
 #include <stddef.h>
 #include <stdint.h>
 #include <stdbool.h>
 
+#ifdef __GNUC__
+#    define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
+#elif defined(_MSC_VER)
+#    define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
+#else
+#    define WHISPER_DEPRECATED(func, hint) func
+#endif
+
 #ifdef WHISPER_SHARED
 #    ifdef _WIN32
 #        ifdef WHISPER_BUILD
@@ -21,7 +31,6 @@
 
 #define WHISPER_SAMPLE_RATE 16000
 #define WHISPER_N_FFT       400
-#define WHISPER_N_MEL       80
 #define WHISPER_HOP_LENGTH  160
 #define WHISPER_CHUNK_SIZE  30
 
@@ -71,6 +80,10 @@ extern "C" {
 
     typedef int whisper_token;
 
+    struct whisper_context_params {
+        bool  use_gpu;
+    };
+
     typedef struct whisper_token_data {
         whisper_token id;  // token id
         whisper_token tid; // forced timestamp token id
@@ -99,15 +112,40 @@ extern "C" {
     // Various functions for loading a ggml whisper model.
     // Allocate (almost) all memory needed for the model.
     // Return NULL on failure
-    WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
-    WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
-    WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
+    WHISPER_API struct whisper_context * whisper_init_from_file_with_params  (const char * path_model,              struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size,    struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_with_params            (struct whisper_model_loader * loader, struct whisper_context_params params);
 
     // These are the same as the above, but the internal state of the context is not allocated automatically
     // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
-    WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model);
-    WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size);
-    WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader);
+    WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state  (const char * path_model,              struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size,    struct whisper_context_params params);
+    WHISPER_API struct whisper_context * whisper_init_with_params_no_state            (struct whisper_model_loader * loader, struct whisper_context_params params);
+
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
+        "use whisper_init_from_file_with_params instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size),
+        "use whisper_init_from_buffer_with_params instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader),
+        "use whisper_init_with_params instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model),
+        "use whisper_init_from_file_with_params_no_state instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size),
+        "use whisper_init_from_buffer_with_params_no_state instead"
+    );
+    WHISPER_DEPRECATED(
+        WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader),
+        "use whisper_init_with_params_no_state instead"
+    );
 
     WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
 
@@ -132,6 +170,7 @@ extern "C" {
     WHISPER_API void whisper_free      (struct whisper_context * ctx);
     WHISPER_API void whisper_free_state(struct whisper_state * state);
     WHISPER_API void whisper_free_params(struct whisper_full_params * params);
+    WHISPER_API void whisper_free_context_params(struct whisper_context_params * params);
 
     // Convert RAW PCM audio to log mel spectrogram.
     // The resulting spectrogram is stored inside the default state of the provided whisper context.
@@ -442,7 +481,9 @@ extern "C" {
         void * logits_filter_callback_user_data;
     };
 
-    // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
+    // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
+    WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref();
+    WHISPER_API struct whisper_context_params whisper_context_default_params(void);
     WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
     WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
 
@@ -531,8 +572,7 @@ extern "C" {
 
     // Control logging output; default behavior is to print to stderr
 
-    typedef void (*whisper_log_callback)(const char * line);
-    WHISPER_API void whisper_set_log_callback(whisper_log_callback callback);
+    WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
 
 #ifdef __cplusplus
 }
index ce4feeece7f631c5b07cb1c9b68de55e6397454c..058011a48f0da5189223145aab5c4aa26d1b3b3f 100644 (file)
@@ -4741,7 +4741,7 @@ static  __global__ void im2col_f32_f16(
         int ofs0, int ofs1, int IW, int IH, int CHW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
     const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
-    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+       const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
 
     const int offset_dst =
         (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
@@ -7962,6 +7962,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         return false;
     }
 
+    if (tensor->op == GGML_OP_MUL_MAT) {
+        if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+#ifndef NDEBUG
+            fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
+#endif
+            return false;
+        }
+    }
+
     switch (tensor->op) {
         case GGML_OP_REPEAT:
             func = ggml_cuda_repeat;
index 096b844e32c6fef91e18033c4638acecc2f322c8..be2731f8ba476728c50b4971a67a74a9437b981d 100644 (file)
@@ -26,7 +26,7 @@
 #include <stdbool.h>
 
 // max memory buffers that can be mapped to the device
-#define GGML_METAL_MAX_BUFFERS 16
+#define GGML_METAL_MAX_BUFFERS 64
 #define GGML_METAL_MAX_COMMAND_BUFFERS 32
 
 struct ggml_tensor;
index 148c12b14f8f211de78fb8d0eff109c18b2e2877..3d22b0b27e444f2a1995a7e729ee7171faa189a3 100644 (file)
@@ -128,7 +128,7 @@ struct ggml_metal_context {
 // MSL code
 // TODO: move the contents here when ready
 //       for now it is easier to work in a separate file
-static NSString * const msl_library_source = @"see metal.metal";
+//static NSString * const msl_library_source = @"see metal.metal";
 
 // Here to assist with NSBundle Path Hack
 @interface GGMLMetalClass : NSObject
@@ -144,7 +144,8 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
     ggml_metal_log_user_data = user_data;
 }
 
-static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
+GGML_ATTRIBUTE_FORMAT(2, 3)
+static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
     if (ggml_metal_log_callback != NULL) {
         va_list args;
         va_start(args, format);
@@ -339,7 +340,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
     for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
         if ([ctx->device supportsFamily:i]) {
-            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
+            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
             break;
         }
     }
@@ -479,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
 
     const int64_t tsize = ggml_nbytes(t);
 
+    if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
+        ctx = t->buffer->backend->context;
+    }
+
     // find the view that contains the tensor fully
     for (int i = 0; i < ctx->n_buffers; ++i) {
         const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -579,7 +584,7 @@ bool ggml_metal_add_buffer(
                 ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
-            GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
+            GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
         } else {
             GGML_METAL_LOG_INFO("\n");
         }
index 3067f9dff2e12ab1606c9b015e1a997398074610..a5a418d129a2ab43da98e026ac11b1adadb6f09f 100644 (file)
@@ -35,8 +35,8 @@ void load_model(test_model & model, bool use_gpu = false) {
     int IL = 8, N = 1;
 
     // Initialize adata
-    float* adata = new float[K * IC * OC];
-    for (size_t i = 0; i < K * IC * OC; i++) {
+    float * adata = new float[K * IC * OC];
+    for (int i = 0; i < K * IC * OC; i++) {
         adata[i] = 4.5f;
     }
 
@@ -45,8 +45,8 @@ void load_model(test_model & model, bool use_gpu = false) {
     ggml_fp32_to_fp16_row(adata, hadata.data(), K * IC * OC);
 
     // Initialize bdata
-    float* bdata =  new float[IL * IC * N];
-    for (size_t i = 0; i < IL * IC * N; i++) {
+    float * bdata =  new float[IL * IC * N];
+    for (int i = 0; i < IL * IC * N; i++) {
         bdata[i] = 2.5f;
     }
 
@@ -278,7 +278,7 @@ int main(void)
         }
     }
 
-    printf("ggml_im2col (%i): %s\n", ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+    printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
 
     passed = true;
     for(int i = 0; i < n_conv1d_test; i++) {
@@ -288,7 +288,7 @@ int main(void)
         }
     }
 
-    printf("ggml_conv1d (%i): %s\n", ggml_nelements(conv1d_res), passed && (ggml_nelements(conv1d_res) == n_conv1d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
+    printf("ggml_conv1d (%d): %s\n", (int) ggml_nelements(conv1d_res), passed && (ggml_nelements(conv1d_res) == n_conv1d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
     ggml_free(model.ctx);
 
     ggml_backend_buffer_free(model.buffer);
index 4ad830b418185b76cd0792e9adca61e4f7b7f12f..df0e2ccdb7b49eb3e9b12aa4683c2b24eace8361 100644 (file)
@@ -35,8 +35,8 @@ void load_model(test_model & model, bool use_gpu = false) {
     int IW = 8, IH = 6, N = 1;
 
     // Initialize adata
-    float* adata = new float[KW * KH * IC * OC];
-    for (size_t i = 0; i < KW * KH * IC * OC; i++) {
+    float * adata = new float[KW * KH * IC * OC];
+    for (int i = 0; i < KW * KH * IC * OC; i++) {
         adata[i] = 2.5f;
     }
 
@@ -45,8 +45,8 @@ void load_model(test_model & model, bool use_gpu = false) {
     ggml_fp32_to_fp16_row(adata, hadata.data(), KW * KH * IC * OC);
 
     // Initialize bdata
-    float* bdata =  new float[IW * IH * IC * N];
-    for (size_t i = 0; i < IW * IH * IC * N; i++) {
+    float * bdata =  new float[IW * IH * IC * N];
+    for (int i = 0; i < IW * IH * IC * N; i++) {
         bdata[i] = 1.5f;
     }