]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add support for backends with multiple ggml_backend_buffer_type (#2863)
authorDan Johansson <redacted>
Wed, 26 Mar 2025 14:54:02 +0000 (15:54 +0100)
committerGitHub <redacted>
Wed, 26 Mar 2025 14:54:02 +0000 (16:54 +0200)
* whisper : add support for ggml_backend_buffer_type

Signed-off-by: Dan Johansson <redacted>
* fix compile error when building on Ubuntu

Signed-off-by: Dan Johansson <redacted>
* remove copyright header from include file

Signed-off-by: Dan Johansson <redacted>
---------

Signed-off-by: Dan Johansson <redacted>
src/CMakeLists.txt
src/whisper-arch.h [new file with mode: 0644]
src/whisper.cpp

index 527d38b8ce1910075fc6036b7ac540e8f3079fd3..a091e66a25f24774da16256da17ab747179a2f29 100644 (file)
@@ -102,6 +102,7 @@ endif()
 
 add_library(whisper
             ../include/whisper.h
+            whisper-arch.h
             whisper.cpp
             )
 
diff --git a/src/whisper-arch.h b/src/whisper-arch.h
new file mode 100644 (file)
index 0000000..ea2cfd6
--- /dev/null
@@ -0,0 +1,141 @@
+#pragma once
+
+#include "ggml.h"
+
+#include <map>
+
+enum asr_tensor {
+    ASR_TENSOR_ENC_POS_EMBD,
+    ASR_TENSOR_DEC_POS_EMBD,
+    ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT,
+    ASR_TENSOR_LN_WEIGHT,
+    ASR_TENSOR_LN_BIAS,
+    ASR_TENSOR_CONV1_WEIGHT,
+    ASR_TENSOR_CONV1_BIAS,
+    ASR_TENSOR_CONV2_WEIGHT,
+    ASR_TENSOR_CONV2_BIAS,
+    ASR_TENSOR_LN_POST_WEIGHT,
+    ASR_TENSOR_LN_POST_BIAS,
+    ASR_TENSOR_MLP_LN_WEIGHT,
+    ASR_TENSOR_MLP_LN_BIAS,
+    ASR_TENSOR_MLP_0_WEIGHT,
+    ASR_TENSOR_MLP_0_BIAS,
+    ASR_TENSOR_MLP_2_WEIGHT,
+    ASR_TENSOR_MLP_2_BIAS,
+    ASR_TENSOR_ATTN_LN_WEIGHT,
+    ASR_TENSOR_ATTN_LN_BIAS,
+    ASR_TENSOR_ATTN_QUERY_WEIGHT,
+    ASR_TENSOR_ATTN_QUERY_BIAS,
+    ASR_TENSOR_ATTN_KEY_WEIGHT,
+    ASR_TENSOR_ATTN_VALUE_WEIGHT,
+    ASR_TENSOR_ATTN_VALUE_BIAS,
+    ASR_TENSOR_ATTN_OUT_WEIGHT,
+    ASR_TENSOR_ATTN_OUT_BIAS,
+};
+
+enum asr_system {
+    ASR_SYSTEM_ENCODER,
+    ASR_SYSTEM_DECODER,
+    ASR_SYSTEM_CROSS
+};
+
+static const std::map<asr_system, std::map<asr_tensor, const char *>> ASR_TENSOR_NAMES = {
+    {
+        ASR_SYSTEM_ENCODER,
+        {
+            {ASR_TENSOR_ENC_POS_EMBD, "encoder.positional_embedding"},
+            {ASR_TENSOR_CONV1_WEIGHT, "encoder.conv1.weight"},
+            {ASR_TENSOR_CONV1_BIAS, "encoder.conv1.bias"},
+            {ASR_TENSOR_CONV2_WEIGHT, "encoder.conv2.weight"},
+            {ASR_TENSOR_CONV2_BIAS, "encoder.conv2.bias"},
+            {ASR_TENSOR_LN_WEIGHT, "encoder.ln_post.weight"},
+            {ASR_TENSOR_LN_POST_BIAS, "encoder.ln_post.bias"},
+            {ASR_TENSOR_MLP_LN_WEIGHT, "encoder.blocks.%d.mlp_ln.weight"},
+            {ASR_TENSOR_MLP_LN_BIAS, "encoder.blocks.%d.mlp_ln.bias"},
+            {ASR_TENSOR_MLP_0_WEIGHT, "encoder.blocks.%d.mlp.0.weight"},
+            {ASR_TENSOR_MLP_0_BIAS, "encoder.blocks.%d.mlp.0.bias"},
+            {ASR_TENSOR_MLP_2_WEIGHT, "encoder.blocks.%d.mlp.2.weight"},
+            {ASR_TENSOR_MLP_2_BIAS, "encoder.blocks.%d.mlp.2.bias"},
+            {ASR_TENSOR_ATTN_LN_WEIGHT, "encoder.blocks.%d.attn_ln.weight"},
+            {ASR_TENSOR_ATTN_LN_BIAS, "encoder.blocks.%d.attn_ln.bias"},
+            {ASR_TENSOR_ATTN_QUERY_WEIGHT, "encoder.blocks.%d.attn.query.weight"},
+            {ASR_TENSOR_ATTN_QUERY_BIAS, "encoder.blocks.%d.attn.query.bias"},
+            {ASR_TENSOR_ATTN_KEY_WEIGHT, "encoder.blocks.%d.attn.key.weight"},
+            {ASR_TENSOR_ATTN_VALUE_WEIGHT, "encoder.blocks.%d.attn.value.weight"},
+            {ASR_TENSOR_ATTN_VALUE_BIAS, "encoder.blocks.%d.attn.value.bias"},
+            {ASR_TENSOR_ATTN_OUT_WEIGHT, "encoder.blocks.%d.attn.out.weight"},
+            {ASR_TENSOR_ATTN_OUT_BIAS, "encoder.blocks.%d.attn.out.bias"},
+        },
+    },
+    {
+        ASR_SYSTEM_DECODER,
+        {
+            {ASR_TENSOR_DEC_POS_EMBD, "decoder.positional_embedding"},
+            {ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, "decoder.token_embedding.weight"},
+            {ASR_TENSOR_LN_WEIGHT, "decoder.ln.weight"},
+            {ASR_TENSOR_LN_BIAS, "decoder.ln.bias"},
+
+            {ASR_TENSOR_MLP_LN_WEIGHT, "decoder.blocks.%d.mlp_ln.weight"},
+            {ASR_TENSOR_MLP_LN_BIAS, "decoder.blocks.%d.mlp_ln.bias"},
+            {ASR_TENSOR_MLP_0_WEIGHT, "decoder.blocks.%d.mlp.0.weight"},
+            {ASR_TENSOR_MLP_0_BIAS, "decoder.blocks.%d.mlp.0.bias"},
+            {ASR_TENSOR_MLP_2_WEIGHT, "decoder.blocks.%d.mlp.2.weight"},
+            {ASR_TENSOR_MLP_2_BIAS, "decoder.blocks.%d.mlp.2.bias"},
+            {ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.attn_ln.weight"},
+            {ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.attn_ln.bias"},
+            {ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.attn.query.weight"},
+            {ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.attn.query.bias"},
+            {ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.attn.key.weight"},
+            {ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.attn.value.weight"},
+            {ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.attn.value.bias"},
+            {ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.attn.out.weight"},
+            {ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.attn.out.bias"},
+        },
+    },
+    {
+        ASR_SYSTEM_CROSS,
+        {
+            {ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.cross_attn_ln.weight"},
+            {ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.cross_attn_ln.bias"},
+            {ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.cross_attn.query.weight"},
+            {ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.cross_attn.query.bias"},
+            {ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.cross_attn.key.weight"},
+            {ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.cross_attn.value.weight"},
+            {ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.cross_attn.value.bias"},
+            {ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.cross_attn.out.weight"},
+            {ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.cross_attn.out.bias"},
+        },
+    },
+};
+
+static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
+    {ASR_TENSOR_ENC_POS_EMBD,          GGML_OP_ADD},
+    {ASR_TENSOR_DEC_POS_EMBD,          GGML_OP_GET_ROWS},
+    // Note: ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT is also used by GGML_OP_MAT_MUL. Need to figure out a way how to handle
+    // weight tensors that are used by multiple different operators when extra_buffer_type implementations accelerate
+    // more than just GGML_OP_MUL_MAT.
+    {ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, GGML_OP_GET_ROWS},
+    {ASR_TENSOR_LN_WEIGHT,             GGML_OP_MUL},
+    {ASR_TENSOR_LN_BIAS,               GGML_OP_ADD},
+    {ASR_TENSOR_CONV1_WEIGHT,          GGML_OP_IM2COL},
+    {ASR_TENSOR_CONV1_BIAS,            GGML_OP_ADD},
+    {ASR_TENSOR_CONV2_WEIGHT,          GGML_OP_IM2COL},
+    {ASR_TENSOR_CONV2_BIAS,            GGML_OP_ADD},
+    {ASR_TENSOR_LN_POST_WEIGHT,        GGML_OP_MUL},
+    {ASR_TENSOR_LN_POST_BIAS,          GGML_OP_ADD},
+    {ASR_TENSOR_MLP_LN_WEIGHT,         GGML_OP_MUL},
+    {ASR_TENSOR_MLP_LN_BIAS,           GGML_OP_ADD},
+    {ASR_TENSOR_MLP_0_WEIGHT,          GGML_OP_MUL_MAT},
+    {ASR_TENSOR_MLP_0_BIAS,            GGML_OP_ADD},
+    {ASR_TENSOR_MLP_2_WEIGHT,          GGML_OP_MUL_MAT},
+    {ASR_TENSOR_MLP_2_BIAS,            GGML_OP_ADD},
+    {ASR_TENSOR_ATTN_LN_WEIGHT,        GGML_OP_MUL},
+    {ASR_TENSOR_ATTN_LN_BIAS,          GGML_OP_ADD},
+    {ASR_TENSOR_ATTN_QUERY_WEIGHT,     GGML_OP_MUL_MAT},
+    {ASR_TENSOR_ATTN_QUERY_BIAS,       GGML_OP_ADD},
+    {ASR_TENSOR_ATTN_KEY_WEIGHT,       GGML_OP_MUL_MAT},
+    {ASR_TENSOR_ATTN_VALUE_WEIGHT,     GGML_OP_MUL_MAT},
+    {ASR_TENSOR_ATTN_VALUE_BIAS,       GGML_OP_ADD},
+    {ASR_TENSOR_ATTN_OUT_WEIGHT,       GGML_OP_MUL_MAT},
+    {ASR_TENSOR_ATTN_OUT_BIAS,         GGML_OP_ADD},
+};
index 547dcf235bb21a5579c84a8a9c0ff14e76b70890..c633765ed7a2cc1d22aa213cf0e100c162304614 100644 (file)
@@ -1,4 +1,5 @@
 #include "whisper.h"
+#include "whisper-arch.h"
 
 #include "ggml.h"
 #include "ggml-cpp.h"
@@ -18,6 +19,7 @@
 #include <cassert>
 #define _USE_MATH_DEFINES
 #include <cmath>
+#include <climits>
 #include <codecvt>
 #include <cstdarg>
 #include <cstdio>
@@ -143,6 +145,21 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
 #define WHISPER_MAX_DECODERS 8
 #define WHISPER_MAX_NODES 4096
 
+static std::string format(const char * fmt, ...) {
+    va_list ap;
+    va_list ap2;
+    va_start(ap, fmt);
+    va_copy(ap2, ap);
+    int size = vsnprintf(NULL, 0, fmt, ap);
+    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
+    std::vector<char> buf(size + 1);
+    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
+    GGML_ASSERT(size2 == size);
+    va_end(ap2);
+    va_end(ap);
+    return std::string(buf.data(), size);
+}
+
 //
 // ggml helpers
 //
@@ -778,10 +795,10 @@ struct whisper_model {
     std::vector<whisper_layer_decoder> layers_decoder;
 
     // ggml context that contains all the meta information about the model tensors
-    struct ggml_context * ctx = nullptr;
+    std::vector<ggml_context *> ctxs;
 
     // the model backend data is read-only and can be shared between processors
-    ggml_backend_buffer_t buffer = nullptr;
+    std::vector<ggml_backend_buffer_t> buffers;
 
     // tensors
     int n_loaded;
@@ -1364,28 +1381,109 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
     return result;
 }
 
-static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
-    ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();
+using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
 
-    if (!params.use_gpu) {
-        return result;
-    }
+static buft_list_t make_buft_list(whisper_context_params & params) {
+    // Prio order: GPU -> CPU Extra -> CPU
+    buft_list_t buft_list;
 
-    int cnt = 0;
-    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
-            if (cnt == 0 || cnt == params.gpu_device) {
-                result = ggml_backend_dev_buffer_type(dev);
+    // GPU
+    if (params.use_gpu) {
+        int cnt = 0;
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
+                if (cnt == 0 || cnt == params.gpu_device) {
+                    auto * buft = ggml_backend_dev_buffer_type(dev);
+                    if (buft) {
+                        buft_list.emplace_back(dev, buft);
+                    }
+                }
+
+                if (++cnt > params.gpu_device) {
+                    break;
+                }
             }
+        }
+    }
+
+    // CPU Extra
+    auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
+    auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
+        ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
+    if (get_extra_bufts_fn) {
+        ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
+        while (extra_bufts && *extra_bufts) {
+            buft_list.emplace_back(cpu_dev, *extra_bufts);
+            ++extra_bufts;
+        }
+    }
+
+    // CPU
+    buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
+
+    return buft_list;
+}
+
+static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
+    bool op_supported = true;
+
+    if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
+        (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
+        // GPU and default CPU backend support all operators
+        op_supported = true;
+    } else {
+        switch (op) {
+            // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
+            case GGML_OP_MUL_MAT: {
+                ggml_init_params params = {
+                    /*.mem_size   =*/ 2 * ggml_tensor_overhead(),
+                    /*.mem_buffer =*/ nullptr,
+                    /*.no_alloc   =*/ true,
+                };
+
+                ggml_context_ptr ctx_ptr { ggml_init(params) };
+                if (!ctx_ptr) {
+                    throw std::runtime_error("failed to create ggml context");
+                }
+                ggml_context * ctx = ctx_ptr.get();
 
-            if (++cnt > params.gpu_device) {
+                ggml_tensor * op_tensor = nullptr;
+
+                int64_t n_ctx = hparams.n_audio_ctx;
+                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
+                op_tensor = ggml_mul_mat(ctx, w, b);
+
+                // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
+                GGML_ASSERT(w->buffer == nullptr);
+                w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
+                op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
+                ggml_backend_buffer_free(w->buffer);
+                w->buffer = nullptr;
+                break;
+            }
+            default: {
+                op_supported = false;
                 break;
             }
+        };
+    }
+
+    return op_supported;
+}
+
+static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
+    GGML_ASSERT(!buft_list.empty());
+    for (const auto & p : buft_list) {
+        ggml_backend_dev_t dev = p.first;
+        ggml_backend_buffer_type_t buft = p.second;
+        if (weight_buft_supported(hparams, w, op, buft, dev)) {
+            return buft;
         }
     }
 
-    return result;
+    return nullptr;
 }
 
 // load the model from a ggml file
@@ -1594,31 +1692,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
     const ggml_type wtype = wctx.wtype;
     const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
 
-    // create the ggml context
-    {
-        const auto & hparams = model.hparams;
+    const auto & hparams = model.hparams;
 
-        const int n_audio_layer = hparams.n_audio_layer;
-        const int n_text_layer  = hparams.n_text_layer;
+    const int n_audio_layer = hparams.n_audio_layer;
+    const int n_text_layer  = hparams.n_text_layer;
 
-        const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
+    const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
 
-        struct ggml_init_params params = {
-            /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
-            /*.mem_buffer =*/ nullptr,
-            /*.no_alloc   =*/ true,
-        };
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            ggml_init_params params = {
+                /*.mem_size   =*/ n_tensors * ggml_tensor_overhead(),
+                /*.mem_buffer =*/ nullptr,
+                /*.no_alloc   =*/ true,
+            };
 
-        model.ctx = ggml_init(params);
-        if (!model.ctx) {
-            WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
-            return false;
+            ggml_context * ctx = ggml_init(params);
+            if (!ctx) {
+                throw std::runtime_error("failed to create ggml context");
+            }
+
+            ctx_map[buft] = ctx;
+            model.ctxs.emplace_back(ctx);
+
+            return ctx;
         }
-    }
+
+        return it->second;
+    };
+
+    // Create a list of available bufts, in priority order
+    buft_list_t buft_list = make_buft_list(wctx.params);
+
+    auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
+        ggml_op op = ASR_TENSOR_INFO.at(type);
+        ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
+        if (!buft) {
+            throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
+        }
+
+        ggml_context * ctx = get_ctx(buft);
+        ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
+
+        model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
+
+        return tensor;
+    };
+
 
     // prepare tensors for the weights
     {
-        auto & ctx = model.ctx;
+        ggml_init_params params = {
+            /*.mem_size   =*/ n_tensors * ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+
+        ggml_context * ctx = ggml_init(params);
 
         const auto & hparams = model.hparams;
 
@@ -1638,189 +1770,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         model.layers_decoder.resize(n_text_layer);
 
         // encoder
-        {
-            model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
-
-            model.e_conv_1_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_mels,     n_audio_state);
-            model.e_conv_1_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,         1,     n_audio_state);
-
-            model.e_conv_2_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_audio_state, n_audio_state);
-            model.e_conv_2_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,                1, n_audio_state);
-
-            model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
-            model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
-
-            // map by name
-            model.tensors["encoder.positional_embedding"] = model.e_pe;
-
-            model.tensors["encoder.conv1.weight"]         = model.e_conv_1_w;
-            model.tensors["encoder.conv1.bias"]           = model.e_conv_1_b;
-
-            model.tensors["encoder.conv2.weight"]         = model.e_conv_2_w;
-            model.tensors["encoder.conv2.bias"]           = model.e_conv_2_b;
+        model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
 
-            model.tensors["encoder.ln_post.weight"]       = model.e_ln_w;
-            model.tensors["encoder.ln_post.bias"]         = model.e_ln_b;
+        model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
+        model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
 
-            for (int i = 0; i < n_audio_layer; ++i) {
-                auto & layer = model.layers_encoder[i];
+        model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
+        model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
 
-                layer.mlp_ln_w    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state);
-                layer.mlp_ln_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state);
+        model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
+        model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
 
-                layer.mlp_0_w     = ggml_new_tensor_2d(ctx, wtype,           n_audio_state, 4*n_audio_state);
-                layer.mlp_0_b     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
+        for (int i = 0; i < n_audio_layer; ++i) {
+            auto & layer = model.layers_encoder[i];
 
-                layer.mlp_1_w     = ggml_new_tensor_2d(ctx, wtype,         4*n_audio_state, n_audio_state);
-                layer.mlp_1_b     = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state);
+            layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
+            layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state), i);
 
-                layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state);
-                layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state);
+            layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
+            layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
 
-                layer.attn_q_w    = ggml_new_tensor_2d(ctx, wtype,           n_audio_state, n_audio_state);
-                layer.attn_q_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state);
+            layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
+            layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state), i);
 
-                layer.attn_k_w    = ggml_new_tensor_2d(ctx, wtype,           n_audio_state, n_audio_state);
+            layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
+            layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
 
-                layer.attn_v_w    = ggml_new_tensor_2d(ctx, wtype,           n_audio_state, n_audio_state);
-                layer.attn_v_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state);
+            layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
+            layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
 
-                layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype,           n_audio_state, n_audio_state);
-                layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_audio_state);
+            layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
 
-                // map by name
-                model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"]     = layer.mlp_ln_w;
-                model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"]       = layer.mlp_ln_b;
+            layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
+            layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
 
-                model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"]      = layer.mlp_0_w;
-                model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"]        = layer.mlp_0_b;
-
-                model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"]      = layer.mlp_1_w;
-                model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"]        = layer.mlp_1_b;
-
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"]    = layer.attn_ln_0_w;
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"]      = layer.attn_ln_0_b;
-
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"]   = layer.attn_q_b;
-
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"]   = layer.attn_k_w;
-
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"]   = layer.attn_v_b;
-
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"]   = layer.attn_ln_1_w;
-                model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"]     = layer.attn_ln_1_b;
-            }
+            layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
+            layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
         }
 
         // decoder
-        {
-            model.d_pe   = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
-
-            model.d_te   = ggml_new_tensor_2d(ctx, wtype,         n_text_state, n_vocab);
-
-            model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
-            model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
-
-            // map by name
-            model.tensors["decoder.positional_embedding"]   = model.d_pe;
-
-            model.tensors["decoder.token_embedding.weight"] = model.d_te;
-
-            model.tensors["decoder.ln.weight"]              = model.d_ln_w;
-            model.tensors["decoder.ln.bias"]                = model.d_ln_b;
-
-            for (int i = 0; i < n_text_layer; ++i) {
-                auto & layer = model.layers_decoder[i];
-
-                layer.mlp_ln_w          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
-                layer.mlp_ln_b          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
-
-                layer.mlp_0_w           = ggml_new_tensor_2d(ctx, wtype,           n_text_state, 4*n_text_state);
-                layer.mlp_0_b           = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
-
-                layer.mlp_1_w           = ggml_new_tensor_2d(ctx, wtype,         4*n_text_state, n_text_state);
-                layer.mlp_1_b           = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
-
-                layer.attn_ln_0_w       = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
-                layer.attn_ln_0_b       = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
-
-                layer.attn_q_w          = ggml_new_tensor_2d(ctx, wtype,           n_text_state, n_text_state);
-                layer.attn_q_b          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
-
-                layer.attn_k_w          = ggml_new_tensor_2d(ctx, wtype,           n_text_state, n_text_state);
-
-                layer.attn_v_w          = ggml_new_tensor_2d(ctx, wtype,           n_text_state, n_text_state);
-                layer.attn_v_b          = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
+        model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
 
-                layer.attn_ln_1_w       = ggml_new_tensor_2d(ctx, wtype,           n_text_state, n_text_state);
-                layer.attn_ln_1_b       = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
+        model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
 
-                layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
-                layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
+        model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
+        model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
 
-                layer.cross_attn_q_w    = ggml_new_tensor_2d(ctx, wtype,           n_text_state, n_text_state);
-                layer.cross_attn_q_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
+        for (int i = 0; i < n_text_layer; ++i) {
+            auto & layer = model.layers_decoder[i];
 
-                layer.cross_attn_k_w    = ggml_new_tensor_2d(ctx, wtype,           n_text_state, n_text_state);
+            layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
+            layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                layer.cross_attn_v_w    = ggml_new_tensor_2d(ctx, wtype,           n_text_state, n_text_state);
-                layer.cross_attn_v_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
+            layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
+            layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
 
-                layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype,           n_text_state, n_text_state);
-                layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_text_state);
+            layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
+            layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                // map by name
-                model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"]           = layer.mlp_ln_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"]             = layer.mlp_ln_b;
+            layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
+            layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"]            = layer.mlp_0_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"]              = layer.mlp_0_b;
+            layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
+            layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"]            = layer.mlp_1_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"]              = layer.mlp_1_b;
+            layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"]          = layer.attn_ln_0_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"]            = layer.attn_ln_0_b;
+            layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
+            layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"]       = layer.attn_q_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"]         = layer.attn_q_b;
+            layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
+            layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"]         = layer.attn_k_w;
+            layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
+            layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"]       = layer.attn_v_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"]         = layer.attn_v_b;
+            layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
+            layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"]         = layer.attn_ln_1_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"]           = layer.attn_ln_1_b;
+            layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"]    = layer.cross_attn_ln_0_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"]      = layer.cross_attn_ln_0_b;
+            layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
+            layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
 
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"]   = layer.cross_attn_q_b;
-
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"]   = layer.cross_attn_k_w;
-
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"]   = layer.cross_attn_v_b;
-
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"]   = layer.cross_attn_ln_1_w;
-                model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"]     = layer.cross_attn_ln_1_b;
-            }
+            layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
+            layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
         }
+
+        ggml_free(ctx);
     }
 
     // allocate tensors in the backend buffers
-    model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
-    if (!model.buffer) {
-        WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
-        return false;
-    }
+    for (auto & p : ctx_map) {
+        ggml_backend_buffer_type_t buft = p.first;
+        ggml_context * ctx = p.second;
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (buf) {
+            model.buffers.emplace_back(buf);
 
-    size_t size_main = ggml_backend_buffer_get_size(model.buffer);
-    WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
+            size_t size_main = ggml_backend_buffer_get_size(buf);
+            WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
+        }
+    }
 
     // load weights
     {
@@ -1883,11 +1934,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 return false;
             }
 
-            //ggml_backend_t backend = wctx.backend;
-
-            //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
-
-            if (ggml_backend_buffer_is_host(model.buffer)) {
+            if (ggml_backend_buffer_is_host(tensor->buffer)) {
                 // for the CPU and Metal backend, we can read directly into the tensor
                 loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
                 BYTESWAP_TENSOR(tensor);
@@ -1900,7 +1947,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
             }
 
-            //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
             total_size += ggml_nbytes(tensor);
             model.n_loaded++;
         }
@@ -1915,7 +1961,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
-    ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+    for (auto & buf : model.buffers) {
+        ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+    }
 
     wctx.t_load_us = ggml_time_us() - t_start_us;
 
@@ -3806,9 +3854,13 @@ void whisper_free_state(struct whisper_state * state) {
 
 void whisper_free(struct whisper_context * ctx) {
     if (ctx) {
-        ggml_free(ctx->model.ctx);
+        for (ggml_context * context : ctx->model.ctxs) {
+            ggml_free(context);
+        }
 
-        ggml_backend_buffer_free(ctx->model.buffer);
+        for (ggml_backend_buffer_t buf : ctx->model.buffers) {
+            ggml_backend_buffer_free(buf);
+        }
 
         whisper_free_state(ctx->state);