]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : use ggml_backend_sched (#2239)
authorGeorgi Gerganov <redacted>
Tue, 18 Jun 2024 06:37:20 +0000 (09:37 +0300)
committerGeorgi Gerganov <redacted>
Tue, 18 Jun 2024 06:39:40 +0000 (09:39 +0300)
* whisper : use ggml_backend_sched (wip)

* use sched in whisper_allocr

* whisper : single backend in whisper_context

* whisper : remove whisper_state->backends_used

* whisper : remove whisper_context->backend

* whisper : reset scheduler after init

* whisper : fix external encoder (e.g. CoreML)

* whisper : cleanup

* whisper : handle null GPU buffer types + fix sycl

---------

Co-authored-by: slaren <redacted>
ggml-backend.c
ggml-backend.h
whisper.cpp

index 2bec7bea38a85e10059551098a4779596c1fe0bb..174297949504776493bd229918cd21d30bed5699 100644 (file)
@@ -1706,14 +1706,16 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
     bool backend_ids_changed = false;
     for (int i = 0; i < sched->graph->n_nodes; i++) {
-        if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i]) {
+        if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&
+            sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {
             backend_ids_changed = true;
             break;
         }
     }
     if (!backend_ids_changed) {
         for (int i = 0; i < sched->graph->n_leafs; i++) {
-            if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i]) {
+            if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&
+                sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {
                 backend_ids_changed = true;
                 break;
             }
@@ -1977,6 +1979,15 @@ int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
     return sched->n_copies;
 }
 
+int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
+    return sched->n_backends;
+}
+
+ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
+    GGML_ASSERT(i >= 0 && i < sched->n_backends);
+    return sched->backends[i];
+}
+
 size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
     int backend_index = ggml_backend_sched_backend_id(sched, backend);
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
index 47fd8147517954088e9952eefc9f2ecf0d21fcd3..4a38eeb5c23bde451dc4b79ad9e0e8e7eb194ae8 100644 (file)
@@ -182,6 +182,9 @@ extern "C" {
     // Initialize backend buffers from a measure graph
     GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
 
+    GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
+
     // Get the number of splits of the last graph
     GGML_API int                  ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
     GGML_API int                  ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
index 0a53a03d03d8dd5f28d0702b93dcbd833bdd6996..4b96e8bcb6615289de30b192811339aa6ee05827 100644 (file)
 #include "ggml-sycl.h"
 #endif
 
+#ifdef GGML_USE_BLAS
+#include "ggml-blas.h"
+#endif
+
 #ifdef WHISPER_USE_OPENVINO
 #include "openvino/whisper-openvino-encoder.h"
 #endif
@@ -179,18 +183,30 @@ static bool ggml_graph_compute_helper(
 }
 
 static bool ggml_graph_compute_helper(
-       struct ggml_backend * backend,
+      ggml_backend_sched_t   sched,
         struct ggml_cgraph * graph,
                        int   n_threads) {
-    if (ggml_backend_is_cpu(backend)) {
-        ggml_backend_cpu_set_n_threads(backend, n_threads);
-    }
+
+    for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
+        ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
+        if (ggml_backend_is_cpu(backend)) {
+            ggml_backend_cpu_set_n_threads(backend, n_threads);
+        }
+#ifdef GGML_USE_BLAS
+        if (ggml_backend_is_blas(backend)) {
+            ggml_backend_blas_set_n_threads(backend, n_threads);
+        }
+#endif
 #ifdef GGML_USE_METAL
-    if (ggml_backend_is_metal(backend)) {
-        ggml_backend_metal_set_n_cb(backend, n_threads);
-    }
+        if (ggml_backend_is_metal(backend)) {
+            ggml_backend_metal_set_n_cb(backend, n_threads);
+        }
 #endif
-    return ggml_backend_graph_compute(backend, graph) == GGML_STATUS_SUCCESS;
+    }
+
+    bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
+    ggml_backend_sched_reset(sched);
+    return t;
 }
 
 // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
@@ -490,33 +506,41 @@ struct whisper_pair {
     whisper_pair() : first(A()), second(B()) {}
 };
 
-// ggml_allocr wrapper for whisper usage
-struct whisper_allocr {
-    ggml_gallocr_t alloc = nullptr;
+// ggml_backend_sched wrapper for whisper usage
+struct whisper_sched {
+    ggml_backend_sched_t sched = nullptr;
 
     std::vector<uint8_t> meta;
 };
 
-static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
-    return allocr.meta.size() + ggml_gallocr_get_buffer_size(allocr.alloc, 0);
+static size_t whisper_sched_size(struct whisper_sched & allocr) {
+    size_t size = allocr.meta.size();
+    for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
+        ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
+        size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
+    }
+    return size;
 }
 
 // measure the memory usage of a graph and prepare the allocr's internal data buffer
-static bool whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
-    auto & alloc = allocr.alloc;
+static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
+    auto & sched = allocr.sched;
     auto & meta  = allocr.meta;
 
-    alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
+    sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
 
     meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
 
     // since there are dependencies between the different graphs,
     // we need to allocate them instead of only reserving to get the correct compute buffer size
-    if (!ggml_gallocr_alloc_graph(alloc, get_graph())) {
+    if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
         // failed to allocate the compute buffer
         WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
         return false;
     }
+
+    ggml_backend_sched_reset(sched);
+
     return true;
 }
 
@@ -808,15 +832,13 @@ struct whisper_state {
 
     whisper_decoder decoders[WHISPER_MAX_DECODERS];
 
-    ggml_backend_t backend = nullptr;
+    std::vector<ggml_backend_t> backends;
 
-    // ggml-alloc:
     // - stores meta info about the intermediate tensors into the `meta` buffers
-    // - stores the actual tensor data into the `data` buffers
-    whisper_allocr alloc_conv;
-    whisper_allocr alloc_encode;
-    whisper_allocr alloc_cross;
-    whisper_allocr alloc_decode;
+    whisper_sched sched_conv;
+    whisper_sched sched_encode;
+    whisper_sched sched_cross;
+    whisper_sched sched_decode;
 
     // result of the encoder
     struct ggml_tensor * embd_conv = nullptr;
@@ -874,8 +896,6 @@ struct whisper_context {
 
     whisper_state * state = nullptr;
 
-    ggml_backend_t backend = nullptr;
-
     std::string path_model; // populated by whisper_init_from_file_with_params()
 };
 
@@ -1061,20 +1081,16 @@ static void whisper_kv_cache_seq_cp(
 }
 
 static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
-    if (!wctx.params.flash_attn) {
+    if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
         return 1u;
     }
 
 #ifdef GGML_USE_METAL
-    if (ggml_backend_is_metal(wctx.backend)) {
-        return 32u;
-    }
+    return 32u;
 #endif
 
 #ifdef GGML_USE_CUDA
-    if (ggml_backend_is_cuda(wctx.backend)) {
-        return 256u;
-    }
+    return 256u;
 #endif
 
     return 1u;
@@ -1211,15 +1227,14 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
     return size;
 }
 
-static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
-    ggml_backend_t backend_gpu = NULL;
+static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
+    ggml_backend_t result = NULL;
 
-    // initialize the backends
 #ifdef GGML_USE_CUDA
     if (params.use_gpu) {
         WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
-        backend_gpu = ggml_backend_cuda_init(params.gpu_device);
-        if (!backend_gpu) {
+        result = ggml_backend_cuda_init(params.gpu_device);
+        if (!result) {
             WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
         }
     }
@@ -1229,13 +1244,13 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
     if (params.use_gpu) {
         WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
         ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-        backend_gpu = ggml_backend_metal_init();
-        if (!backend_gpu) {
+        result = ggml_backend_metal_init();
+        if (!result) {
             WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
-        } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
+        } else if (!ggml_backend_metal_supports_family(result, 7)) {
             WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
-            ggml_backend_free(backend_gpu);
-            backend_gpu = NULL;
+            ggml_backend_free(result);
+            result = NULL;
         }
     }
 #endif
@@ -1243,20 +1258,64 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
 #ifdef GGML_USE_SYCL
     if (params.use_gpu) {
         WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
-        backend_gpu = ggml_backend_sycl_init(params.gpu_device);
-        if (!backend_gpu) {
+        result = ggml_backend_sycl_init(params.gpu_device);
+        if (!result) {
             WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
         }
     }
 #endif
 
-    GGML_UNUSED(params);
+    return result;
+}
+
+static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
+    std::vector<ggml_backend_t> result;
+
+    ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
 
     if (backend_gpu) {
-        return backend_gpu;
+        result.push_back(backend_gpu);
+    }
+
+#ifdef GGML_USE_BLAS
+    {
+        WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
+        ggml_backend_t backend_blas = ggml_backend_blas_init();
+        if (!backend_blas) {
+            WHISPER_LOG_ERROR("%s: ggml_backend_blas_init() failed\n", __func__);
+        } else {
+            result.push_back(backend_blas);
+        }
     }
+#endif
+
+    GGML_UNUSED(params);
+
+    result.push_back(ggml_backend_cpu_init());
+
+    return result;
+}
+
+static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
+    ggml_backend_buffer_type_t result = nullptr;
 
-    return ggml_backend_cpu_init();
+    params.use_gpu || (result = ggml_backend_cpu_buffer_type());
+
+#ifdef GGML_USE_CUDA
+    result || (result = ggml_backend_cuda_buffer_type(params.gpu_device));
+#endif
+
+#ifdef GGML_USE_METAL
+    result || (result = ggml_backend_metal_buffer_type());
+#endif
+
+#ifdef GGML_USE_SYCL
+    result || (result = ggml_backend_sycl_buffer_type(params.gpu_device));
+#endif
+
+    result || (result = ggml_backend_cpu_buffer_type());
+
+    return result;
 }
 
 // load the model from a ggml file
@@ -1683,21 +1742,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
-    wctx.backend = whisper_backend_init(wctx.params);
-    if (!wctx.backend) {
-        WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__);
-        return false;
-    }
-
     // allocate tensors in the backend buffers
-    model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, wctx.backend);
+    model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
     if (!model.buffer) {
         WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
         return false;
     }
 
     size_t size_main = ggml_backend_buffer_get_size(model.buffer);
-    WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6);
+    WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
 
     // load weights
     {
@@ -1792,6 +1845,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
+    ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+
     wctx.t_load_us = ggml_time_us() - t_start_us;
 
     return true;
@@ -1828,8 +1883,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
     const int n_mels = hparams.n_mels;
 
     struct ggml_init_params params = {
-        /*.mem_size   =*/ wstate.alloc_conv.meta.size(),
-        /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
+        /*.mem_size   =*/ wstate.sched_conv.meta.size(),
+        /*.mem_buffer =*/ wstate.sched_conv.meta.data(),
         /*.no_alloc   =*/ true,
     };
 
@@ -1837,9 +1892,13 @@ static struct ggml_cgraph * whisper_build_graph_conv(
 
     ggml_cgraph * gf = ggml_new_graph(ctx0);
 
+    GGML_ASSERT(wstate.mel.tensor);
+
     ggml_tensor * mel_inp = wstate.mel.tensor;
+    ggml_set_input(mel_inp);
+
     ggml_tensor * mel;
-    if (mel_inp) {
+    {
         const int n_len = int(mel_inp->ne[0]);
         const int out_s = 2 * n_ctx;
         const int i0 = std::min(mel_offset, n_len);
@@ -1853,16 +1912,12 @@ static struct ggml_cgraph * whisper_build_graph_conv(
 
         if (mel_s < out_s) {
             mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
-        }
-        else {
+        } else {
             mel = ggml_cont(ctx0, cur);
         }
     }
-    else {
-        // just create some tensor so that the graph/buffer size estimation is correct
-        mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
-    }
-    ggml_set_name(mel, "mel"); // used with external encoding
+
+    ggml_set_name(mel, "mel");
 
     struct ggml_tensor * cur = nullptr;
 
@@ -1886,6 +1941,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         ggml_build_forward_expand(gf, mel);
 
         cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
+        ggml_set_input(cur); // the external encoder will write into this tensor
 
         ggml_set_name(cur, "embd_enc");
         wstate.embd_enc = cur;
@@ -1920,8 +1976,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
     const int n_ctx_pad = GGML_PAD(n_ctx, 256);
 
     struct ggml_init_params params = {
-        /*.mem_size   =*/ wstate.alloc_encode.meta.size(),
-        /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
+        /*.mem_size   =*/ wstate.sched_encode.meta.size(),
+        /*.mem_buffer =*/ wstate.sched_encode.meta.data(),
         /*.no_alloc   =*/ true,
     };
 
@@ -2160,8 +2216,8 @@ static struct ggml_cgraph * whisper_build_graph_cross(
     const int n_ctx_pad = GGML_PAD(n_ctx, 256);
 
     struct ggml_init_params params = {
-        /*.mem_size   =*/ wstate.alloc_cross.meta.size(),
-        /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
+        /*.mem_size   =*/ wstate.sched_cross.meta.size(),
+        /*.mem_buffer =*/ wstate.sched_cross.meta.data(),
         /*.no_alloc   =*/ true,
     };
 
@@ -2242,16 +2298,16 @@ static bool whisper_encode_internal(
 
     // conv
     {
-        auto & alloc = wstate.alloc_conv.alloc;
+        auto & sched = wstate.sched_conv.sched;
 
         ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
 
-        if (!ggml_gallocr_alloc_graph(alloc, gf)) {
+        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
             // should never happen as we pre-allocate the memory
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
             return false;
         }
 
@@ -2269,32 +2325,32 @@ static bool whisper_encode_internal(
 
     // encoder
     if (!whisper_encode_external(wstate)) {
-        auto & alloc = wstate.alloc_encode.alloc;
+        auto & sched = wstate.sched_encode.sched;
 
         ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
 
-        if (!ggml_gallocr_alloc_graph(alloc, gf)) {
+        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
             // should never happen as we pre-allocate the memory
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
             return false;
         }
     }
 
     // cross
     {
-        auto & alloc = wstate.alloc_cross.alloc;
+        auto & sched = wstate.sched_cross.sched;
 
         ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
 
-        if (!ggml_gallocr_alloc_graph(alloc, gf)) {
+        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
             // should never happen as we pre-allocate the memory
             return false;
         }
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
             return false;
         }
     }
@@ -2336,8 +2392,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
     //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
 
     struct ggml_init_params params = {
-        /*.mem_size   =*/ wstate.alloc_decode.meta.size(),
-        /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
+        /*.mem_size   =*/ wstate.sched_decode.meta.size(),
+        /*.mem_buffer =*/ wstate.sched_decode.meta.data(),
         /*.no_alloc   =*/ true,
     };
 
@@ -2736,11 +2792,11 @@ static bool whisper_decode_internal(
 
     // decoder
     {
-        auto & alloc = wstate.alloc_decode.alloc;
+        auto & sched = wstate.sched_decode.sched;
 
         ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
 
-        if (!ggml_gallocr_alloc_graph(alloc, gf)) {
+        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
             // should never happen as we pre-allocate the memory
             return false;
         }
@@ -2795,7 +2851,7 @@ static bool whisper_decode_internal(
 
         logits = gf->nodes[gf->n_nodes - 1];
 
-        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+        if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
             return false;
         }
     }
@@ -3299,20 +3355,29 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
 struct whisper_state * whisper_init_state(whisper_context * ctx) {
     whisper_state * state = new whisper_state;
 
-    state->backend = whisper_backend_init(ctx->params);
-    if (!state->backend) {
+    state->backends = whisper_backend_init(ctx->params);
+    if (state->backends.empty()) {
         WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
         whisper_free_state(state);
         return nullptr;
     }
 
-    state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters);
+    state->mel_calc = whisper_mel_calc_create(state->backends[0], ctx->model.filters);
+
+    // init 60s of random mel data
+    {
+        const int n_len = 2*100*WHISPER_CHUNK_SIZE;
+        const int n_mel = ctx->model.filters.n_mel;
+
+        whisper_mel_free(state->mel);
+        whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
+    }
 
     // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
     // in theory, there can be a case where this is not enough, but in practice it should always be enough
     const int factor = 3;
 
-    if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype,
+    if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
                 ctx->model.hparams.n_text_state,
                 ctx->model.hparams.n_text_layer,
                 GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
@@ -3326,7 +3391,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
-    if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype,
+    if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
                 ctx->model.hparams.n_text_state,
                 ctx->model.hparams.n_text_layer,
                 GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
@@ -3340,7 +3405,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
-    if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype,
+    if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
                 ctx->model.hparams.n_audio_state,
                 1,
                 GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
@@ -3356,7 +3421,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // [EXPERIMENTAL] Token-level timestamps with DTW
     if (ctx->params.dtw_token_timestamps) {
-        if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) {
+        if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
             WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
             whisper_free_state(state);
             return nullptr;
@@ -3399,7 +3464,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     // conv allocator
     {
-        bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend,
+        bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
                 [&]() {
                     return whisper_build_graph_conv(*ctx, *state, 0);
                 });
@@ -3410,12 +3475,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
             return nullptr;
         }
 
-        WHISPER_LOG_INFO("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
+        WHISPER_LOG_INFO("%s: compute buffer (conv)   = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
     }
 
     // encoder allocator
     if (!whisper_encode_external(*state)) {
-        bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend,
+        bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
                 [&]() {
                     return whisper_build_graph_encoder(*ctx, *state);
                 });
@@ -3426,12 +3491,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
             return nullptr;
         }
 
-        WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
+        WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
     }
 
     // cross allocator
     {
-        bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend,
+        bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
                 [&]() {
                     return whisper_build_graph_cross(*ctx, *state);
                 });
@@ -3442,12 +3507,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
             return nullptr;
         }
 
-        WHISPER_LOG_INFO("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
+        WHISPER_LOG_INFO("%s: compute buffer (cross)  = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
     }
 
     // decoder allocator
     {
-        bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend,
+        bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
                 [&]() {
                     const auto & hparams = ctx->model.hparams;
 
@@ -3466,7 +3531,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
             return nullptr;
         }
 
-        WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
+        WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
     }
 
     return state;
@@ -3746,12 +3811,14 @@ void whisper_free_state(struct whisper_state * state) {
 
         whisper_batch_free(state->batch);
 
-        ggml_gallocr_free(state->alloc_conv.alloc);
-        ggml_gallocr_free(state->alloc_encode.alloc);
-        ggml_gallocr_free(state->alloc_cross.alloc);
-        ggml_gallocr_free(state->alloc_decode.alloc);
+        ggml_backend_sched_free(state->sched_conv.sched);
+        ggml_backend_sched_free(state->sched_encode.sched);
+        ggml_backend_sched_free(state->sched_cross.sched);
+        ggml_backend_sched_free(state->sched_decode.sched);
 
-        ggml_backend_free(state->backend);
+        for (auto & backend : state->backends) {
+            ggml_backend_free(backend);
+        }
 
         // [EXPERIMENTAL] Token-level timestamps with DTW
         aheads_masks_free(state->aheads_masks);
@@ -3768,8 +3835,6 @@ void whisper_free(struct whisper_context * ctx) {
 
         whisper_free_state(ctx->state);
 
-        ggml_backend_free(ctx->backend);
-
         delete ctx;
     }
 }
@@ -3800,7 +3865,7 @@ int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_s
         // 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation
         //    taking longer is not a major concern
         if (!state->mel_calc_fallback) {
-            state->mel_calc_fallback = new mel_calc_cpu(state->backend, ctx->model.filters);
+            state->mel_calc_fallback = new mel_calc_cpu(state->backends[0], ctx->model.filters);
         }
         state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads);
     }
@@ -3837,7 +3902,7 @@ int whisper_set_mel_with_state(
     }
 
     whisper_mel_free(state->mel);
-    whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel);
+    whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
 
     ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));