]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add pipeline parallelism support (#6017)
authorslaren <redacted>
Wed, 13 Mar 2024 17:54:21 +0000 (18:54 +0100)
committerGitHub <redacted>
Wed, 13 Mar 2024 17:54:21 +0000 (18:54 +0100)
* llama : add pipeline parallelism support for batch processing with multiple CUDA GPUs

ggml-ci

* server : add -ub, --ubatch-size parameter

* fix server embedding test

* llama : fix Mamba inference for pipeline parallelism

Tested to work correctly with both `main` and `parallel` examples.

* llama : limit max batch size to n_batch

* add LLAMA_SCHED_MAX_COPIES to configure the number of input copies for pipeline parallelism
default increase to 4 (from 2)

changing this value may improve performance for some systems, but increases memory usage

* fix hip build

* fix sycl build (disable cpy_tensor_async)

* fix hip build

* llama : limit n_batch and n_ubatch to n_ctx during context creation

* llama : fix norm backend

* batched-bench : sync after decode

* swiftui : sync after decode

* ggml : allow ggml_get_rows to use multiple threads if they are available

* check n_ubatch >= n_tokens with non-casual attention

* llama : do not limit n_batch to n_ctx with non-casual attn

* server : construct batch with size of llama_n_batch

* ggml_backend_cpu_graph_compute : fix return value when alloc fails

* llama : better n_batch and n_ubatch comment

* fix merge

* small fix

* reduce default n_batch to 2048

---------

Co-authored-by: Francis Couture-Harpin <redacted>
Co-authored-by: Georgi Gerganov <redacted>
25 files changed:
CMakeLists.txt
Makefile
common/common.cpp
common/common.h
examples/batched-bench/batched-bench.cpp
examples/embedding/embedding.cpp
examples/llama-bench/llama-bench.cpp
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
examples/perplexity/perplexity.cpp
examples/server/server.cpp
examples/server/tests/features/embeddings.feature
examples/server/tests/features/steps/steps.py
ggml-alloc.c
ggml-alloc.h
ggml-backend-impl.h
ggml-backend.c
ggml-backend.h
ggml-cuda.cu
ggml-kompute.cpp
ggml-metal.m
ggml-sycl.cpp
ggml-vulkan.cpp
ggml.c
llama.cpp
llama.h

index 7ab13cbd56e8a6ff361b089364b50d3516039e28..a8abf4088a01c673f9d362855750957e864cf35e 100644 (file)
@@ -118,6 +118,7 @@ option(LLAMA_SYCL                            "llama: use SYCL"
 option(LLAMA_SYCL_F16                        "llama: use 16 bit floats for sycl calculations"   OFF)
 set(LLAMA_SYCL_TARGET   "INTEL" CACHE STRING "llama: sycl target device")
 option(LLAMA_CPU_HBM                         "llama: use memkind for CPU HBM"                   OFF)
+set(LLAMA_SCHED_MAX_COPIES  "4" CACHE STRING "llama: max input copies for pipeline parallelism")
 
 option(LLAMA_BUILD_TESTS                     "llama: build tests"    ${LLAMA_STANDALONE})
 option(LLAMA_BUILD_EXAMPLES                  "llama: build examples" ${LLAMA_STANDALONE})
@@ -147,6 +148,8 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
 find_package(Threads REQUIRED)
 include(CheckCXXCompilerFlag)
 
+add_compile_definitions(GGML_SCHED_MAX_COPIES=${LLAMA_SCHED_MAX_COPIES})
+
 # enable libstdc++ assertions for debug builds
 if (CMAKE_SYSTEM_NAME MATCHES "Linux")
     add_compile_definitions($<$<CONFIG:Debug>:_GLIBCXX_ASSERTIONS>)
index c8fd3f5c5026a49deb80c8a591d274fa5e60b474..db9968efbddd8aa687e29deb1ebe024dc719568b 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -167,6 +167,10 @@ ifeq ($(UNAME_S),OpenBSD)
        MK_CPPFLAGS += -D_BSD_SOURCE
 endif
 
+ifdef LLAMA_SCHED_MAX_COPIES
+       MK_CPPFLAGS += -DGGML_SCHED_MAX_COPIES=$(LLAMA_SCHED_MAX_COPIES)
+endif
+
 ifdef LLAMA_DEBUG
        MK_CFLAGS   += -O0 -g
        MK_CXXFLAGS += -O0 -g
index 2f38ac632b45a91c9444bb56ae3030c7a6261a80..73b1b61ba1b74f28ecb7ed2d5b244baa7243db9c 100644 (file)
@@ -483,6 +483,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_batch = std::stoi(argv[i]);
+        } else if (arg == "-ub" || arg == "--ubatch-size") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_ubatch = std::stoi(argv[i]);
         } else if (arg == "--keep") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -977,7 +983,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        binary file containing multiple choice tasks.\n");
     printf("  -n N, --n-predict N   number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
     printf("  -c N, --ctx-size N    size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
-    printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
+    printf("  -b N, --batch-size N  logical maximum batch size (default: %d)\n", params.n_batch);
+    printf("  -ub N, --ubatch-size N\n");
+    printf("                        physical maximum batch size (default: %d)\n", params.n_ubatch);
     printf("  --samplers            samplers that will be used for generation in the order, separated by \';\'\n");
     printf("                        (default: %s)\n", sampler_type_names.c_str());
     printf("  --sampling-seq        simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
@@ -1287,8 +1295,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     auto cparams = llama_context_default_params();
 
     cparams.n_ctx             = params.n_ctx;
-    cparams.n_batch           = params.n_batch;
     cparams.n_seq_max         = params.n_parallel;
+    cparams.n_batch           = params.n_batch;
+    cparams.n_ubatch          = params.n_ubatch;
     cparams.n_threads         = params.n_threads;
     cparams.n_threads_batch   = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
     cparams.seed              = params.seed;
@@ -1379,6 +1388,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
         std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
         llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
         llama_kv_cache_clear(lctx);
+        llama_synchronize(lctx);
         llama_reset_timings(lctx);
     }
 
index f8d82b8713c8715329e90333c61040f4adb327bb..0f178b9eb1de37bb4ff6a104839a6767c43b9d39 100644 (file)
@@ -51,7 +51,8 @@ struct gpt_params {
     int32_t n_threads_batch_draft = -1;
     int32_t n_predict             = -1;    // new tokens to predict
     int32_t n_ctx                 = 512;   // context size
-    int32_t n_batch               = 512;   // batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_batch               = 2048;  // logical batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_ubatch              = 512;   // physical batch size for prompt processing (must be >=32 to use BLAS)
     int32_t n_keep                = 0;     // number of tokens to keep from initial prompt
     int32_t n_draft               = 5;     // number of tokens to draft during speculative decoding
     int32_t n_chunks              = -1;    // max number of chunks to process (-1 = unlimited)
index 22bc93bca817cb5321eb1a24929e98c56a1b5059..19674dfd3670839756a5fca859c30af9a5712c05 100644 (file)
@@ -138,6 +138,8 @@ int main(int argc, char ** argv) {
                 LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
                 return false;
             }
+
+            llama_synchronize(ctx);
         }
 
         return true;
index a553ae1c3f35d39314784084aeb25be67abfa838..49302a199977ed35880de6e4fd97b8db0cfdffa8 100644 (file)
@@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
 
     // max batch size
     const uint64_t n_batch = params.n_batch;
-    GGML_ASSERT(params.n_batch == params.n_ctx);
+    GGML_ASSERT(params.n_batch >= params.n_ctx);
 
     // tokenize the prompts and trim
     std::vector<std::vector<int32_t>> inputs;
index 2ff86ef6f11468e6ce441930e7d4d864ddadc965..bf94e7e7a60a4879492c0ca719efb430fa5fe2db 100644 (file)
@@ -164,6 +164,7 @@ struct cmd_params {
     std::vector<int> n_prompt;
     std::vector<int> n_gen;
     std::vector<int> n_batch;
+    std::vector<int> n_ubatch;
     std::vector<ggml_type> type_k;
     std::vector<ggml_type> type_v;
     std::vector<int> n_threads;
@@ -183,7 +184,8 @@ static const cmd_params cmd_params_defaults = {
     /* model         */ {"models/7B/ggml-model-q4_0.gguf"},
     /* n_prompt      */ {512},
     /* n_gen         */ {128},
-    /* n_batch       */ {512},
+    /* n_batch       */ {2048},
+    /* n_ubatch      */ {512},
     /* type_k        */ {GGML_TYPE_F16},
     /* type_v        */ {GGML_TYPE_F16},
     /* n_threads     */ {get_num_physical_cores()},
@@ -208,6 +210,7 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("  -p, --n-prompt <n>                  (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
     printf("  -n, --n-gen <n>                     (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
     printf("  -b, --batch-size <n>                (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
+    printf("  -ub N, --ubatch-size <n>            (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str());
     printf("  -ctk <t>, --cache-type-k <t>        (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
     printf("  -ctv <t>, --cache-type-v <t>        (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
     printf("  -t, --threads <n>                   (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
@@ -217,7 +220,7 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("  -nkvo, --no-kv-offload <0|1>        (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
     printf("  -mmp, --mmap <0|1>                  (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
     printf("  -embd, --embeddings <0|1>           (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
-    printf("  -ts, --tensor_split <ts0/ts1/..>    (default: 0)\n");
+    printf("  -ts, --tensor-split <ts0/ts1/..>    (default: 0)\n");
     printf("  -r, --repetitions <n>               (default: %d)\n", cmd_params_defaults.reps);
     printf("  -o, --output <csv|json|md|sql>      (default: %s)\n", output_format_str(cmd_params_defaults.output_format));
     printf("  -v, --verbose                       (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
@@ -297,6 +300,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
             }
             auto p = split<int>(argv[i], split_delim);
             params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
+        } else if (arg == "-ub" || arg == "--ubatch-size") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            auto p = split<int>(argv[i], split_delim);
+            params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end());
         } else if (arg == "-ctk" || arg == "--cache-type-k") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -455,6 +465,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
     if (params.n_prompt.empty())     { params.n_prompt = cmd_params_defaults.n_prompt; }
     if (params.n_gen.empty())        { params.n_gen = cmd_params_defaults.n_gen; }
     if (params.n_batch.empty())      { params.n_batch = cmd_params_defaults.n_batch; }
+    if (params.n_ubatch.empty())     { params.n_ubatch = cmd_params_defaults.n_ubatch; }
     if (params.type_k.empty())       { params.type_k = cmd_params_defaults.type_k; }
     if (params.type_v.empty())       { params.type_v = cmd_params_defaults.type_v; }
     if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
@@ -474,6 +485,7 @@ struct cmd_params_instance {
     int n_prompt;
     int n_gen;
     int n_batch;
+    int n_ubatch;
     ggml_type type_k;
     ggml_type type_v;
     int n_threads;
@@ -511,6 +523,7 @@ struct cmd_params_instance {
 
         cparams.n_ctx = n_prompt + n_gen;
         cparams.n_batch = n_batch;
+        cparams.n_ubatch = n_ubatch;
         cparams.type_k = type_k;
         cparams.type_v = type_v;
         cparams.offload_kqv = !no_kv_offload;
@@ -532,6 +545,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
     for (const auto & mmp : params.use_mmap)
     for (const auto & embd : params.embeddings)
     for (const auto & nb : params.n_batch)
+    for (const auto & nub : params.n_ubatch)
     for (const auto & tk : params.type_k)
     for (const auto & tv : params.type_v)
     for (const auto & nkvo : params.no_kv_offload)
@@ -545,6 +559,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .n_prompt     = */ n_prompt,
                 /* .n_gen        = */ 0,
                 /* .n_batch      = */ nb,
+                /* .n_ubatch     = */ nub,
                 /* .type_k       = */ tk,
                 /* .type_v       = */ tv,
                 /* .n_threads    = */ nt,
@@ -568,6 +583,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
                 /* .n_prompt     = */ 0,
                 /* .n_gen        = */ n_gen,
                 /* .n_batch      = */ nb,
+                /* .n_ubatch     = */ nub,
                 /* .type_k       = */ tk,
                 /* .type_v       = */ tv,
                 /* .n_threads    = */ nt,
@@ -604,6 +620,7 @@ struct test {
     uint64_t model_size;
     uint64_t model_n_params;
     int n_batch;
+    int n_ubatch;
     int n_threads;
     ggml_type type_k;
     ggml_type type_v;
@@ -627,6 +644,7 @@ struct test {
         model_size = llama_model_size(lmodel);
         model_n_params = llama_model_n_params(lmodel);
         n_batch = inst.n_batch;
+        n_ubatch = inst.n_ubatch;
         n_threads = inst.n_threads;
         type_k = inst.type_k;
         type_v = inst.type_v;
@@ -705,7 +723,8 @@ struct test {
             "cuda", "opencl", "vulkan", "kompute", "metal", "sycl", "gpu_blas", "blas",
             "cpu_info", "gpu_info",
             "model_filename", "model_type", "model_size", "model_n_params",
-            "n_batch", "n_threads", "type_k", "type_v",
+            "n_batch", "n_ubatch",
+            "n_threads", "type_k", "type_v",
             "n_gpu_layers", "split_mode",
             "main_gpu", "no_kv_offload",
             "tensor_split", "use_mmap", "embeddings",
@@ -719,7 +738,8 @@ struct test {
     enum field_type {STRING, BOOL, INT, FLOAT};
 
     static field_type get_field_type(const std::string & field) {
-        if (field == "build_number" || field == "n_batch" || field == "n_threads" ||
+        if (field == "build_number" || field == "n_batch" || field == "n_ubatch" ||
+            field == "n_threads" ||
             field == "model_size" || field == "model_n_params" ||
             field == "n_gpu_layers" || field == "main_gpu" ||
             field == "n_prompt" || field == "n_gen" ||
@@ -759,7 +779,8 @@ struct test {
             std::to_string(metal), std::to_string(sycl), std::to_string(gpu_blas), std::to_string(blas),
             cpu_info, gpu_info,
             model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
-            std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
+            std::to_string(n_batch), std::to_string(n_ubatch),
+            std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
             std::to_string(n_gpu_layers), split_mode_str(split_mode),
             std::to_string(main_gpu), std::to_string(no_kv_offload),
             tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
@@ -957,6 +978,9 @@ struct markdown_printer : public printer {
         if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
             fields.emplace_back("n_batch");
         }
+        if (params.n_ubatch.size() > 1 || params.n_ubatch != cmd_params_defaults.n_ubatch) {
+            fields.emplace_back("n_ubatch");
+        }
         if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
             fields.emplace_back("type_k");
         }
@@ -1096,25 +1120,32 @@ struct sql_printer : public printer {
 };
 
 static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
+    llama_set_n_threads(ctx, n_threads, n_threads);
+
+    //std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx)));
+    //llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0));
+    //GGML_UNUSED(n_batch);
+
     std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
     int n_processed = 0;
 
-    llama_set_n_threads(ctx, n_threads, n_threads);
-
     while (n_processed < n_prompt) {
         int n_tokens = std::min(n_prompt - n_processed, n_batch);
         llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
         n_processed += n_tokens;
     }
+
+    llama_synchronize(ctx);
 }
 
 static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
-    llama_token token = llama_token_bos(llama_get_model(ctx));
-
     llama_set_n_threads(ctx, n_threads, n_threads);
 
+    llama_token token = llama_token_bos(llama_get_model(ctx));
+
     for (int i = 0; i < n_gen; i++) {
         llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
+        llama_synchronize(ctx);
     }
 }
 
@@ -1203,7 +1234,8 @@ int main(int argc, char ** argv) {
 
         // warmup run
         if (t.n_prompt > 0) {
-            test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
+            //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
+            test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
         }
         if (t.n_gen > 0) {
             test_gen(ctx, 1, 0, t.n_threads);
@@ -1219,6 +1251,7 @@ int main(int argc, char ** argv) {
             if (t.n_gen > 0) {
                 test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
             }
+
             uint64_t t_ns = get_time_ns() - t_start;
             t.samples_ns.push_back(t_ns);
         }
index 58fcf40c6fb6954ff96c2bdaa264fc6f8a557b55..c249291aea1107183889fba42b5fd31e9b8ff1a1 100644 (file)
@@ -221,6 +221,7 @@ actor LlamaContext {
             if llama_decode(context, batch) != 0 {
                 print("llama_decode() failed during prompt")
             }
+            llama_synchronize(context)
 
             let t_pp_end = ggml_time_us()
 
@@ -240,6 +241,7 @@ actor LlamaContext {
                 if llama_decode(context, batch) != 0 {
                     print("llama_decode() failed during text generation")
                 }
+                llama_synchronize(context)
             }
 
             let t_tg_end = ggml_time_us()
index fdfc8f5dcd3e5e66d3aa50062e19edd061e95d7a..d766aef6ac1b1a69d998691cdded4b3ac88a73e3 100644 (file)
@@ -589,9 +589,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
             }
         }
 
-        const auto t_end = std::chrono::high_resolution_clock::now();
 
         if (i == 0) {
+            llama_synchronize(ctx);
+            const auto t_end = std::chrono::high_resolution_clock::now();
             const float t_total = std::chrono::duration<float>(t_end - t_start).count();
             fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
             int total_seconds = (int)(t_total*n_chunk/n_seq);
index 3172d96dd03539e38019d5b7898671ede4e6f28b..895d608fdcc06b4cc7fad28bc8f0c41da5520aa3 100644 (file)
@@ -147,7 +147,7 @@ struct server_slot {
     int32_t n_decoded   = 0;
     int32_t n_remaining = -1;
     int32_t i_batch     = -1;
-    int32_t n_predict   = -1;
+    int32_t n_predict   = -1; // TODO: disambiguate from params.n_predict
 
     int32_t n_prompt_tokens           = 0;
     int32_t n_prompt_tokens_processed = 0;
@@ -739,7 +739,13 @@ struct server_context {
         default_generation_settings_for_props = get_formated_generation(slots.front());
         default_generation_settings_for_props["seed"] = -1;
 
-        batch = llama_batch_init(n_ctx, 0, params.n_parallel);
+        // the update_slots() logic will always submit a maximum of n_batch tokens
+        // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
+        {
+            const int32_t n_batch = llama_n_batch(ctx);
+
+            batch = llama_batch_init(n_batch, 0, params.n_parallel);
+        }
 
         metrics.init();
     }
@@ -1036,8 +1042,10 @@ struct server_context {
                 llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
             }
 
-            for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) {
-                const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
+            const int32_t n_batch = llama_n_batch(ctx);
+
+            for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
+                const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
                 llama_batch batch_view = {
                     n_tokens,
                     batch.token    + i,
@@ -1226,7 +1234,7 @@ struct server_context {
             {"mirostat_eta",              slot.sparams.mirostat_eta},
             {"penalize_nl",               slot.sparams.penalize_nl},
             {"stop",                      slot.params.antiprompt},
-            {"n_predict",                 slot.params.n_predict},
+            {"n_predict",                 slot.params.n_predict}, // TODO: fix duplicate key n_predict
             {"n_keep",                    params.n_keep},
             {"ignore_eos",                ignore_eos},
             {"stream",                    slot.params.stream},
@@ -1738,7 +1746,8 @@ struct server_context {
         }
 
         // process in chunks of params.n_batch
-        int32_t n_batch = params.n_batch;
+        int32_t n_batch = llama_n_batch(ctx);
+        int32_t n_ubatch = llama_n_ubatch(ctx);
 
         // next, batch any pending prompts without exceeding n_batch
         if (params.cont_batching || batch.n_tokens == 0) {
@@ -1811,7 +1820,7 @@ struct server_context {
 
                         if (slot.embedding) {
                             // this prompt is too large to process - discard it
-                            if (slot.n_prompt_tokens > n_batch) {
+                            if (slot.n_prompt_tokens > n_ubatch) {
                                 slot.state = SLOT_STATE_PROCESSING;
                                 slot.command = SLOT_COMMAND_NONE;
                                 slot.release();
@@ -2157,7 +2166,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
     printf("  --pooling {none,mean,cls} pooling type for embeddings, use model default if unspecified\n");
     printf("  -dt N, --defrag-thold N\n");
     printf("                            KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
-    printf("  -b N, --batch-size N      batch size for prompt processing (default: %d)\n", params.n_batch);
+    printf("  -b N, --batch-size N      logical maximum batch size (default: %d)\n", params.n_batch);
+    printf("  -ub N, --ubatch-size N    physical maximum batch size (default: %d)\n", params.n_ubatch);
     printf("  --memory-f32              use f32 instead of f16 for memory key+value (default: disabled)\n");
     printf("                            not recommended: doubles context memory required and no measurable increase in quality\n");
     if (llama_supports_mlock()) {
@@ -2424,6 +2434,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
                 break;
             }
             params.n_batch = std::stoi(argv[i]);
+        } else if (arg == "-ub" || arg == "--ubatch-size") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_ubatch = std::stoi(argv[i]);
         } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
             if (++i >= argc) {
                 invalid_param = true;
index b47661e943ca63ade6a887ea8a87fbcb4b11c242..57359b267a668b3c6089dd3d62c2a8dc4c61e847 100644 (file)
@@ -9,6 +9,7 @@ Feature: llama.cpp server
     And   42 as server seed
     And   2 slots
     And   1024 as batch size
+    And   1024 as ubatch size
     And   2048 KV cache size
     And   embeddings extraction
     Then  the server is starting
index 98c2b61743cbffca52e64f5907d357d96bf7d2ec..cfa9f96ec5306ef58a0811c2361c39d192bae436 100644 (file)
@@ -33,6 +33,7 @@ def step_server_config(context, server_fqdn, server_port):
 
     context.model_alias = None
     context.n_batch = None
+    context.n_ubatch = None
     context.n_ctx = None
     context.n_ga = None
     context.n_ga_w = None
@@ -278,6 +279,11 @@ def step_n_batch(context, n_batch):
     context.n_batch = n_batch
 
 
+@step('{n_ubatch:d} as ubatch size')
+def step_n_ubatch(context, n_ubatch):
+    context.n_ubatch = n_ubatch
+
+
 @step('{seed:d} as seed')
 def step_seed(context, seed):
     context.seed = seed
@@ -1029,6 +1035,8 @@ def start_server_background(context):
     ]
     if context.n_batch:
         server_args.extend(['--batch-size', context.n_batch])
+    if context.n_ubatch:
+        server_args.extend(['--ubatch-size', context.n_ubatch])
     if context.n_gpu_layer:
         server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
     if context.server_continuous_batching:
index e675306c8c3f1dccf0e455f683cd57bd2e1dc09c..8ac1d3e51470c9cdcc3d4107a9e731c0dd8ea6d6 100644 (file)
@@ -61,7 +61,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
     }
 }
 
-// TODO: GGML_PAD ?
 static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
     assert(alignment && !(alignment & (alignment - 1))); // power of 2
     size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
@@ -69,25 +68,14 @@ static size_t aligned_offset(const void * buffer, size_t offset, size_t alignmen
 }
 
 // tallocr
-struct ggml_tallocr {
-    ggml_backend_buffer_t buffer;
-    void * base;
-    size_t alignment;
-    size_t offset;
-};
-
-ggml_tallocr_t ggml_tallocr_new(ggml_backend_buffer_t buffer) {
-    ggml_tallocr_t talloc = malloc(sizeof(struct ggml_tallocr));
-    if (talloc == NULL) {
-        return NULL;
-    }
 
+struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer) {
     void * base = ggml_backend_buffer_get_base(buffer);
     size_t align = ggml_backend_buffer_get_alignment(buffer);
 
     assert(align && !(align & (align - 1))); // power of 2
 
-    *talloc = (struct ggml_tallocr) {
+    struct ggml_tallocr talloc = (struct ggml_tallocr) {
         /*.buffer    = */ buffer,
         /*.base      = */ base,
         /*.alignment = */ align,
@@ -96,11 +84,7 @@ ggml_tallocr_t ggml_tallocr_new(ggml_backend_buffer_t buffer) {
     return talloc;
 }
 
-void ggml_tallocr_free(ggml_tallocr_t talloc) {
-    free(talloc);
-}
-
-void ggml_tallocr_alloc(ggml_tallocr_t talloc, struct ggml_tensor * tensor) {
+void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
     size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
     size = GGML_PAD(size, talloc->alignment);
 
@@ -354,12 +338,16 @@ struct hash_node {
     bool allocated;
 };
 
-//
 struct tensor_alloc {
     size_t offset;
     size_t size_max; // 0 = pre-allocated, unused, or view
 };
 
+struct leaf_alloc {
+    int buffer_id;
+    struct tensor_alloc leaf;
+};
+
 struct node_alloc {
     int buffer_id;
     struct tensor_alloc dst;
@@ -378,7 +366,7 @@ struct ggml_gallocr {
     struct node_alloc * node_allocs; // [n_nodes]
     int n_nodes;
 
-    struct tensor_alloc * leaf_allocs; // [n_leafs]
+    struct leaf_alloc * leaf_allocs; // [n_leafs]
     int n_leafs;
 };
 
@@ -543,13 +531,20 @@ static int get_node_buffer_id(const int * node_buffer_ids, int i) {
     return node_buffer_ids ? node_buffer_ids[i] : 0;
 }
 
-static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids) {
+static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
     // clear hash tables
     memset(galloc->hash_set.keys, 0, galloc->hash_set.size * sizeof(struct ggml_tensor *));
     memset(galloc->hash_values,   0, galloc->hash_set.size * sizeof(struct hash_node));
 
+    // allocate leafs
+    // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        ggml_gallocr_allocate_node(galloc, leaf, get_node_buffer_id(leaf_buffer_ids, i));
+    }
+
     // count number of children and views
-    // allocate all graph inputs and leafs first to avoid overwriting them
+    // allocate other graph inputs and leafs first to avoid overwriting them
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
 
@@ -577,19 +572,6 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
         }
     }
 
-    // allocate the remaining leafs that are unused on the graph
-    // these are effectively static tensors that the application is not using in the graph, but may still want to allocate for other purposes
-    for (int i = 0; i < graph->n_leafs; i++) {
-        struct ggml_tensor * leaf = graph->leafs[i];
-        struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
-
-        if (hn->n_children == 0) {
-            assert(!hn->allocated);
-            // since buffer ids are only given for nodes, these leafs are always allocated in the first buffer
-            ggml_gallocr_allocate_node(galloc, leaf, 0);
-        }
-    }
-
     // allocate tensors
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
@@ -652,7 +634,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
     }
 }
 
-bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids) {
+bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
     size_t hash_size = graph->visited_hash_table.size;
 
     // initialize hash table
@@ -676,7 +658,7 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
     }
 
     // allocate in hash table
-    ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids);
+    ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids);
 
     // set the node_allocs from the hash table
     if (galloc->n_nodes < graph->n_nodes) {
@@ -711,15 +693,16 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
     }
     if (galloc->n_leafs < graph->n_leafs) {
         free(galloc->leaf_allocs);
-        galloc->leaf_allocs = calloc(sizeof(struct tensor_alloc), graph->n_leafs);
+        galloc->leaf_allocs = calloc(sizeof(galloc->leaf_allocs[0]), graph->n_leafs);
         GGML_ASSERT(galloc->leaf_allocs != NULL);
     }
     galloc->n_leafs = graph->n_leafs;
     for (int i = 0; i < graph->n_leafs; i++) {
         struct ggml_tensor * leaf = graph->leafs[i];
         struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
-        galloc->leaf_allocs[i].offset = hn->offset;
-        galloc->leaf_allocs[i].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf);
+        galloc->leaf_allocs[i].buffer_id = hn->buffer_id;
+        galloc->leaf_allocs[i].leaf.offset = hn->offset;
+        galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf);
     }
 
     // reallocate buffers if needed
@@ -727,7 +710,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
         size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0;
         size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
 
-        if (new_size > cur_size) {
+        // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
+        if (new_size > cur_size || galloc->buffers[i] == NULL) {
 #ifndef NDEBUG
             fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
 #endif
@@ -744,30 +728,30 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
 }
 
 bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
-    return ggml_gallocr_reserve_n(galloc, graph, NULL);
+    return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);
 }
 
-static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, struct tensor_alloc * tensor_alloc) {
-    assert(node->data || node->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], node) <= tensor_alloc->size_max);
+static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, int buffer_id, struct tensor_alloc * tensor_alloc) {
+    assert(tensor->data || tensor->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
 
-    if (node->view_src != NULL) {
-        if (node->buffer == NULL) {
+    if (tensor->view_src != NULL) {
+        if (tensor->buffer == NULL) {
             assert(tensor_alloc->offset == SIZE_MAX);
-            if (node->view_src->buffer == NULL) {
+            if (tensor->view_src->buffer == NULL) {
                 // this tensor was allocated without ggml-backend
                 return;
             }
-            ggml_backend_view_init(galloc->buffers[buffer_id], node);
+            ggml_backend_view_init(galloc->buffers[buffer_id], tensor);
         }
     } else {
-        if (node->data == NULL) {
+        if (tensor->data == NULL) {
             assert(tensor_alloc->offset != SIZE_MAX);
-            assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], node) <= tensor_alloc->size_max);
+            assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
             void * base = ggml_backend_buffer_get_base(galloc->buffers[buffer_id]);
             void * addr = (char *)base + tensor_alloc->offset;
-            ggml_backend_tensor_alloc(galloc->buffers[buffer_id], node, addr);
+            ggml_backend_tensor_alloc(galloc->buffers[buffer_id], tensor, addr);
         } else {
-            if (node->buffer == NULL) {
+            if (tensor->buffer == NULL) {
                 // this tensor was allocated without ggml-backend
                 return;
             }
@@ -843,13 +827,18 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
 
     // reset buffers
     for (int i = 0; i < galloc->n_buffers; i++) {
-        // zero size buffers are not allocated
         if (galloc->buffers[i] != NULL) {
             ggml_backend_buffer_reset(galloc->buffers[i]);
         }
     }
 
     // allocate the graph tensors from the previous assignments
+    // leafs
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i];
+        ggml_gallocr_init_tensor(galloc, leaf, leaf_alloc->buffer_id, &leaf_alloc->leaf);
+    }
     // nodes
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
@@ -863,12 +852,6 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
         }
         ggml_gallocr_init_tensor(galloc, node, node_alloc->buffer_id, &node_alloc->dst);
     }
-    // leafs
-    for (int i = 0; i < graph->n_leafs; i++) {
-        struct ggml_tensor * leaf = graph->leafs[i];
-        struct tensor_alloc * leaf_alloc = &galloc->leaf_allocs[i];
-        ggml_gallocr_init_tensor(galloc, leaf, 0, leaf_alloc);
-    }
 
     return true;
 }
@@ -900,12 +883,12 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
         return false;
     }
 
-    struct ggml_tallocr tallocr = ggml_tallocr_new(buffer);
+    struct ggml_tallocr tallocr = ggml_tallocr_new(buffer);
 
     for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) {
         if (t->data == NULL) {
             if (t->view_src == NULL) {
-                ggml_tallocr_alloc(tallocr, t);
+                ggml_tallocr_alloc(&tallocr, t);
             } else if (t->buffer == NULL) {
                 ggml_backend_view_init(buffer, t);
             }
@@ -917,8 +900,6 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
         }
     }
 
-    ggml_tallocr_free(tallocr);
-
     *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1));
     (*buffers)[(*n_buffers)++] = buffer;
 
index 1d9085d15f7937cfbe3d91b15ff55eca7e64ef81..434c13b34a929c43e959565acae84b97ae4fd066 100644 (file)
@@ -11,11 +11,15 @@ typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
 typedef struct ggml_backend * ggml_backend_t;
 
 // Tensor allocator
-typedef struct ggml_tallocr * ggml_tallocr_t;
+struct ggml_tallocr {
+    ggml_backend_buffer_t buffer;
+    void * base;
+    size_t alignment;
+    size_t offset;
+};
 
-GGML_API ggml_tallocr_t ggml_tallocr_new(ggml_backend_buffer_t buffer);
-GGML_API void           ggml_tallocr_free(ggml_tallocr_t talloc);
-GGML_API void           ggml_tallocr_alloc(ggml_tallocr_t talloc, struct ggml_tensor * tensor);
+GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
+GGML_API void                ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
 
 // Graph allocator
 /*
@@ -50,7 +54,11 @@ GGML_API void           ggml_gallocr_free(ggml_gallocr_t galloc);
 // not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
 // returns false if the buffer allocation failed
 GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
-GGML_API bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids);
+GGML_API bool ggml_gallocr_reserve_n(
+    ggml_gallocr_t galloc,
+    struct ggml_cgraph * graph,
+    const int * node_buffer_ids,
+    const int * leaf_buffer_ids);
 
 // automatic reallocation if the topology changes when using a single buffer
 // returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
index 2e9ba58a931cc4696aa53920026cd3e8147fc0ff..e475e20e5f46a55d6ca5f3b23ffb2d3549ce9c2f 100644 (file)
@@ -86,12 +86,12 @@ extern "C" {
         // (optional) asynchronous tensor data access
         void (*GGML_CALL set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
         void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst);
+        bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
 
         // (optional) complete all pending operations
         void (*GGML_CALL synchronize)(ggml_backend_t backend);
 
-        // create a plan for ggml_cgraph and free it
+        // compute graph with a plan (not used currently)
         ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
         void                      (*GGML_CALL graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
 
@@ -102,16 +102,27 @@ extern "C" {
 
         // check if the backend supports an operation
         bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+
+        // (optional) event synchronization
+        ggml_backend_event_t (*GGML_CALL event_new)         (ggml_backend_t backend);
+        void                 (*GGML_CALL event_free)        (ggml_backend_event_t event);
+        void                 (*GGML_CALL event_record)      (ggml_backend_event_t event);
+        void                 (*GGML_CALL event_wait)        (ggml_backend_t backend, ggml_backend_event_t event);
+        void                 (*GGML_CALL event_synchronize) (ggml_backend_event_t event);
     };
 
     struct ggml_backend {
         ggml_guid_t guid;
 
         struct ggml_backend_i iface;
-
         ggml_backend_context_t context;
     };
 
+    struct ggml_backend_event {
+        ggml_backend_t backend;
+        void * context;
+    };
+
     //
     // Backend registry
     //
index d60d9841432495a65dfd36fbc85d549c547b9a54..31f8d5a6dd30befb6a4bea8676c457e8259c982f 100644 (file)
@@ -221,29 +221,29 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
 GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
-    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
     GGML_ASSERT(buf != NULL && "tensor buffer not set");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
 
     if (!size) {
         return;
     }
 
-    tensor->buffer->iface.set_tensor(buf, tensor, data, offset, size);
+    buf->iface.set_tensor(buf, tensor, data, offset, size);
 }
 
 GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
     ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
+    GGML_ASSERT(buf != NULL && "tensor buffer not set");
     GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
-    GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
 
     if (!size) {
         return;
     }
 
-    tensor->buffer->iface.get_tensor(buf, tensor, data, offset, size);
+    buf->iface.get_tensor(buf, tensor, data, offset, size);
 }
 
 void ggml_backend_synchronize(ggml_backend_t backend) {
@@ -255,18 +255,30 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
 }
 
 ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    GGML_ASSERT(backend->iface.graph_plan_create != NULL);
+
     return backend->iface.graph_plan_create(backend, cgraph);
 }
 
 void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    GGML_ASSERT(backend->iface.graph_plan_free != NULL);
+
     backend->iface.graph_plan_free(backend, plan);
 }
 
 enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
+
     return backend->iface.graph_plan_compute(backend, plan);
 }
 
 enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);
+    ggml_backend_synchronize(backend);
+    return err;
+}
+
+bool ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     return backend->iface.graph_compute(backend, cgraph);
 }
 
@@ -314,34 +326,68 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
     }
 }
 
-void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
+void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) {
     GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
 
     if (src == dst) {
         return;
     }
 
-    if (ggml_backend_buft_supports_backend(src->buffer->buft, backend) && ggml_backend_buft_supports_backend(dst->buffer->buft, backend)) {
-        if (backend->iface.cpy_tensor_async != NULL) {
-            if (backend->iface.cpy_tensor_async(backend, src, dst)) {
-                return;
-            }
+    if (backend_dst->iface.cpy_tensor_async != NULL) {
+        if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
+            return;
         }
     }
 
-    size_t nbytes = ggml_nbytes(src);
+    // an async copy would normally happen after all the queued operations on both backends are completed
+    // sync src, set_async dst
     if (ggml_backend_buffer_is_host(src->buffer)) {
-        ggml_backend_tensor_set_async(backend, dst, src->data, 0, nbytes);
-    }
-    else {
+        ggml_backend_synchronize(backend_src);
+        ggml_backend_tensor_set_async(backend_dst, dst, src->data, 0, ggml_nbytes(src));
+    } else {
+        ggml_backend_synchronize(backend_src);
         ggml_backend_tensor_copy(src, dst);
+        ggml_backend_synchronize(backend_dst);
+    }
+}
+
+// events
+
+ggml_backend_event_t ggml_backend_event_new(ggml_backend_t backend) {
+    if (backend->iface.event_new == NULL) {
+        return NULL;
+    }
+    return backend->iface.event_new(backend);
+}
+
+void ggml_backend_event_free(ggml_backend_event_t event) {
+    if (event == NULL) {
+        return;
     }
+    event->backend->iface.event_free(event);
+}
+
+void ggml_backend_event_record(ggml_backend_event_t event) {
+    GGML_ASSERT(event->backend->iface.event_record != NULL);
+
+    event->backend->iface.event_record(event);
+}
+
+void ggml_backend_event_synchronize(ggml_backend_event_t event) {
+    GGML_ASSERT(event->backend->iface.event_synchronize != NULL);
+
+    event->backend->iface.event_synchronize(event);
 }
 
+void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    GGML_ASSERT(backend->iface.event_wait != NULL);
+
+    backend->iface.event_wait(backend, event);
+}
 
 // backend registry
 
-#define GGML_MAX_BACKENDS_REG 16
+#define GGML_REG_MAX_BACKENDS 16
 
 struct ggml_backend_reg {
     char name[128];
@@ -350,7 +396,7 @@ struct ggml_backend_reg {
     void * user_data;
 };
 
-static struct ggml_backend_reg ggml_backend_registry[GGML_MAX_BACKENDS_REG];
+static struct ggml_backend_reg ggml_backend_registry[GGML_REG_MAX_BACKENDS];
 static size_t ggml_backend_registry_count = 0;
 
 GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
@@ -395,7 +441,7 @@ GGML_CALL static void ggml_backend_registry_init(void) {
 }
 
 GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
-    GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
+    GGML_ASSERT(ggml_backend_registry_count < GGML_REG_MAX_BACKENDS);
 
     size_t id = ggml_backend_registry_count;
 
@@ -746,8 +792,12 @@ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t
     struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
 
     if (cpu_ctx->work_size < cplan.work_size) {
-        // TODO: may be faster to free and use malloc to avoid the copy
-        cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
+        free(cpu_ctx->work_data);
+        cpu_ctx->work_data = malloc(cplan.work_size);
+        if (cpu_ctx->work_data == NULL) {
+            cpu_ctx->work_size = 0;
+            return GGML_STATUS_ALLOC_FAILED;
+        }
         cpu_ctx->work_size = cplan.work_size;
     }
     cplan.work_data = cpu_ctx->work_data;
@@ -784,6 +834,11 @@ static struct ggml_backend_i cpu_backend_i = {
     /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
     /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
     /* .supports_op             = */ ggml_backend_cpu_supports_op,
+    /* .event_new               = */ NULL,
+    /* .event_free              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .event_synchronize       = */ NULL,
 };
 
 static ggml_guid_t ggml_backend_cpu_guid(void) {
@@ -939,15 +994,27 @@ static bool ggml_is_view_op(enum ggml_op op) {
 
 // scheduler
 
-#define GGML_MAX_BACKENDS 16
-#define GGML_MAX_SPLITS 256
-#define GGML_MAX_SPLIT_INPUTS 16
+#ifndef GGML_SCHED_MAX_BACKENDS
+#define GGML_SCHED_MAX_BACKENDS 16
+#endif
+
+#ifndef GGML_SCHED_MAX_SPLITS
+#define GGML_SCHED_MAX_SPLITS 256
+#endif
+
+#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
+#define GGML_SCHED_MAX_SPLIT_INPUTS 16
+#endif
+
+#ifndef GGML_SCHED_MAX_COPIES
+#define GGML_SCHED_MAX_COPIES 4
+#endif
 
 struct ggml_backend_sched_split {
     int backend_id;
     int i_start;
     int i_end;
-    struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
+    struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
     int n_inputs;
     // graph view of this split
     struct ggml_cgraph graph;
@@ -955,45 +1022,53 @@ struct ggml_backend_sched_split {
 
 struct ggml_backend_sched {
     bool is_reset; // true if the scheduler has been reset since the last graph split
+    bool is_alloc;
 
     int n_backends;
-    ggml_backend_t backends[GGML_MAX_BACKENDS];
-    ggml_backend_buffer_type_t bufts[GGML_MAX_BACKENDS];
 
+    ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS];
+    ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS];
     ggml_gallocr_t galloc;
 
     // hash keys of the nodes in the graph
     struct ggml_hash_set    hash_set;
     // hash values
     int * tensor_backend_id;
-    struct ggml_tensor * (* tensor_copies)[GGML_MAX_BACKENDS];
+    struct ggml_tensor * (* tensor_copies)[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
 
-    int * node_backend_ids; // [n_nodes]
-    int n_nodes;
+    int * node_backend_ids; // [graph_size]
+    int * leaf_backend_ids; // [graph_size]
 
     // copy of the graph with modified inputs
     struct ggml_cgraph * graph;
 
-    struct ggml_backend_sched_split splits[GGML_MAX_SPLITS];
+    // graph splits
+    struct ggml_backend_sched_split splits[GGML_SCHED_MAX_SPLITS];
     int n_splits;
 
+    // pipeline parallelism support
+    int n_copies;
+    int cur_copy;
+    ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
+    struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
+    int n_graph_inputs;
+
     struct ggml_context * ctx;
 
     ggml_backend_sched_eval_callback callback_eval;
     void * callback_eval_user_data;
 
     // align context_buffer to GGML_MEM_ALIGN
-    #ifdef _MSC_VER
+#ifdef _MSC_VER
     __declspec(align(GGML_MEM_ALIGN))
-    #else
+#else
     __attribute__((aligned(GGML_MEM_ALIGN)))
-    #endif
-    char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
+#endif
+    char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
 };
 
-#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
-#define tensor_backend_id(node) sched->tensor_backend_id[hash_id(node)]
-#define tensor_backend(node) (tensor_backend_id(node) == -1 ? NULL : sched->backends[tensor_backend_id(node)])
+#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor)
+#define tensor_backend_id(tensor) sched->tensor_backend_id[hash_id(tensor)]
 
 // returns the priority of the backend, lower id is higher priority
 static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) {
@@ -1005,7 +1080,8 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
     return -1;
 }
 
-static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) {
+static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor) {
+    ggml_backend_buffer_t buffer = tensor->buffer;
     if (buffer == NULL) {
         return -1;
     }
@@ -1016,12 +1092,16 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, gg
             return i;
         }
     }
-    GGML_ASSERT(false && "tensor buffer type not supported by any backend");
-    return -1; // silence warning
+
+    fprintf(stderr, "%s: error: no backend supports buffer type %s used in tensor %s\n",
+        __func__, ggml_backend_buffer_name(buffer), tensor->name);
+    GGML_ASSERT(false);
+
+    return -1;
 }
 
 #if 0
-static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug only
+static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
 #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
 #define GET_CAUSE(node) causes[hash_id(node)]
 #else
@@ -1035,19 +1115,28 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
 
     // assign pre-allocated nodes to their backend
     // dst
-    int cur_backend = ggml_backend_sched_backend_from_buffer(sched, tensor->buffer);
+    int cur_backend = ggml_backend_sched_backend_from_buffer(sched, tensor);
     if (cur_backend != -1) {
-        SET_CAUSE(node, "1.dst");
+        SET_CAUSE(tensor, "1.dst");
         return cur_backend;
     }
+
     // view_src
     if (tensor->view_src != NULL) {
-        cur_backend = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src->buffer);
+        cur_backend = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src);
         if (cur_backend != -1) {
-            SET_CAUSE(node, "1.vsrc");
+            SET_CAUSE(tensor, "1.vsrc");
             return cur_backend;
         }
     }
+
+    // input
+    if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
+        cur_backend = sched->n_backends - 1; // last backend (assumed CPU)
+        SET_CAUSE(tensor, "1.inp");
+        return cur_backend;
+    }
+
     // assign nodes that use weights to the backend of the weights
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         const struct ggml_tensor * src = tensor->src[i];
@@ -1055,9 +1144,9 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
             continue;
         }
         if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
-            int src_backend = ggml_backend_sched_backend_from_buffer(sched, src->buffer);
+            int src_backend = ggml_backend_sched_backend_from_buffer(sched, src);
             // operations with weights are always run on the same backend as the weights
-            SET_CAUSE(node, "1.wgt%d", i);
+            SET_CAUSE(tensor, "1.wgt%d", i);
             return src_backend;
         }
     }
@@ -1093,7 +1182,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
         if (ggml_is_view_op(node->op)) {
             continue;
         }
-        ggml_backend_t tensor_backend = tensor_backend(node);
+        ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
         fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
             fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
         for (int j = 0; j < GGML_MAX_SRC; j++) {
@@ -1101,7 +1190,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
             if (src == NULL) {
                 continue;
             }
-            ggml_backend_t src_backend = tensor_backend(src);
+            ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);
             fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
                 fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
         }
@@ -1118,6 +1207,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
 static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
     // reset splits
     sched->n_splits = 0;
+    sched->n_graph_inputs = 0;
     sched->is_reset = false;
 
     struct ggml_init_params params = {
@@ -1163,7 +1253,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         }
     }
 #ifdef DEBUG_PASS1
-    fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+    fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
 #endif
 
     // pass 2: expand current backend assignments
@@ -1171,10 +1261,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
     // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)
     // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
 
-    // pass 2.1 expand gpu up
+
+    // pass 2.2 expand gpu down
     {
         int cur_backend_id = -1;
-        for (int i = graph->n_nodes - 1; i >= 0; i--) {
+        for (int i = 0; i < graph->n_nodes; i++) {
             struct ggml_tensor * node = graph->nodes[i];
             if (ggml_is_view_op(node->op)) {
                 continue;
@@ -1189,15 +1280,15 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 }
             } else {
                 tensor_backend_id(node) = cur_backend_id;
-                SET_CAUSE(node, "2.1");
+                SET_CAUSE(node, "2.2");
             }
         }
     }
 
-    // pass 2.2 expand gpu down
+    // pass 2.1 expand gpu up
     {
         int cur_backend_id = -1;
-        for (int i = 0; i < graph->n_nodes; i++) {
+        for (int i = graph->n_nodes - 1; i >= 0; i--) {
             struct ggml_tensor * node = graph->nodes[i];
             if (ggml_is_view_op(node->op)) {
                 continue;
@@ -1212,15 +1303,16 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 }
             } else {
                 tensor_backend_id(node) = cur_backend_id;
-                SET_CAUSE(node, "2.2");
+                SET_CAUSE(node, "2.1");
             }
         }
     }
 
-    // pass 2.3 expand rest up
+
+    // pass 2.4 expand rest down
     {
         int cur_backend_id = -1;
-        for (int i = graph->n_nodes - 1; i >= 0; i--) {
+        for (int i = 0; i < graph->n_nodes; i++) {
             struct ggml_tensor * node = graph->nodes[i];
             if (ggml_is_view_op(node->op)) {
                 continue;
@@ -1230,15 +1322,14 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 cur_backend_id = tensor_backend_id;
             } else {
                 tensor_backend_id(node) = cur_backend_id;
-                SET_CAUSE(node, "2.3");
+                SET_CAUSE(node, "2.4");
             }
         }
     }
-
-    // pass 2.4 expand rest down
+        // pass 2.3 expand rest up
     {
         int cur_backend_id = -1;
-        for (int i = 0; i < graph->n_nodes; i++) {
+        for (int i = graph->n_nodes - 1; i >= 0; i--) {
             struct ggml_tensor * node = graph->nodes[i];
             if (ggml_is_view_op(node->op)) {
                 continue;
@@ -1248,12 +1339,13 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 cur_backend_id = tensor_backend_id;
             } else {
                 tensor_backend_id(node) = cur_backend_id;
-                SET_CAUSE(node, "2.4");
+                SET_CAUSE(node, "2.3");
             }
         }
     }
+
 #ifdef DEBUG_PASS2
-    fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+    fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
 #endif
 
     // pass 3: assign backends to remaining src from dst and view_src
@@ -1283,7 +1375,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         }
     }
 #ifdef DEBUG_PASS3
-    fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+    fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
 #endif
 
     // pass 4: split graph, find tensors that need to be copied
@@ -1315,7 +1407,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             if (tensor_backend_id != cur_backend_id) {
                 sched->splits[cur_split].i_end = i;
                 cur_split++;
-                GGML_ASSERT(cur_split < GGML_MAX_SPLITS);
+                GGML_ASSERT(cur_split < GGML_SCHED_MAX_SPLITS);
                 sched->splits[cur_split].backend_id = tensor_backend_id;
                 sched->splits[cur_split].i_start = i;
                 sched->splits[cur_split].n_inputs = 0;
@@ -1328,25 +1420,57 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 if (src == NULL) {
                     continue;
                 }
+
                 int src_backend_id = tensor_backend_id(src);
                 assert(src_backend_id != -1); // all inputs should be assigned by now
+
+                if (src->flags & GGML_TENSOR_FLAG_INPUT)  {
+                    size_t id = hash_id(src);
+                    if (sched->tensor_copies[id][src_backend_id][0] == NULL) {
+                        ggml_backend_t backend = sched->backends[src_backend_id];
+                        for (int c = 0; c < sched->n_copies; c++) {
+                            struct ggml_tensor * tensor_copy;
+                            if (c == sched->cur_copy) {
+                                tensor_copy = src; // use the original tensor as the current copy
+                            } else {
+                                tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                                ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
+                            }
+                            if (sched->n_copies > 1) {
+                                ggml_set_input(tensor_copy);
+                                ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
+                            }
+                            sched->tensor_copies[id][src_backend_id][c] = tensor_copy;
+                            tensor_backend_id(tensor_copy) = src_backend_id;
+                            SET_CAUSE(tensor_copy, "4.cpy");
+                        }
+                        int n_graph_inputs = sched->n_graph_inputs++;
+                        GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
+                        sched->graph_inputs[n_graph_inputs] = src;
+                    }
+                }
+
                 if (src_backend_id != tensor_backend_id) {
                     // create a copy of the input in the split's backend
                     size_t id = hash_id(src);
-                    if (sched->tensor_copies[id][cur_backend_id] == NULL) {
+                    if (sched->tensor_copies[id][cur_backend_id][0] == NULL) {
                         ggml_backend_t backend = sched->backends[cur_backend_id];
-                        struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
-                        ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
-
-                        sched->tensor_copies[id][cur_backend_id] = tensor_copy;
-                        tensor_backend_id(tensor_copy) = cur_backend_id;
-                        SET_CAUSE(tensor_copy, "4.cpy");
-
+                        for (int c = 0; c < sched->n_copies; c++) {
+                            struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                            ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
+                            if (sched->n_copies > 1) {
+                                ggml_set_input(tensor_copy);
+                                ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
+                            }
+                            sched->tensor_copies[id][cur_backend_id][c] = tensor_copy;
+                            tensor_backend_id(tensor_copy) = cur_backend_id;
+                            SET_CAUSE(tensor_copy, "4.cpy");
+                        }
                         int n_inputs = sched->splits[cur_split].n_inputs++;
-                        GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
+                        GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
                         sched->splits[cur_split].inputs[n_inputs] = src;
                     }
-                    node->src[j] = sched->tensor_copies[id][cur_backend_id];
+                    node->src[j] = sched->tensor_copies[id][cur_backend_id][sched->cur_copy];
                 }
             }
         }
@@ -1354,37 +1478,39 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         sched->n_splits = cur_split + 1;
     }
 #ifdef DEBUG_PASS4
-    fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+    fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
 #endif
 
 #ifndef NDEBUG
     // sanity check: all sources should have the same backend as the node
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
-        ggml_backend_t tensor_backend = tensor_backend(node);
+        ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
         if (tensor_backend == NULL) {
             fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
         }
-        if (node->view_src != NULL && tensor_backend != tensor_backend(node->view_src)) {
+        if (node->view_src != NULL && tensor_backend != ggml_backend_sched_get_tensor_backend(sched, node->view_src)) {
             fprintf(stderr, "!!!!!!! %s has backend %s, view_src %s has backend %s\n",
                 node->name, tensor_backend ? ggml_backend_name(tensor_backend) : "NULL",
-                node->view_src->name, tensor_backend(node->view_src) ? ggml_backend_name(tensor_backend(node->view_src)) : "NULL");
+                node->view_src->name, ggml_backend_sched_get_tensor_backend(sched, node->view_src) ?
+                    ggml_backend_name(ggml_backend_sched_get_tensor_backend(sched, node->view_src)) : "NULL");
         }
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
                 continue;
             }
-            ggml_backend_t src_backend = tensor_backend(src);
+            ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);
             if (src_backend != tensor_backend /* && src_backend != NULL */) {
                 fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
                     node->name, tensor_backend ? ggml_backend_name(tensor_backend) : "NULL",
                     j, src->name, src_backend ? ggml_backend_name(src_backend) : "NULL");
             }
-            if (src->view_src != NULL && src_backend != tensor_backend(src->view_src)) {
+            if (src->view_src != NULL && src_backend != ggml_backend_sched_get_tensor_backend(sched, src->view_src)) {
                 fprintf(stderr, "!!!!!!! [src] %s has backend %s, view_src %s has backend %s\n",
                     src->name, src_backend ? ggml_backend_name(src_backend) : "NULL",
-                    src->view_src->name, tensor_backend(src->view_src) ? ggml_backend_name(tensor_backend(src->view_src)) : "NULL");
+                    src->view_src->name, ggml_backend_sched_get_tensor_backend(sched, src->view_src) ?
+                        ggml_backend_name(ggml_backend_sched_get_tensor_backend(sched, src->view_src)) : "NULL");
             }
         }
     }
@@ -1392,18 +1518,20 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 #endif
 
     // create copies of the graph for each split
-    // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way
-    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
+    // TODO: avoid this copy
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS, false);
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &sched->splits[i];
         split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
 
+        // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
         for (int j = 0; j < split->n_inputs; j++) {
             struct ggml_tensor * input = split->inputs[j];
-            struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split->backend_id];
+            struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split->backend_id][sched->cur_copy];
 
             // add a dependency to the input source so that it is not freed before the copy is done
             struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);
+            input_dep->src[0] = input;
             sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(input);
             graph_copy->nodes[graph_copy->n_nodes++] = input_dep;
 
@@ -1417,18 +1545,56 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
         }
     }
+
+    if (sched->n_copies > 1) {
+        // add input copies as leafs so that they are allocated first
+        for (int i = 0; i < sched->n_graph_inputs; i++) {
+            struct ggml_tensor * input = sched->graph_inputs[i];
+            size_t id = hash_id(input);
+            int backend_id = tensor_backend_id(input);
+            for (int c = 0; c < sched->n_copies; c++) {
+                struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c];
+                sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
+                graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
+            }
+        }
+
+        for (int i = 0; i < sched->n_splits; i++) {
+            struct ggml_backend_sched_split * split = &sched->splits[i];
+            int backend_id = split->backend_id;
+            for (int j = 0; j < split->n_inputs; j++) {
+                struct ggml_tensor * input = split->inputs[j];
+                size_t id = hash_id(input);
+                for (int c = 0; c < sched->n_copies; c++) {
+                    struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c];
+                    sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
+                    graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
+                }
+            }
+        }
+    }
+
+    // add leafs from the original graph
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);
+        graph_copy->leafs[graph_copy->n_leafs++] = leaf;
+    }
+
     sched->graph = graph_copy;
 }
 
 static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
-    // ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids);
+    // allocate graph
     if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
+        // the re-allocation may cause the split inputs to be moved to a different address
+        ggml_backend_sched_synchronize(sched);
 #ifndef NDEBUG
-        fprintf(stderr, "ggml_backend_sched: failed to allocate graph, reserving\n");
+        fprintf(stderr, "%s: failed to allocate graph, reserving\n", __func__);
 #endif
-        ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids);
+        ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
         if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
-            fprintf(stderr, "ggml_backend_sched: failed to allocate graph\n");
+            fprintf(stderr, "%s: failed to allocate graph\n", __func__);
             return false;
         }
     }
@@ -1437,9 +1603,6 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
 }
 
 static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
-    uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
-    uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
-
     struct ggml_backend_sched_split * splits = sched->splits;
 
     for (int i = 0; i < sched->n_splits; i++) {
@@ -1448,34 +1611,36 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
         ggml_backend_t split_backend = sched->backends[split_backend_id];
 
         // copy the input tensors to the split backend
-        uint64_t copy_start_us = ggml_time_us();
         for (int j = 0; j < split->n_inputs; j++) {
+            ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
             struct ggml_tensor * input = split->inputs[j];
-            struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id];
+            struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy];
 
-            GGML_ASSERT(input->buffer != NULL);
-            GGML_ASSERT(input_cpy->buffer != NULL);
+            if (input->flags & GGML_TENSOR_FLAG_INPUT) {
+                // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
+                if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                    ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
+                } else {
+                    ggml_backend_synchronize(split_backend);
+                }
+                ggml_backend_tensor_copy(input, input_cpy);
+            } else {
+                if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                    ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
+                } else {
+                    ggml_backend_synchronize(split_backend);
+                    ggml_backend_synchronize(input_backend);
+                }
 
-            ggml_backend_tensor_copy_async(split_backend, input, input_cpy);
+                ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
+            }
         }
-        //ggml_backend_synchronize(split_backend); // necessary to measure copy time
-        int64_t copy_end_us = ggml_time_us();
-        copy_us[split_backend_id] += copy_end_us - copy_start_us;
 
-#if 0
-        char split_filename[GGML_MAX_NAME];
-        snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend));
-        ggml_graph_dump_dot(split->graph, NULL, split_filename);
-#endif
-
-
-        uint64_t compute_start_us = ggml_time_us();
         if (!sched->callback_eval) {
-            enum ggml_status ec = ggml_backend_graph_compute(split_backend, &split->graph);
+            enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
             if (ec != GGML_STATUS_SUCCESS) {
                 return ec;
             }
-            //ggml_backend_synchronize(split_backend); // necessary to measure compute time
         } else {
             // similar to ggml_backend_compare_graph_backend
             for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
@@ -1494,11 +1659,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
 
                 struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
 
-                enum ggml_status ec = ggml_backend_graph_compute(split_backend, &gv);
+                enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);
                 if (ec != GGML_STATUS_SUCCESS) {
                     return ec;
                 }
 
+                // TODO: pass backend to the callback, then the user can decide if they want to synchronize
+                ggml_backend_synchronize(split_backend);
+
                 if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
                     break;
                 }
@@ -1506,39 +1674,54 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
                 j0 = j1;
             }
         }
-        uint64_t compute_end_us = ggml_time_us();
-        compute_us[split_backend_id] += compute_end_us - compute_start_us;
-    }
 
-#if 0
-    // per-backend timings
-    fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
-    for (int i = 0; i < sched->n_backends; i++) {
-        if (copy_us[i] > 0 || compute_us[i] > 0) {
-            fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
+        // record the event of this copy
+        if (split->n_inputs > 0) {
+            if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]);
+            }
         }
     }
-#endif
+
+    sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
 
     return GGML_STATUS_SUCCESS;
 }
 
-ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
+ggml_backend_sched_t ggml_backend_sched_new(
+        ggml_backend_t * backends,
+        ggml_backend_buffer_type_t * bufts,
+        int n_backends,
+        size_t graph_size,
+        bool parallel) {
     GGML_ASSERT(n_backends > 0);
-    GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS);
+    GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
+    GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU
 
     struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1);
 
     // initialize hash table
-    sched->hash_set          = ggml_hash_set_new(graph_size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
+    sched->hash_set          = ggml_hash_set_new(graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS);
     sched->tensor_backend_id = calloc(sizeof(sched->tensor_backend_id[0]), sched->hash_set.size);
     sched->tensor_copies     = calloc(sizeof(sched->tensor_copies[0]), sched->hash_set.size);
     sched->node_backend_ids  = calloc(sizeof(sched->node_backend_ids[0]), graph_size);
+    sched->leaf_backend_ids  = calloc(sizeof(sched->leaf_backend_ids[0]), graph_size);
 
     sched->n_backends = n_backends;
-    for (int i = 0; i < n_backends; i++) {
-        sched->backends[i] = backends[i];
-        sched->bufts[i] = bufts ? bufts[i] : ggml_backend_get_default_buffer_type(backends[i]);
+
+    sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
+
+    GGML_ASSERT(sched->n_copies <= GGML_SCHED_MAX_COPIES);
+
+    for (int b = 0; b < n_backends; b++) {
+        sched->backends[b] = backends[b];
+        sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
+        GGML_ASSERT(ggml_backend_buft_supports_backend(sched->bufts[b], backends[b]));
+        if (sched->n_copies > 1) {
+            for (int c = 0; c < sched->n_copies; c++) {
+                sched->events[b][c] = ggml_backend_event_new(backends[b]);
+            }
+        }
     }
 
     sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
@@ -1552,12 +1735,18 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
     if (sched == NULL) {
         return;
     }
+    for (int b = 0; b < sched->n_backends; b++) {
+        for (int c = 0; c < sched->n_copies; c++) {
+            ggml_backend_event_free(sched->events[b][c]);
+        }
+    }
     ggml_gallocr_free(sched->galloc);
     ggml_free(sched->ctx);
     free(sched->hash_set.keys);
     free(sched->tensor_backend_id);
     free(sched->tensor_copies);
     free(sched->node_backend_ids);
+    free(sched->leaf_backend_ids);
     free(sched);
 }
 
@@ -1569,34 +1758,63 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
     memset(sched->tensor_copies,      0, sizeof(sched->tensor_copies[0])     * hash_size);
 
     sched->is_reset = true;
+    sched->is_alloc = false;
 }
 
 bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
     ggml_backend_sched_split_graph(sched, measure_graph);
 
-    if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids)) {
+    // TODO: extract this to a separate function
+    if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
         return false;
     }
 
     ggml_backend_sched_reset(sched);
+    ggml_backend_sched_synchronize(sched);
+
+    return true;
+}
+
+bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS);
+
+    ggml_backend_sched_split_graph(sched, graph);
+
+    if (!ggml_backend_sched_alloc_splits(sched)) {
+        return false;
+    }
+
+    sched->is_alloc = true;
+
     return true;
 }
 
 enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
-    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
+    enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph);
+    ggml_backend_sched_synchronize(sched);
+    return err;
+}
 
-    if (!sched->is_reset) {
+enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    if (!sched->is_reset && !sched->is_alloc) {
         ggml_backend_sched_reset(sched);
     }
 
-    ggml_backend_sched_split_graph(sched, graph);
-    if (!ggml_backend_sched_alloc_splits(sched)) {
-        return GGML_STATUS_ALLOC_FAILED;
+    if (!sched->is_alloc) {
+        if (!ggml_backend_sched_alloc_graph(sched, graph)) {
+            return GGML_STATUS_ALLOC_FAILED;
+        }
     }
 
     return ggml_backend_sched_compute_splits(sched);
 }
 
+void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        ggml_backend_synchronize(sched->backends[i]);
+    }
+}
+
 void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
     sched->callback_eval = callback;
     sched->callback_eval_user_data = user_data;
@@ -1606,19 +1824,24 @@ int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
     return sched->n_splits;
 }
 
+int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
+    return sched->n_copies;
+}
+
 size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
     int backend_index = ggml_backend_sched_backend_id(sched, backend);
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+
     return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
 }
 
-void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
     int backend_index = ggml_backend_sched_backend_id(sched, backend);
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
     tensor_backend_id(node) = backend_index;
 }
 
-ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
     int backend_index = tensor_backend_id(node);
     if (backend_index == -1) {
         return NULL;
index 8bed22578a9076545a7e2fa54ab68f8f77e2962f..099d9c258794ed0b3ad460549ca446a10c875130 100644 (file)
@@ -9,6 +9,7 @@ extern "C" {
 
     typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
     typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    typedef struct ggml_backend_event * ggml_backend_event_t;
     typedef struct ggml_backend * ggml_backend_t;
     typedef void * ggml_backend_graph_plan_t;
 
@@ -72,11 +73,24 @@ extern "C" {
     GGML_API enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
     GGML_API enum ggml_status ggml_backend_graph_compute     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
 
+    GGML_API bool ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
     GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
 
     // tensor copy between different backends
     GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
-    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
+
+    // asynchronous copy
+    // the copy is performed after all the currently queued operations in backend_src
+    // backend_dst will wait for the copy to complete before performing other operations
+    // automatic fallback to sync copy if async is not supported
+    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    // events
+    GGML_API ggml_backend_event_t   ggml_backend_event_new        (ggml_backend_t backend);
+    GGML_API void                   ggml_backend_event_free       (ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_record     (ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_synchronize(ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_wait       (ggml_backend_t backend, ggml_backend_event_t event); // wait async on event
 
     //
     // CPU backend
@@ -123,27 +137,31 @@ extern "C" {
     /*
       Example usage:
 
-        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends);
-        // sched is initialized with measure allocators and cannot be used until allocated with a measure graph
+        // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be asigned
+        // preferrably to run on the same backend as the buffer
+        ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
 
-        // initialize buffers from a measure graph
-        measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed
+        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
 
-        // in build_graph:
-        build_graph(...) {
-            // manually assign nodes to a backend (optional, should not be needed in most cases)
-            struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
-            ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
-        }
+        // initialize buffers from a max size graph (optional)
+        reserve_graph = build_graph(sched, max_batch_size);
 
-        // allocate backend buffers from measure graph
-        ggml_backend_sched_init_measure(sched, measure_graph);
+        // manually assign nodes to a backend (optional, should not be needed in most cases)
+        struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
+        ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
 
-        // the scheduler is now ready to compute graphs
+        ggml_backend_sched_reserve(sched, reserve_graph);
 
         // compute
         graph = build_graph(sched);
         ggml_backend_sched_graph_compute(sched, graph);
+
+        // if there are graph inputs:
+        ggml_backend_sched_reset(sched);
+        ggml_backend_sched_alloc_graph(sched, graph);
+        ggml_backend_tensor_set(input_tensor, ...);
+        ggml_backend_sched_graph_compute(sched, graph);
+    }
     */
 
     struct ggml_backend_sched;
@@ -158,20 +176,26 @@ extern "C" {
     typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
 
     // Initialize a backend scheduler
-    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
+    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
+
     // Initialize backend buffers from a measure graph
     GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+
     // Get the number of splits of the last graph
     GGML_API int                  ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
+    GGML_API int                  ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
 
     GGML_API size_t               ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
 
-    GGML_API void                 ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
-    GGML_API ggml_backend_t       ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
+    GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
 
     // Allocate and compute graph on the backend scheduler
+    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
 
     // Reset all assignments and allocators - must be called before changing the node backends
     GGML_API void                 ggml_backend_sched_reset(ggml_backend_sched_t sched);
index b8834ed05eed0230c70ea3d4659795862c17c683..d1b5e52ba901113d3b829604924f545427364009 100644 (file)
@@ -72,6 +72,7 @@
 #define cudaEventCreateWithFlags hipEventCreateWithFlags
 #define cudaEventDisableTiming hipEventDisableTiming
 #define cudaEventRecord hipEventRecord
+#define cudaEventSynchronize hipEventSynchronize
 #define cudaEvent_t hipEvent_t
 #define cudaEventDestroy hipEventDestroy
 #define cudaFree hipFree
@@ -81,6 +82,7 @@
 #define cudaGetDeviceProperties hipGetDeviceProperties
 #define cudaGetErrorString hipGetErrorString
 #define cudaGetLastError hipGetLastError
+#define cudaLaunchHostFunc hipLaunchHostFunc
 #ifdef GGML_HIP_UMA
 #define cudaMalloc hipMallocManaged
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
 #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
 #define cudaStreamFireAndForget hipStreamFireAndForget
 #define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamPerThread hipStreamPerThread
 #define cudaStreamSynchronize hipStreamSynchronize
 #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
 #define cudaStream_t hipStream_t
@@ -10641,8 +10644,20 @@ GGML_CALL void ggml_cuda_get_device_description(int device, char * description,
 #define UNUSED GGML_UNUSED
 
 struct ggml_backend_cuda_context {
+    explicit ggml_backend_cuda_context(int device) :
+        device(device),
+        name(GGML_CUDA_NAME + std::to_string(device)) {
+    }
+
+    ~ggml_backend_cuda_context() {
+        if (copy_event != nullptr) {
+            CUDA_CHECK(cudaEventDestroy(copy_event));
+        }
+    }
+
     int device;
     std::string name;
+    cudaEvent_t copy_event = nullptr;
 };
 
 // cuda buffer
@@ -10732,9 +10747,8 @@ GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t
     ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 
     ggml_cuda_set_device(ctx->device);
-    CUDA_CHECK(cudaDeviceSynchronize());
-    CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
-    CUDA_CHECK(cudaDeviceSynchronize());
+    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 }
 
 GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -10743,26 +10757,25 @@ GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t
     ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 
     ggml_cuda_set_device(ctx->device);
-    CUDA_CHECK(cudaDeviceSynchronize());
-    CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost));
-    CUDA_CHECK(cudaDeviceSynchronize());
+    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 }
 
 GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
     if (ggml_backend_buffer_is_cuda(src->buffer)) {
         ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
-        ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
-
-        ggml_cuda_set_device(src_ctx->device);
-        CUDA_CHECK(cudaDeviceSynchronize());
-        ggml_cuda_set_device(dst_ctx->device);
-        CUDA_CHECK(cudaDeviceSynchronize());
-        CUDA_CHECK(cudaMemcpy((char *)dst->data, (const char *)src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice));
-        CUDA_CHECK(cudaDeviceSynchronize());
-
+        ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
+        if (src_ctx->device == dst_ctx->device) {
+            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
+        } else {
+            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
+        }
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
         return true;
     }
     return false;
+
+    UNUSED(buffer);
 }
 
 GGML_CALL static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -11007,7 +11020,11 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buf
         }
 
         const char * buf_host = (const char *)data + offset_split;
-        CUDA_CHECK(cudaMemcpy(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice));
+        CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+    }
+
+    for (int id = 0; id < g_device_count; ++id) {
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
     }
 }
 
@@ -11041,7 +11058,11 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buf
         }
 
         char * buf_host = (char *)data + offset_split;
-        CUDA_CHECK(cudaMemcpy(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost));
+        CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+    }
+
+    for (int id = 0; id < g_device_count; ++id) {
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
     }
 }
 
@@ -11220,6 +11241,10 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
     return &ggml_backend_cuda_buffer_type_host;
 }
 
+//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {
+//    return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
+//}
+
 // backend
 
 GGML_CALL static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
@@ -11243,8 +11268,9 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer
 
 GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
-    GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
     GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
 
     CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0]));
@@ -11252,22 +11278,61 @@ GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend,
 
 GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
-    GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
     GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
 
     CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0]));
 }
 
-GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_backend_is_cuda(backend_src) || ggml_backend_is_cuda(backend_dst));
 
-    if (dst->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && ggml_backend_buffer_is_cuda(src->buffer)) {
-        CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, g_cudaStreams[cuda_ctx->device][0]));
-        return true;
+    ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
+    ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+    if (!ggml_backend_buffer_is_cuda(src->buffer)) {
+        return false;
     }
 
-    return false;
+    if (!ggml_backend_buffer_is_cuda(dst->buffer)) {
+        return false;
+    }
+
+    // device -> device
+    ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
+    ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
+
+    if (backend_src != backend_dst) {
+        ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
+        ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
+
+        GGML_ASSERT(cuda_ctx_src->device == buf_ctx_src->device);
+        GGML_ASSERT(cuda_ctx_dst->device == buf_ctx_dst->device);
+
+        if (!cuda_ctx_src->copy_event) {
+            ggml_cuda_set_device(cuda_ctx_src->device);
+            CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
+        }
+
+        // copy on src stream
+        if (cuda_ctx_src->device == cuda_ctx_dst->device) {
+            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, g_cudaStreams[cuda_ctx_dst->device][0]));
+        } else {
+            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), g_cudaStreams[cuda_ctx_src->device][0]));
+        }
+
+        // record event on src stream
+        CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, g_cudaStreams[cuda_ctx_src->device][0]));
+
+        // wait on dst stream for the copy to complete
+        CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[cuda_ctx_dst->device][0], cuda_ctx_src->copy_event, 0));
+    } else {
+        // src and dst are on the same backend
+        CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, g_cudaStreams[cuda_ctx_dst->device][0]));
+    }
+    return true;
 }
 
 GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
@@ -11444,6 +11509,52 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
     UNUSED(backend);
 }
 
+static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    ggml_cuda_set_device(cuda_ctx->device);
+
+    cudaEvent_t event;
+    CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
+
+    return new ggml_backend_event {
+        /* .backend = */ backend,
+        /* .context = */ event,
+    };
+}
+
+static void ggml_backend_cuda_event_free(ggml_backend_event_t event) {
+    CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));
+
+    delete event;
+}
+
+static void ggml_backend_cuda_event_record(ggml_backend_event_t event) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)event->backend->context;
+
+    CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, g_cudaStreams[cuda_ctx->device][0]));
+}
+
+static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    if (ggml_backend_is_cuda(event->backend)) {
+        CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[cuda_ctx->device][0], (cudaEvent_t)event->context, 0));
+    } else {
+        // untested
+        auto wait_fn = [](void * user_data) {
+            ggml_backend_event_t event = (ggml_backend_event_t)user_data;
+            ggml_backend_event_synchronize(event);
+        };
+
+        CUDA_CHECK(cudaLaunchHostFunc(g_cudaStreams[cuda_ctx->device][0], wait_fn, event));
+    }
+}
+
+static void ggml_backend_cuda_event_synchronize(ggml_backend_event_t event) {
+    CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
+}
+
 static ggml_backend_i ggml_backend_cuda_interface = {
     /* .get_name                = */ ggml_backend_cuda_name,
     /* .free                    = */ ggml_backend_cuda_free,
@@ -11457,6 +11568,11 @@ static ggml_backend_i ggml_backend_cuda_interface = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
     /* .supports_op             = */ ggml_backend_cuda_supports_op,
+    /* .event_new               = */ ggml_backend_cuda_event_new,
+    /* .event_free              = */ ggml_backend_cuda_event_free,
+    /* .event_record            = */ ggml_backend_cuda_event_record,
+    /* .event_wait              = */ ggml_backend_cuda_event_wait,
+    /* .event_synchronize       = */ ggml_backend_cuda_event_synchronize,
 };
 
 static ggml_guid_t ggml_backend_cuda_guid() {
@@ -11475,10 +11591,11 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
     // not strictly necessary, but it may reduce the overhead of the first graph_compute
     ggml_cuda_set_main_device(device);
 
-    ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context {
-        /* .device = */ device,
-        /* .name   = */ GGML_CUDA_NAME + std::to_string(device),
-    };
+    ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
+    if (ctx == nullptr) {
+        fprintf(stderr, "%s: error: failed to allocate context\n", __func__);
+        return nullptr;
+    }
 
     ggml_backend_t cuda_backend = new ggml_backend {
         /* .guid      = */ ggml_backend_cuda_guid(),
index 83a7822fdbe9d7eb08bfb0de621356e3bea38ec9..4caf2c9e78b0281ffe19e0caa687f31c3c5f91c5 100644 (file)
@@ -1951,6 +1951,11 @@ static struct ggml_backend_i kompute_backend_i = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_kompute_graph_compute,
     /* .supports_op             = */ ggml_backend_kompute_supports_op,
+    /* .event_new               = */ NULL,
+    /* .event_free              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .event_synchronize       = */ NULL,
 };
 
 static ggml_guid_t ggml_backend_kompute_guid() {
index 1825d3320ee7a035bfcc2f3db9a4ce2cff69775e..3a5476c52f1a52954eb623658feffdc03d8ca47c 100644 (file)
@@ -2820,6 +2820,11 @@ static struct ggml_backend_i ggml_backend_metal_i = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_metal_graph_compute,
     /* .supports_op             = */ ggml_backend_metal_supports_op,
+    /* .event_new               = */ NULL,
+    /* .event_free              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .event_synchronize       = */ NULL,
 };
 
 void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
index c2ab13034ba855a11dacadf2e88a24ba1550db12..9f6506383cc0dde78de71e72b455307386eff042 100644 (file)
@@ -17249,13 +17249,18 @@ static ggml_backend_i ggml_backend_sycl_interface = {
     /* .get_default_buffer_type = */ ggml_backend_sycl_get_default_buffer_type,
     /* .set_tensor_async        = */ ggml_backend_sycl_set_tensor_async,
     /* .get_tensor_async        = */ ggml_backend_sycl_get_tensor_async,
-    /* .cpy_tensor_async        = */ ggml_backend_sycl_cpy_tensor_async,
+    /* .cpy_tensor_async        = */ NULL, //ggml_backend_sycl_cpy_tensor_async, // TODO: update for the new interface
     /* .synchronize             = */ ggml_backend_sycl_synchronize,
     /* .graph_plan_create       = */ NULL,
     /* .graph_plan_free         = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_sycl_graph_compute,
     /* .supports_op             = */ ggml_backend_sycl_supports_op,
+    /* .event_new               = */ NULL,
+    /* .event_free              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .event_synchronize       = */ NULL,
 };
 
 static ggml_guid_t ggml_backend_sycl_guid() {
index d41aa7d22f096dc874a36aab9273b742ac8955cf..7cce616ba714fd6ca5ceb591365fcce32fe1f1f7 100644 (file)
@@ -5693,6 +5693,11 @@ static ggml_backend_i ggml_backend_vk_interface = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_vk_graph_compute,
     /* .supports_op             = */ ggml_backend_vk_supports_op,
+    /* .event_new               = */ NULL,
+    /* .event_free              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .event_synchronize       = */ NULL,
 };
 
 static ggml_guid_t ggml_backend_vk_guid() {
diff --git a/ggml.c b/ggml.c
index 9a7bd1d8c527b19168e4769220eb9b847a45e165..fbc66f65b105214748f3d5e9020f52ffa3cd782f 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -11560,8 +11560,6 @@ static void ggml_compute_forward_get_rows_q(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    assert(params->ith == 0);
-
     if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
         return;
     }
@@ -11569,7 +11567,7 @@ static void ggml_compute_forward_get_rows_q(
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
+    const int64_t nr = ggml_nelements(src1);
 
     const enum ggml_type type = src0->type;
     ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
@@ -11579,17 +11577,25 @@ static void ggml_compute_forward_get_rows_q(
     assert(nb00 == ggml_type_size(type));
     assert(ggml_nrows(dst) == nr);
 
-    // TODO: multi-thread
-    for (int64_t i12 = 0; i12 < ne12; ++i12) {
-        for (int64_t i11 = 0; i11 < ne11; ++i11) {
-            for (int64_t i10 = 0; i10 < ne10; ++i10) {
-                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+    const int ith = params->ith;
+    const int nth = params->nth;
 
-                dequantize_row_q(
-                        (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
-                             (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
-            }
-        }
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        dequantize_row_q(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
     }
 }
 
@@ -11600,8 +11606,6 @@ static void ggml_compute_forward_get_rows_f16(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    assert(params->ith == 0);
-
     if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
         return;
     }
@@ -11609,24 +11613,32 @@ static void ggml_compute_forward_get_rows_f16(
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
+    const int64_t nr = ggml_nelements(src1);
 
     assert(ne0  == nc);
     assert(ne02 == ne11);
     assert(nb00 == sizeof(ggml_fp16_t));
     assert(ggml_nrows(dst) == nr);
 
-    // TODO: multi-thread
-    for (int64_t i12 = 0; i12 < ne12; ++i12) {
-        for (int64_t i11 = 0; i11 < ne11; ++i11) {
-            for (int64_t i10 = 0; i10 < ne10; ++i10) {
-                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
 
-                ggml_fp16_to_fp32_row(
-                        (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
-                             (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
-            }
-        }
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        ggml_fp16_to_fp32_row(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
     }
 }
 
@@ -11637,8 +11649,6 @@ static void ggml_compute_forward_get_rows_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    assert(params->ith == 0);
-
     if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
         return;
     }
@@ -11646,24 +11656,32 @@ static void ggml_compute_forward_get_rows_f32(
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t nc = ne00;
-    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
+    const int64_t nr = ggml_nelements(src1);
 
     assert(ne0  == nc);
     assert(ne02 == ne11);
     assert(nb00 == sizeof(float));
     assert(ggml_nrows(dst) == nr);
 
-    // TODO: multi-thread
-    for (int64_t i12 = 0; i12 < ne12; ++i12) {
-        for (int64_t i11 = 0; i11 < ne11; ++i11) {
-            for (int64_t i10 = 0; i10 < ne10; ++i10) {
-                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+    const int ith = params->ith;
+    const int nth = params->nth;
 
-                ggml_vec_cpy_f32(nc,
-                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
-                        (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
-            }
-        }
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
+                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
     }
 }
 
@@ -17796,7 +17814,7 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
     node->perf_time_us += time_us_cur;
 }
 
-static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
     int n_tasks = 0;
 
     switch (node->op) {
@@ -17877,6 +17895,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = n_threads;
             } break;
+        case GGML_OP_GET_ROWS:
+            {
+                // FIXME: the cost of launching additional threads decreases performance with GPU offloading
+                //n_tasks = MIN(n_threads, ggml_nelements(node->src[1]));
+                n_tasks = MIN(n_cur_threads, ggml_nelements(node->src[1]));
+            } break;
         case GGML_OP_SCALE:
         case GGML_OP_SET:
         case GGML_OP_CONT:
@@ -17884,7 +17908,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
-        case GGML_OP_GET_ROWS:
         case GGML_OP_GET_ROWS_BACK:
         case GGML_OP_DIAG:
             {
@@ -18102,7 +18125,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 /* FINALIZE */
                 struct ggml_tensor * node = cgraph->nodes[node_n];
                 if (GGML_OP_HAS_FINALIZE[node->op]) {
-                    params.nth = ggml_get_n_tasks(node, n_threads);
+                    params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
                     ggml_compute_forward(&params, node);
                 }
                 ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -18112,7 +18135,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
             while (++node_n < cgraph->n_nodes) {
                 GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
                 struct ggml_tensor * node = cgraph->nodes[node_n];
-                const int n_tasks = ggml_get_n_tasks(node, n_threads);
+                const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
 
                 state->shared->perf_node_start_cycles  = ggml_perf_cycles();
                 state->shared->perf_node_start_time_us = ggml_perf_time_us();
@@ -18160,7 +18183,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
         /* INIT & COMPUTE */
         struct ggml_tensor * node = cgraph->nodes[node_n];
-        const int n_tasks = ggml_get_n_tasks(node, n_threads);
+        const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
 
         struct ggml_compute_params params = {
             /*.type  =*/ GGML_TASK_TYPE_INIT,
@@ -18225,7 +18248,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
-        const int n_tasks = ggml_get_n_tasks(node, n_threads);
+        const int n_tasks = ggml_get_n_tasks(node, n_threads, 1);
 
         max_tasks = MAX(max_tasks, n_tasks);
 
index ad7b7b7d4bcf24ece975ed6f87cb212ff422f12a..38e7036a72720dada20b54a48719f530462c2182 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -978,21 +978,6 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
     }
 }
 
-//
-// ggml helpers
-//
-
-static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
-    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
-
-    if (plan.work_size > 0) {
-        buf.resize(plan.work_size);
-        plan.work_data = buf.data();
-    }
-
-    ggml_graph_compute(graph, &plan);
-}
-
 //
 // llama helpers
 //
@@ -1728,6 +1713,7 @@ struct llama_hparams {
 struct llama_cparams {
     uint32_t n_ctx;           // context size used during inference
     uint32_t n_batch;
+    uint32_t n_ubatch;
     uint32_t n_threads;       // number of threads to use for generation
     uint32_t n_threads_batch; // number of threads to use for batch processing
 
@@ -2024,8 +2010,7 @@ struct llama_context {
         ggml_vk_free_cpu_assist();
 #endif
 
-        ggml_backend_buffer_free(buf_input);
-        ggml_free(ctx_input);
+        ggml_backend_buffer_free(buf_output);
     }
 
     llama_cparams cparams;
@@ -2051,12 +2036,20 @@ struct llama_context {
     int64_t t_p_eval_us = 0;
     int64_t t_eval_us   = 0;
 
+    int64_t t_compute_start_us = 0;
+    int64_t n_queued_tokens = 0;
+
     int32_t n_sample = 0; // number of tokens sampled
     int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
     int32_t n_eval   = 0; // number of eval calls
 
-    // logits output (2-dimensional array: [n_tokens][n_vocab])
-    std::vector<float> logits;
+    // host buffer for the model output (logits and embeddings)
+    ggml_backend_buffer_t buf_output = nullptr;
+
+    // decode output (2-dimensional array: [n_tokens][n_vocab])
+    size_t logits_size = 0;
+    float * logits = nullptr;
+
 #ifndef NDEBUG
     // guard against access to unset logits
     std::vector<bool>  logits_valid;
@@ -2065,7 +2058,8 @@ struct llama_context {
 
     // embeddings output (2-dimensional array: [n_tokens][n_embd])
     // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
-    std::vector<float> embd;
+    size_t embd_size = 0;
+    float * embd = nullptr;
 
     // sequence embeddings output (map of [n_embd] vectors)
     // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
@@ -2079,8 +2073,6 @@ struct llama_context {
     void *              abort_callback_data = nullptr;
 
     // input tensors
-    ggml_backend_buffer_t buf_input = nullptr;
-    ggml_context * ctx_input = nullptr;
     struct ggml_tensor * inp_tokens;    // I32 [n_batch]
     struct ggml_tensor * inp_embd;      // F32 [n_embd, n_batch]
     struct ggml_tensor * inp_pos;       // I32 [n_batch]
@@ -2090,7 +2082,7 @@ struct llama_context {
     struct ggml_tensor * inp_mean;      // F32 [n_batch, n_batch]
     struct ggml_tensor * inp_cls;       // I32 [n_batch]
     struct ggml_tensor * inp_s_copy;    // I32 [kv_size]
-    struct ggml_tensor * inp_s_mask;    // F32 [kv_size]
+    struct ggml_tensor * inp_s_mask;    // F32 [1, kv_size]
     struct ggml_tensor * inp_s_seq;     // I32 [kv_size, n_batch]
 
 #ifdef GGML_USE_MPI
@@ -4005,6 +3997,7 @@ static bool llm_load_tensors(
 
     // there is very little benefit to offloading the input layer, so always keep it on the CPU
     model.buft_input = llama_default_buffer_type_cpu(true);
+    //model.buft_input = llama_default_buffer_type_offload(main_gpu);
 
     model.buft_layer.resize(n_layer);
 
@@ -5094,29 +5087,32 @@ enum llm_norm_type {
 
 static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
+       struct llama_context & lctx,
         const llama_hparams & hparams,
           const llama_batch & batch,
          struct ggml_tensor * tok_embd,
-         struct ggml_tensor * inp_tokens,
-         struct ggml_tensor * inp_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
 
     struct ggml_tensor * inpL;
 
     if (batch.token) {
-        struct ggml_tensor * inp_tokens_v = ggml_view_1d(ctx, inp_tokens, batch.n_tokens, 0);
-        cb(inp_tokens, "inp_tokens", -1);
+        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
+        cb(lctx.inp_tokens, "inp_tokens", -1);
+        ggml_set_input(lctx.inp_tokens);
 
-        inpL = ggml_get_rows(ctx, tok_embd, inp_tokens_v);
+        inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
     } else {
 #ifdef GGML_USE_MPI
         GGML_ASSERT(false && "not implemented");
 #endif
-
-        inpL = ggml_view_2d(ctx, inp_embd, n_embd, batch.n_tokens, inp_embd->nb[1], 0);
+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        inpL = lctx.inp_embd;
+        ggml_set_input(lctx.inp_embd);
     }
 
+    cb(inpL, "inp_embd", -1);
+
     return inpL;
 }
 
@@ -5420,7 +5416,7 @@ static struct ggml_tensor * llm_build_kv(
 
 struct llm_build_context {
     const llama_model    & model;
-    const llama_context  & lctx;
+          llama_context  & lctx;
     const llama_hparams  & hparams;
     const llama_cparams  & cparams;
     const llama_batch    & batch;
@@ -5513,6 +5509,18 @@ struct llm_build_context {
         };
 
         ctx0 = ggml_init(params);
+
+        lctx.inp_tokens = nullptr;
+        lctx.inp_embd = nullptr;
+        lctx.inp_pos = nullptr;
+        lctx.inp_KQ_mask = nullptr;
+        lctx.inp_KQ_pos = nullptr;
+        lctx.inp_K_shift = nullptr;
+        lctx.inp_mean = nullptr;
+        lctx.inp_cls = nullptr;
+        lctx.inp_s_copy = nullptr;
+        lctx.inp_s_mask = nullptr;
+        lctx.inp_s_seq = nullptr;
     }
 
     void free() {
@@ -5527,6 +5535,10 @@ struct llm_build_context {
 
         GGML_ASSERT(kv_self.size == n_ctx);
 
+        lctx.inp_K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
+        cb(lctx.inp_K_shift, "K_shift", -1);
+        ggml_set_input(lctx.inp_K_shift);
+
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * tmp =
                 // we rotate only the first n_rot dimensions
@@ -5550,12 +5562,14 @@ struct llm_build_context {
 
         GGML_ASSERT(kv_self.recurrent);
 
+        struct ggml_tensor * state_copy = build_inp_s_copy();
+
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
             struct ggml_tensor * ssm_states  = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
 
-            conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
-            ssm_states  = ggml_get_rows(ctx0,  ssm_states, lctx.inp_s_copy);
+            conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
+            ssm_states  = ggml_get_rows(ctx0,  ssm_states, state_copy);
 
             // TODO: name the intermediate tensors with cb()
 
@@ -5615,6 +5629,66 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_tensor * build_inp_pos() {
+        lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(lctx.inp_pos, "inp_pos", -1);
+        ggml_set_input(lctx.inp_pos);
+        return lctx.inp_pos;
+    }
+
+    struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
+        if (causal) {
+            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
+        } else {
+            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+        }
+        cb(lctx.inp_KQ_mask, "KQ_mask", -1);
+        ggml_set_input(lctx.inp_KQ_mask);
+        return lctx.inp_KQ_mask;
+    }
+
+    struct ggml_tensor * build_inp_KQ_pos() {
+        lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
+        cb(lctx.inp_KQ_pos, "KQ_pos", -1);
+        ggml_set_input(lctx.inp_KQ_pos);
+        return lctx.inp_KQ_pos;
+    }
+
+    struct ggml_tensor * build_inp_mean() {
+        lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+        cb(lctx.inp_mean, "inp_mean", -1);
+        ggml_set_input(lctx.inp_mean);
+        return lctx.inp_mean;
+    }
+
+    struct ggml_tensor * build_inp_cls() {
+        lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(lctx.inp_cls, "inp_cls", -1);
+        ggml_set_input(lctx.inp_cls);
+        return lctx.inp_cls;
+    }
+
+    struct ggml_tensor * build_inp_s_copy() {
+        lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
+        cb(lctx.inp_s_copy, "inp_s_copy", -1);
+        ggml_set_input(lctx.inp_s_copy);
+        return lctx.inp_s_copy;
+    }
+
+    struct ggml_tensor * build_inp_s_mask() {
+        lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
+        cb(lctx.inp_s_mask, "inp_s_mask", -1);
+        ggml_set_input(lctx.inp_s_mask);
+        return lctx.inp_s_mask;
+    }
+
+    struct ggml_tensor * build_inp_s_seq() {
+        lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
+        cb(lctx.inp_s_seq, "inp_s_seq", -1);
+        ggml_set_input(lctx.inp_s_seq);
+        return lctx.inp_s_seq;
+    }
+
     struct ggml_cgraph * build_llama() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -5625,16 +5699,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -5686,7 +5757,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -5804,20 +5874,16 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
-        cb(KQ_pos, "KQ_pos", -1);
+        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -5865,7 +5931,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -5921,16 +5986,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * attn_norm;
@@ -5984,7 +6046,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = cur;
@@ -6035,21 +6096,17 @@ struct llm_build_context {
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
         struct ggml_tensor * cur;
-        struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
-        pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+        struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
 
         inpL = ggml_add(ctx0, inpL, pos);
@@ -6083,7 +6140,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             // add the input
@@ -6135,16 +6191,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * residual = inpL;
@@ -6284,7 +6337,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
@@ -6338,16 +6390,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
-        cb(KQ_pos, "KQ_pos", -1);
+        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -6377,7 +6426,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -6433,15 +6481,12 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        // get input vectors with right size
-        const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
-
-        struct ggml_tensor * inp_pos  = ggml_view_1d(ctx0, lctx.inp_pos,  n_tokens, 0);
-        struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
-        struct ggml_tensor * inp_cls  = ggml_view_1d(ctx0, lctx.inp_cls,  n_tokens, 0);
+        struct ggml_tensor * inp_pos  = build_inp_pos();
+        struct ggml_tensor * inp_mean = build_inp_mean();
+        struct ggml_tensor * inp_cls  = build_inp_cls();
 
         // construct input embeddings (token, type, position)
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // token types are hardcoded to zero ("Sentence A")
         struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
@@ -6456,8 +6501,7 @@ struct llm_build_context {
         cb(inpL, "inp_norm", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_cont(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_tokens, n_tokens, n_tokens*ggml_type_size(lctx.inp_KQ_mask->type), 0));
-        cb(KQ_mask, "KQ_mask", -1); // [n_tokens, n_tokens]
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
 
         // iterate layers
         for (int il = 0; il < n_layer; ++il) {
@@ -6619,16 +6663,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
-        cb(KQ_pos, "KQ_pos", -1);
+        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
 
         inpL = llm_build_norm(ctx0, inpL, hparams,
                 model.tok_norm,
@@ -6664,7 +6705,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             // Add the input
@@ -6716,16 +6756,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
-        cb(KQ_pos, "KQ_pos", -1);
+        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * attn_norm;
@@ -6766,7 +6803,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             // Add the input
@@ -6821,16 +6857,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -6883,7 +6916,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -6939,16 +6971,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -6993,7 +7022,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -7048,16 +7076,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -7109,7 +7134,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -7164,16 +7188,13 @@ struct llm_build_context {
         struct ggml_tensor * ffn_output;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
@@ -7231,7 +7252,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             // FF
@@ -7281,16 +7301,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
 
@@ -7329,7 +7346,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
             struct ggml_tensor * sa_out = cur;
 
@@ -7383,16 +7399,13 @@ struct llm_build_context {
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -7428,7 +7441,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             // add the input
@@ -7481,16 +7493,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             cur = llm_build_norm(ctx0, inpL, hparams,
@@ -7532,7 +7541,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             // add the input
@@ -7584,16 +7592,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -7645,7 +7650,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -7698,16 +7702,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -7759,7 +7760,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@@ -7821,20 +7821,17 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // scale the input embeddings
         inpL = ggml_scale(ctx0, inpL, scale_embd);
         cb(inpL, "inp_scaled", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -7886,7 +7883,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             // scale_res - scale the hidden states for residual connection
@@ -7953,22 +7949,18 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
         cb(inpL, "inp_scaled", -1);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
-
             // norm
             cur = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm, NULL,
@@ -8005,7 +7997,6 @@ struct llm_build_context {
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-                cb(cur, "kqv_out", il);
             }
 
             struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
@@ -8060,16 +8051,13 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        cb(inp_pos, "inp_pos", -1);
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1);
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
@@ -8178,11 +8166,10 @@ struct llm_build_context {
         struct ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
-        cb(inpL, "inp_embd", -1);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
-        struct ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
-        struct ggml_tensor * state_seq  = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
+        struct ggml_tensor * state_mask = build_inp_s_mask();
+        struct ggml_tensor * state_seq  = build_inp_s_seq();
 
         for (int il = 0; il < n_layer; ++il) {
             // (ab)using the KV cache to store the states
@@ -8234,7 +8221,7 @@ struct llm_build_context {
                 ggml_build_forward_expand(gf,
                     ggml_cpy(ctx0,
                         ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
-                        ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
+                        ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
 
                 // extract x from x_conv
                 x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
@@ -8268,7 +8255,7 @@ struct llm_build_context {
                 ggml_build_forward_expand(gf,
                     ggml_cpy(ctx0,
                         ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
-                        ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states))));
+                        ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states))));
 
                 struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
 
@@ -8372,7 +8359,18 @@ static struct ggml_cgraph * llama_build_graph(
         if (!lctx.cparams.offload_kqv) {
             if (strcmp(name, "kqv_merged_cont") == 0) {
                 // all nodes between the KV store and the attention output are run on the CPU
-                ggml_backend_sched_set_node_backend(lctx.sched, cur, lctx.backend_cpu);
+                ggml_backend_sched_set_tensor_backend(lctx.sched, cur, lctx.backend_cpu);
+            }
+        }
+
+        // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
+        // to fix this, we assign the norm layer manually to the backend of its layer
+        if (il != -1 && strcmp(name, "norm") == 0) {
+            for (auto * backend : lctx.backends) {
+                if (ggml_backend_buft_supports_backend(lctx.model.buft_layer[il].buft, backend)) {
+                    ggml_backend_sched_set_tensor_backend(lctx.sched, cur, backend);
+                    break;
+                }
             }
         }
     };
@@ -8528,7 +8526,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
     }
 
-    if (batch.pos) {
+    if (batch.pos && lctx.inp_pos) {
         const int64_t n_tokens = batch.n_tokens;
 
         ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
@@ -8539,61 +8537,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         "non-causal attention with generative models is not supported"
     );
 
-    // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
-    if (cparams.causal_attn) {
-        const int64_t n_kv     = kv_self.n;
-        const int64_t n_tokens = batch.n_tokens;
+    if (lctx.inp_KQ_mask) {
+        // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
+        if (cparams.causal_attn) {
+            const int64_t n_kv     = kv_self.n;
+            const int64_t n_tokens = batch.n_tokens;
 
-        assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
 
-        float * data = (float *) lctx.inp_KQ_mask->data;
+            float * data = (float *) lctx.inp_KQ_mask->data;
 
-        // For causal attention, use only the previous KV cells
-        // of the correct sequence for each token of the batch.
-        // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                const llama_pos    pos    = batch.pos[j];
-                const llama_seq_id seq_id = batch.seq_id[j][0];
+            // For causal attention, use only the previous KV cells
+            // of the correct sequence for each token of the batch.
+            // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
+            for (int h = 0; h < 1; ++h) {
+                for (int j = 0; j < n_tokens; ++j) {
+                    const llama_pos    pos    = batch.pos[j];
+                    const llama_seq_id seq_id = batch.seq_id[j][0];
 
-                for (int i = 0; i < n_kv; ++i) {
-                    float f;
-                    if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
-                        f = -INFINITY;
-                    } else {
-                        f = 0.0f;
+                    for (int i = 0; i < n_kv; ++i) {
+                        float f;
+                        if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
+                            f = -INFINITY;
+                        } else {
+                            f = 0.0f;
+                        }
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
                     }
-                    data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
                 }
             }
-        }
-    } else {
-        // when using kv cache, the mask needs to match the kv cache size
-        const int64_t n_tokens = batch.n_tokens;
-        const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
+        } else {
+            // when using kv cache, the mask needs to match the kv cache size
+            const int64_t n_tokens = batch.n_tokens;
+            const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
 
-        assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
 
-        float * data = (float *) lctx.inp_KQ_mask->data;
+            float * data = (float *) lctx.inp_KQ_mask->data;
 
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                const llama_seq_id seq_id = batch.seq_id[j][0];
+            for (int h = 0; h < 1; ++h) {
+                for (int j = 0; j < n_tokens; ++j) {
+                    const llama_seq_id seq_id = batch.seq_id[j][0];
 
-                for (int i = 0; i < n_tokens; ++i) {
-                    float f = -INFINITY;
-                    for (int s = 0; s < batch.n_seq_id[i]; ++s) {
-                        if (batch.seq_id[i][s] == seq_id) {
-                            f = 0.0f;
-                            break;
+                    for (int i = 0; i < n_tokens; ++i) {
+                        float f = -INFINITY;
+                        for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+                            if (batch.seq_id[i][s] == seq_id) {
+                                f = 0.0f;
+                                break;
+                            }
                         }
-                    }
 
-                    data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
-                }
+                        data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
+                    }
 
-                for (int i = n_tokens; i < n_stride; ++i) {
-                    data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
+                    for (int i = n_tokens; i < n_stride; ++i) {
+                        data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
+                    }
                 }
             }
         }
@@ -8602,7 +8602,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     if (hparams.need_kq_pos) {
         const int64_t n_kv = kv_self.n;
 
-        assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
+        GGML_ASSERT(lctx.inp_KQ_pos);
+        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
 
         float * data = (float *) lctx.inp_KQ_pos->data;
 
@@ -8614,6 +8615,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens = batch.n_tokens;
 
+        GGML_ASSERT(lctx.inp_mean);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
 
         float * data = (float *) lctx.inp_mean->data;
@@ -8645,6 +8647,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
         const int64_t n_tokens = batch.n_tokens;
 
+        GGML_ASSERT(lctx.inp_cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
 
         uint32_t * data = (uint32_t *) lctx.inp_cls->data;
@@ -8665,7 +8668,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     if (kv_self.recurrent) {
         const int64_t n_kv = kv_self.n;
 
-        {
+        if (lctx.inp_s_mask) {
             GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
             float * data = (float *) lctx.inp_s_mask->data;
 
@@ -8687,7 +8690,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         // update the correct state(s)/sequence(s) for each token of the batch.
         // Like with the KQ_mask, if a token in the batch has multiple sequences,
         // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
-        {
+        if (lctx.inp_s_seq) {
             const int64_t n_tokens = batch.n_tokens;
 
             GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
@@ -8730,7 +8733,7 @@ static void llama_graph_compute(
         ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
     }
 
-    ggml_backend_sched_graph_compute(lctx.sched, gf);
+    ggml_backend_sched_graph_compute_async(lctx.sched, gf);
 
     // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
 
@@ -8750,10 +8753,11 @@ static void llama_graph_compute(
 //
 static int llama_decode_internal(
          llama_context & lctx,
-           llama_batch   batch) {
-    const uint32_t n_tokens = batch.n_tokens;
+           llama_batch   batch_all) { // TODO: rename back to batch
+
+    const uint32_t n_tokens_all = batch_all.n_tokens;
 
-    if (n_tokens == 0) {
+    if (n_tokens_all == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
         return -1;
     }
@@ -8762,14 +8766,16 @@ static int llama_decode_internal(
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
-    const auto n_batch = cparams.n_batch;
+    GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
 
-    GGML_ASSERT(n_tokens <= n_batch);
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+    GGML_ASSERT(n_tokens_all <= cparams.n_batch);
 
-    int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+    GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
 
-    const int64_t t_start_us = ggml_time_us();
+    if (lctx.t_compute_start_us == 0) {
+        lctx.t_compute_start_us = ggml_time_us();
+    }
+    lctx.n_queued_tokens += n_tokens_all;
 
 #ifdef GGML_USE_MPI
     // TODO: needs fix after #3228
@@ -8777,272 +8783,274 @@ static int llama_decode_internal(
     //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
 #endif
 
-    GGML_ASSERT(n_threads > 0);
-
     auto & kv_self = lctx.kv_self;
 
     const int64_t n_embd  = hparams.n_embd;
     const int64_t n_vocab = hparams.n_vocab;
 
-    // helpers for smoother batch API transition
-    // after deprecating the llama_eval calls, these will be removed
-    std::vector<llama_pos> pos;
 
-    std::vector<int32_t>                   n_seq_id;
-    std::vector<llama_seq_id *>            seq_id_arr;
-    std::vector<std::vector<llama_seq_id>> seq_id;
+    auto * logits_out = lctx.logits;
 
-    if (batch.pos == nullptr) {
-        pos.resize(n_tokens);
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
-        }
+#ifndef NDEBUG
+    auto & logits_valid = lctx.logits_valid;
+    logits_valid.clear();
+    logits_valid.resize(n_tokens_all);
 
-        batch.pos = pos.data();
-    }
+    memset(logits_out, 0, lctx.logits_size*sizeof(float));
+#endif
 
-    if (batch.seq_id == nullptr) {
-        n_seq_id.resize(n_tokens);
-        seq_id.resize(n_tokens);
-        seq_id_arr.resize(n_tokens);
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            n_seq_id[i] = 1;
-            seq_id[i].resize(1);
-            seq_id[i][0] = batch.all_seq_id;
-            seq_id_arr[i] = seq_id[i].data();
-        }
+    const auto n_ubatch = cparams.n_ubatch;
 
-        batch.n_seq_id = n_seq_id.data();
-        batch.seq_id = seq_id_arr.data();
-    }
+    std::vector<llama_pos> pos;
+    std::vector<int32_t>                   n_seq_id;
+    std::vector<llama_seq_id *>            seq_id_arr;
+    std::vector<std::vector<llama_seq_id>> seq_id;
 
-    // non-causal masks do not use the KV cache
-    if (hparams.causal_attn) {
-        llama_kv_cache_update(&lctx);
+    for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
+        const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
+        llama_batch u_batch = {
+            /* .n_tokens   = */ (int32_t) n_tokens,
+            /* .token      = */ batch_all.token     ? batch_all.token    + cur_token        : nullptr,
+            /* .embd       = */ batch_all.embd      ? batch_all.embd     + cur_token*n_embd : nullptr,
+            /* .pos        = */ batch_all.pos       ? batch_all.pos      + cur_token        : nullptr,
+            /* .n_seq_id   = */ batch_all.n_seq_id  ? batch_all.n_seq_id + cur_token        : nullptr,
+            /* .seq_id     = */ batch_all.seq_id    ? batch_all.seq_id   + cur_token        : nullptr,
+            /* .logits     = */ batch_all.logits    ? batch_all.logits   + cur_token        : nullptr,
+            /* .all_pos_0  = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
+            /* .all_pos_1  = */ batch_all.all_pos_1,
+            /* .all_seq_id = */ batch_all.all_seq_id,
+        };
 
-        // if we have enough unused cells before the current head ->
-        //   better to start searching from the beginning of the cache, hoping to fill it
-        if (kv_self.head > kv_self.used + 2*n_tokens) {
-            kv_self.head = 0;
-        }
+        int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+        GGML_ASSERT(n_threads > 0);
 
-        if (!llama_kv_cache_find_slot(kv_self, batch)) {
-            return 1;
-        }
+        // helpers for smoother batch API transition
+        // after deprecating the llama_eval calls, these will be removed
+        if (u_batch.pos == nullptr) {
+            pos.resize(n_tokens);
+            for (uint32_t i = 0; i < n_tokens; i++) {
+                pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
+            }
 
-        if (!kv_self.recurrent) {
-            // a heuristic, to avoid attending the full cache if it is not yet utilized
-            // after enough generations, the benefit from this heuristic disappears
-            // if we start defragmenting the cache, the benefit from this will be more important
-            kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
-            //kv_self.n = llama_kv_cache_cell_max(kv_self);
+            u_batch.pos = pos.data();
         }
-    }
-
-    //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
-    ggml_backend_sched_reset(lctx.sched);
-    ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
+        if (u_batch.seq_id == nullptr) {
+            n_seq_id.resize(n_tokens);
+            seq_id.resize(n_tokens);
+            seq_id_arr.resize(n_tokens);
+            for (uint32_t i = 0; i < n_tokens; i++) {
+                n_seq_id[i] = 1;
+                seq_id[i].resize(1);
+                seq_id[i][0] = u_batch.all_seq_id;
+                seq_id_arr[i] = seq_id[i].data();
+            }
 
-    ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
+            u_batch.n_seq_id = n_seq_id.data();
+            u_batch.seq_id = seq_id_arr.data();
+        }
 
-    // the output is always the last tensor in the graph
-    struct ggml_tensor * res  = gf->nodes[gf->n_nodes - 1];
-    struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
+        // non-causal masks do not use the KV cache
+        if (hparams.causal_attn) {
+            llama_kv_cache_update(&lctx);
 
-    if (!hparams.causal_attn) {
-        res = nullptr; // do not extract logits for embedding models such as BERT
+            // if we have enough unused cells before the current head ->
+            //   better to start searching from the beginning of the cache, hoping to fill it
+            if (kv_self.head > kv_self.used + 2*n_tokens) {
+                kv_self.head = 0;
+            }
 
-        // token or sequence embeddings
-        embd = gf->nodes[gf->n_nodes - 1];
+            if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
+                return 1;
+            }
 
-        GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
-    } else {
-        if (strcmp(res->name, "result_output") == 0) {
-            // the token embeddings could be the second to last tensor, or the third to last tensor
-            if (strcmp(embd->name, "result_norm") != 0) {
-                embd = gf->nodes[gf->n_nodes - 3];
-                GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
+            if (!kv_self.recurrent) {
+                // a heuristic, to avoid attending the full cache if it is not yet utilized
+                // after enough generations, the benefit from this heuristic disappears
+                // if we start defragmenting the cache, the benefit from this will be more important
+                kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+                //kv_self.n = llama_kv_cache_cell_max(kv_self);
             }
-        } else {
-            GGML_ASSERT(false && "missing result_output tensor");
         }
-    }
 
-    // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+        //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
-    // for big prompts, if BLAS is enabled, it is better to use only one thread
-    // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
-    // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
-    //       we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
-    //       with the BLAS calls. need a better solution
-    // MoE Special Case: This logic applies when hparams.n_expert == 0, i.e. the model is NOT an MoE model. When an MoE is
-    //                   being processed then Accelerate/BLAS will not be involved, so capping would limit performance.
-    if (n_tokens >= 32 && hparams.n_expert == 0 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
-        n_threads = std::min(4, n_threads);
-    }
+        ggml_backend_sched_reset(lctx.sched);
+        ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
-    llama_set_inputs(lctx, batch);
+        ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
 
-    llama_graph_compute(lctx, gf, n_threads);
+        // the output is always the last tensor in the graph
+        struct ggml_tensor * res  = gf->nodes[gf->n_nodes - 1];
+        struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
 
-    // update the kv ring buffer
-    {
-        kv_self.head += n_tokens;
-
-        // Ensure kv cache head points to a valid index.
-        if (kv_self.head >= kv_self.size) {
-            kv_self.head = 0;
-        }
-    }
+        if (!hparams.causal_attn) {
+            res = nullptr; // do not extract logits for embedding models such as BERT
 
-    // decide if we need to defrag the kv cache
-    if (cparams.defrag_thold >= 0.0f) {
-        const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
+            // token or sequence embeddings
+            embd = gf->nodes[gf->n_nodes - 1];
 
-        // queue defragmentation for next llama_kv_cache_update
-        if (fragmentation > cparams.defrag_thold) {
-            //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
+            GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
+        } else {
+            if (strcmp(res->name, "result_output") == 0) {
+                // the token embeddings could be the second to last tensor, or the third to last tensor
+                if (strcmp(embd->name, "result_norm") != 0) {
+                    embd = gf->nodes[gf->n_nodes - 3];
+                    GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
+                }
+            } else {
+                GGML_ASSERT(false && "missing result_output tensor");
+            }
+        }
+        // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
-            llama_kv_cache_defrag(kv_self);
+        // for big prompts, if BLAS is enabled, it is better to use only one thread
+        // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
+        // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
+        //       we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
+        //       with the BLAS calls. need a better solution
+        // MoE Special Case: This logic applies when hparams.n_expert == 0, i.e. the model is NOT an MoE model. When an MoE is
+        //                   being processed then Accelerate/BLAS will not be involved, so capping would limit performance.
+        if (n_tokens >= 32 && hparams.n_expert == 0 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
+            n_threads = std::min(4, n_threads);
         }
-    }
 
-#ifdef GGML_PERF
-    // print timing information per ggml operation (for debugging purposes)
-    // requires GGML_PERF to be defined
-    ggml_graph_print(gf);
-#endif
+        ggml_backend_sched_alloc_graph(lctx.sched, gf);
 
-    // plot the computation graph in dot format (for debugging purposes)
-    //if (n_past%100 == 0) {
-    //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
-    //}
+        llama_set_inputs(lctx, u_batch);
 
-    // extract logits
-    // TODO: do not compute and extract logits if only embeddings are needed
-    //       need to update the graphs to skip "result_output"
-    if (res) {
-        auto & logits_out = lctx.logits;
+        llama_graph_compute(lctx, gf, n_threads);
 
-#ifndef NDEBUG
-        auto & logits_valid = lctx.logits_valid;
-        logits_valid.clear();
-        logits_valid.resize(n_tokens);
+        // update the kv ring buffer
+        {
+            kv_self.head += n_tokens;
 
-        logits_out.clear();
-#endif
+            // Ensure kv cache head points to a valid index.
+            if (kv_self.head >= kv_self.size) {
+                kv_self.head = 0;
+            }
+        }
 
-        ggml_backend_t backend_res = ggml_backend_sched_get_node_backend(lctx.sched, res);
-        GGML_ASSERT(backend_res != nullptr);
+#ifdef GGML_PERF
+        // print timing information per ggml operation (for debugging purposes)
+        // requires GGML_PERF to be defined
+        ggml_graph_print(gf);
+#endif
 
-        if (batch.logits) {
-            logits_out.resize(n_vocab * n_tokens);
-            int32_t i_first = -1;
-            for (uint32_t i = 0; i < n_tokens; i++) {
-                if (batch.logits[i] && i_first == -1) {
-                    i_first = (int32_t) i;
-                }
-                if (batch.logits[i] == 0 || i == n_tokens - 1) {
-                    if (i_first != -1) {
-                        int i_last = batch.logits[i] == 0 ? i : i + 1;
-                        // extract logits for the range [i_first, i_last)
-                        // group the requests to minimize the number of calls to the backend
-                        ggml_backend_tensor_get_async(backend_res, res,
-                            logits_out.data() + (n_vocab*i_first),
-                            (n_vocab*i_first)*sizeof(float),
-                            (i_last - i_first)*n_vocab*sizeof(float));
-                        i_first = -1;
+        // plot the computation graph in dot format (for debugging purposes)
+        //if (n_past%100 == 0) {
+        //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
+        //}
+
+        // extract logits
+        // TODO: do not compute and extract logits if only embeddings are needed
+        //       update the graphs to skip "result_output" if logits are not needed
+        if (res) {
+            ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
+            GGML_ASSERT(backend_res != nullptr);
+            if (u_batch.logits) {
+                int32_t i_first = -1;
+                for (uint32_t i = 0; i < n_tokens; i++) {
+                    if (u_batch.logits[i] && i_first == -1) {
+                        i_first = (int32_t) i;
+                    }
+                    if (u_batch.logits[i] == 0 || i == n_tokens - 1) {
+                        if (i_first != -1) {
+                            int i_last = u_batch.logits[i] == 0 ? i : i + 1;
+                            // extract logits for the range [i_first, i_last)
+                            // group the requests to minimize the number of calls to the backend
+                            ggml_backend_tensor_get_async(backend_res, res,
+                                logits_out + n_vocab*(cur_token + i_first),
+                                i_first*n_vocab*sizeof(float),
+                                (i_last - i_first)*n_vocab*sizeof(float));
+                            i_first = -1;
+                        }
                     }
-                }
 #ifndef NDEBUG
-                logits_valid[i] = batch.logits[i] != 0;
+                    logits_valid[cur_token + i] = u_batch.logits[i] != 0;;
 #endif
-            }
-        } else if (lctx.logits_all) {
-            logits_out.resize(n_vocab*n_tokens);
-            ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
+                }
+            } else if (lctx.logits_all) {
+                ggml_backend_tensor_get_async(backend_res, res, logits_out + n_vocab*cur_token, 0, n_vocab*n_tokens*sizeof(float));
 #ifndef NDEBUG
-            std::fill(logits_valid.begin(), logits_valid.end(), true);
+                std::fill(logits_valid.begin() + cur_token, logits_valid.begin() + cur_token + n_tokens, true);
 #endif
-        } else {
-            logits_out.resize(n_vocab);
-            ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
+            } else {
+                if (cur_token + n_tokens >= n_tokens_all) {
+                    ggml_backend_tensor_get_async(backend_res, res, logits_out, n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float));
 #ifndef NDEBUG
-            logits_valid[0] = true;
+                    logits_valid[0] = true;
 #endif
+                }
+            }
         }
-        ggml_backend_synchronize(backend_res);
-    }
 
-    // extract embeddings
-    if (cparams.embeddings && embd) {
-        ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
-        GGML_ASSERT(backend_embd != nullptr);
-
-        switch (cparams.pooling_type) {
-            case LLAMA_POOLING_TYPE_NONE:
-                {
-                    // extract token embeddings
-                    auto & embd_out = lctx.embd;
+        // extract embeddings
+        if (cparams.embeddings && embd) {
+            ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
+            GGML_ASSERT(backend_embd != nullptr);
 
-                    if (batch.logits) {
-                        embd_out.resize(n_embd * n_tokens);
-                        for (uint32_t i = 0; i < n_tokens; i++) {
-                            if (batch.logits[i] == 0) {
-                                continue;
+            switch (cparams.pooling_type) {
+                case LLAMA_POOLING_TYPE_NONE:
+                    {
+                        // extract token embeddings
+                        auto & embd_out = lctx.embd;
+
+                        if (u_batch.logits) {
+                            //embd_out.resize(n_embd * n_tokens);
+                            for (uint32_t i = 0; i < n_tokens; i++) {
+                                if (u_batch.logits[i] == 0) {
+                                    continue;
+                                }
+                                ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
                             }
-
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
                         }
-                    }
-                } break;
-            case LLAMA_POOLING_TYPE_CLS:
-            case LLAMA_POOLING_TYPE_MEAN:
-                {
-                    GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
+                    } break;
+                case LLAMA_POOLING_TYPE_CLS:
+                case LLAMA_POOLING_TYPE_MEAN:
+                    {
+                        GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
 
-                    // extract sequence embeddings
-                    auto & embd_seq_out = lctx.embd_seq;
-                    embd_seq_out.clear();
+                        // extract sequence embeddings
+                        auto & embd_seq_out = lctx.embd_seq;
+                        embd_seq_out.clear();
 
-                    for (uint32_t i = 0; i < n_tokens; i++) {
-                        const llama_seq_id seq_id = batch.seq_id[i][0];
-                        if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                            continue;
+                        for (uint32_t i = 0; i < n_tokens; i++) {
+                            const llama_seq_id seq_id = u_batch.seq_id[i][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(n_embd);
+                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
                         }
-                        embd_seq_out[seq_id].resize(n_embd);
-                        ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
-                    }
-                } break;
-            case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                {
-                    GGML_ASSERT(false && "unknown pooling type");
-                } break;
+                    } break;
+                case LLAMA_POOLING_TYPE_UNSPECIFIED:
+                    {
+                        GGML_ASSERT(false && "unknown pooling type");
+                    } break;
+            }
         }
-        ggml_backend_synchronize(backend_embd);
     }
 
-    // measure the performance only for the single-token evals
-    if (n_tokens == 1) {
-        lctx.t_eval_us += ggml_time_us() - t_start_us;
-        lctx.n_eval++;
-    }
-    else if (n_tokens > 1) {
-        lctx.t_p_eval_us += ggml_time_us() - t_start_us;
-        lctx.n_p_eval += n_tokens;
-    }
+    // wait for the computation to finish (automatically done when obtaining the model output)
+    //llama_synchronize(&lctx);
 
-    // get a more accurate load time, upon first eval
-    // TODO: fix this
-    if (!lctx.has_evaluated_once) {
-        lctx.t_load_us = ggml_time_us() - lctx.t_start_us;
-        lctx.has_evaluated_once = true;
+    // decide if we need to defrag the kv cache
+    if (cparams.defrag_thold >= 0.0f) {
+        const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens_all)/float(kv_self.n) : 0.0f;
+
+        // queue defragmentation for next llama_kv_cache_update
+        if (fragmentation > cparams.defrag_thold) {
+            //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
+
+            llama_kv_cache_defrag(kv_self);
+        }
     }
 
     return 0;
 }
 
+
 // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
 static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     auto & kv_self = lctx.kv_self;
@@ -9242,6 +9250,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
 #else
     // ggml_graph defrag
 
+    ggml_backend_sched_reset(lctx.sched);
+
     ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
 
     llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
@@ -9253,14 +9263,22 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
 }
 
 static void llama_kv_cache_update_internal(struct llama_context & lctx) {
+    bool need_reserve = false;
+
     // apply K-shift if needed
     if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
-        llama_set_k_shift(lctx);
-
         {
+            ggml_backend_sched_reset(lctx.sched);
+
             ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
 
+            ggml_backend_sched_alloc_graph(lctx.sched, gf);
+
+            llama_set_k_shift(lctx);
+
             llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+
+            need_reserve = true;
         }
 
         {
@@ -9275,12 +9293,18 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     }
 
     if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
-        llama_set_s_copy(lctx);
-
         {
+            ggml_backend_sched_reset(lctx.sched);
+
             ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
 
+            ggml_backend_sched_alloc_graph(lctx.sched, gf);
+
+            llama_set_s_copy(lctx);
+
             llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+
+            need_reserve = true;
         }
 
         {
@@ -9298,8 +9322,26 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     if (lctx.kv_self.do_defrag) {
         llama_kv_cache_defrag_internal(lctx);
 
+        need_reserve = true;
+
         lctx.kv_self.do_defrag = false;
     }
+
+    // reserve a worst case graph again
+    if (need_reserve) {
+        // TODO: extract to a function
+        // build worst-case graph
+        int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
+        int n_past = lctx.cparams.n_ctx - n_tokens;
+        llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+        ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
+
+        // initialize scheduler with the worst-case graph
+        ggml_backend_sched_reset(lctx.sched);
+        if (!ggml_backend_sched_reserve(lctx.sched, gf)) {
+            LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+        }
+    }
 }
 
 //
@@ -12537,7 +12579,8 @@ struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
         /*.seed                        =*/ LLAMA_DEFAULT_SEED,
         /*.n_ctx                       =*/ 512,
-        /*.n_batch                     =*/ 512,
+        /*.n_batch                     =*/ 2048,
+        /*.n_ubatch                    =*/ 512,
         /*.n_seq_max                   =*/ 1,
         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
@@ -12691,6 +12734,17 @@ struct llama_context * llama_new_context_with_model(
         struct llama_context_params   params) {
 
     if (!model) {
+        LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
+        return nullptr;
+    }
+
+    if (params.n_batch == 0 && params.n_ubatch == 0) {
+        LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
+        return nullptr;
+    }
+
+    if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
+        LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
         return nullptr;
     }
 
@@ -12699,7 +12753,6 @@ struct llama_context * llama_new_context_with_model(
     const auto & hparams = model->hparams;
     auto       & cparams = ctx->cparams;
 
-    cparams.n_batch          = params.n_batch;
     // TODO: maybe add n_seq_max here too
     cparams.n_threads        = params.n_threads;
     cparams.n_threads_batch  = params.n_threads_batch;
@@ -12716,6 +12769,11 @@ struct llama_context * llama_new_context_with_model(
     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
     cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
 
+    // with causal attention, the batch size is limited by the context size
+    cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
+    cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
+
+
     cparams.n_yarn_orig_ctx  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
                                hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
                                                               hparams.n_ctx_train;
@@ -12751,6 +12809,8 @@ struct llama_context * llama_new_context_with_model(
     }
 
     LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
+    LLAMA_LOG_INFO("%s: n_batch    = %u\n",     __func__, cparams.n_batch);
+    LLAMA_LOG_INFO("%s: n_ubatch   = %u\n",     __func__, cparams.n_ubatch);
     LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
 
@@ -12895,54 +12955,31 @@ struct llama_context * llama_new_context_with_model(
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
         }
 
-        // resized during inference, reserve maximum
-        ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
+        // graph outputs buffer
+        {
+            // resized during inference, reserve maximum
+            ctx->logits_size = hparams.n_vocab*cparams.n_batch;
+            ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0;
 
-        if (params.embeddings) {
-            ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
-        }
+            const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);
 
-        // graph inputs
-        {
-            ggml_init_params init_params = {
-                /* .mem_size   */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)),
-                /* .mem_buffer */ nullptr,
-                /* .no_alloc   */ true,
-            };
-            ctx->ctx_input = ggml_init(init_params);
-
-            ctx->inp_tokens  = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
-            ctx->inp_embd    = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
-            ctx->inp_pos     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
-            ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, kv_size, cparams.n_batch);
-            ctx->inp_KQ_pos  = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
-            ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
-            ctx->inp_mean    = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
-            ctx->inp_cls     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
-            if (ctx->kv_self.recurrent) {
-                ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
-                ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
-                ctx->inp_s_seq  = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
-            }
-
-            ggml_set_name(ctx->inp_tokens,  "inp_tokens");
-            ggml_set_name(ctx->inp_embd,    "inp_embd");
-            ggml_set_name(ctx->inp_pos,     "inp_pos");
-            ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
-            ggml_set_name(ctx->inp_KQ_pos,  "inp_KQ_pos");
-            ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
-            ggml_set_name(ctx->inp_mean,    "inp_mean");
-            ggml_set_name(ctx->inp_cls,     "inp_cls");
-            if (ctx->kv_self.recurrent) {
-                ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
-                ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
-                ggml_set_name(ctx->inp_s_seq,  "inp_s_seq");
-            }
-
-            ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
-            LLAMA_LOG_INFO("%s: %10s input buffer size   = %8.2f MiB\n", __func__,
-                    ggml_backend_buffer_name(ctx->buf_input),
-                    ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
+            ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
+            if (ctx->buf_output == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to allocate logits buffer\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ggml_backend_buffer_clear(ctx->buf_output, 0);
+
+
+            ctx->logits = (float *) ggml_backend_buffer_get_base(ctx->buf_output);
+            if (params.embeddings) {
+                ctx->embd = ctx->logits + ctx->logits_size;
+            }
+
+            LLAMA_LOG_INFO("%s: %10s  output buffer size = %8.2f MiB\n", __func__,
+                    ggml_backend_buffer_name(ctx->buf_output),
+                    ggml_backend_buffer_get_size(ctx->buf_output) / 1024.0 / 1024.0);
         }
 
         // scheduler and compute buffers
@@ -12961,10 +12998,21 @@ struct llama_context * llama_new_context_with_model(
             // buffer used to store the computation graph and the tensor meta data
             ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
 
-            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
+            // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
+            bool pipeline_parallel = llama_get_device_count() > 1 && model->n_gpu_layers > (int)model->hparams.n_layer && model->split_mode == LLAMA_SPLIT_MODE_LAYER;
+#ifndef GGML_USE_CUBLAS
+            // pipeline parallelism requires support for async compute and events
+            // currently this is only implemented in the CUDA backend
+            pipeline_parallel = false;
+#endif
+            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES, pipeline_parallel);
+
+            if (pipeline_parallel) {
+                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));
+            }
 
             // build worst-case graph
-            int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
+            int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
             int n_past = cparams.n_ctx - n_tokens;
             llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
             ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
@@ -12987,7 +13035,7 @@ struct llama_context * llama_new_context_with_model(
 
             // note: the number of splits during measure is higher than during inference due to the kv shift
             int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
-            LLAMA_LOG_INFO("%s: graph splits (measure): %d\n", __func__, n_splits);
+            LLAMA_LOG_INFO("%s: graph splits: %d\n", __func__, n_splits);
         }
     }
 
@@ -13024,6 +13072,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
     return ctx->cparams.n_batch;
 }
 
+uint32_t llama_n_ubatch(const struct llama_context * ctx) {
+    return ctx->cparams.n_ubatch;
+}
+
 uint32_t llama_n_seq_max(const struct llama_context * ctx) {
     return ctx->kv_self.size;
 }
@@ -13347,9 +13399,9 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
     const size_t s_rng             = LLAMA_MAX_RNG_STATE;
     const size_t s_logits_size     = sizeof(size_t);
     // assume worst case for logits although only currently set ones are serialized
-    const size_t s_logits          = ctx->logits.capacity() * sizeof(float);
+    const size_t s_logits          = ctx->logits_size * sizeof(float);
     const size_t s_embedding_size  = sizeof(size_t);
-    const size_t s_embedding       = ctx->embd.capacity() * sizeof(float);
+    const size_t s_embedding       = ctx->embd_size * sizeof(float);
     const size_t s_kv_buf_size     = sizeof(size_t);
     const size_t s_kv_head         = sizeof(uint32_t);
     const size_t s_kv_size         = sizeof(uint32_t);
@@ -13447,23 +13499,23 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
 
     // copy logits
     {
-        const size_t logits_size = ctx->logits.size();
+        const size_t logits_size = ctx->logits_size;
 
         data_ctx->write(&logits_size, sizeof(logits_size));
 
         if (logits_size) {
-            data_ctx->write(ctx->logits.data(), logits_size * sizeof(float));
+            data_ctx->write(ctx->logits, logits_size * sizeof(float));
         }
     }
 
     // copy embeddings
     {
-        const size_t embeddings_size = ctx->embd.size();
+        const size_t embeddings_size = ctx->embd_size;
 
         data_ctx->write(&embeddings_size, sizeof(embeddings_size));
 
         if (embeddings_size) {
-            data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
+            data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
         }
     }
 
@@ -13566,12 +13618,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
 
         memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
 
-        GGML_ASSERT(ctx->logits.capacity() >= logits_size);
+        GGML_ASSERT(ctx->logits_size >= logits_size);
 
         if (logits_size) {
-            ctx->logits.resize(logits_size);
-
-            memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
+            memcpy(ctx->logits, inp, logits_size * sizeof(float));
             inp += logits_size * sizeof(float);
         }
     }
@@ -13582,12 +13632,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
 
         memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
 
-        GGML_ASSERT(ctx->embd.capacity() == embeddings_size);
+        GGML_ASSERT(ctx->embd_size == embeddings_size);
 
         if (embeddings_size) {
-            ctx->embd.resize(embeddings_size);
-
-            memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
+            memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
             inp += embeddings_size * sizeof(float);
         }
     }
@@ -13842,24 +13890,61 @@ int32_t llama_decode(
     return ret;
 }
 
+void llama_synchronize(struct llama_context * ctx) {
+    ggml_backend_sched_synchronize(ctx->sched);
+
+    // FIXME: if multiple single tokens are evaluated without a synchronization,
+    // the stats will be added to the prompt evaluation stats
+    // this should only happen when using batch size 1 to evaluate a batch
+
+    // add the evaluation to the stats
+    if (ctx->n_queued_tokens == 1) {
+        ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+        ctx->n_eval++;
+    } else if (ctx->n_queued_tokens > 1) {
+        ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+        ctx->n_p_eval += ctx->n_queued_tokens;
+    }
+
+    // get a more accurate load time, upon first eval
+    if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
+        ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
+        ctx->has_evaluated_once = true;
+    }
+
+    ctx->n_queued_tokens = 0;
+    ctx->t_compute_start_us = 0;
+}
+
 float * llama_get_logits(struct llama_context * ctx) {
-    return ctx->logits.data();
+    llama_synchronize(ctx);
+
+    return ctx->logits;
 }
 
 float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
     assert(ctx->logits_valid.at(i));
-    return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
+
+    llama_synchronize(ctx);
+
+    return ctx->logits + i*ctx->model.hparams.n_vocab;
 }
 
 float * llama_get_embeddings(struct llama_context * ctx) {
-    return ctx->embd.data();
+    llama_synchronize(ctx);
+
+    return ctx->embd;
 }
 
 float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
-    return ctx->embd.data() + i*ctx->model.hparams.n_embd;
+    llama_synchronize(ctx);
+
+    return ctx->embd + i*ctx->model.hparams.n_embd;
 }
 
 float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
+    llama_synchronize(ctx);
+
     auto it = ctx->embd_seq.find(seq_id);
     if (it == ctx->embd_seq.end()) {
         return nullptr;
diff --git a/llama.h b/llama.h
index 446899da6e38ea6299d54becd71d0db0686bc48b..2d16cc9b9fa2c2b9c6847986f386bb8aceaa025f 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -234,7 +234,8 @@ extern "C" {
     struct llama_context_params {
         uint32_t seed;              // RNG seed, -1 for random
         uint32_t n_ctx;             // text context, 0 = from model
-        uint32_t n_batch;           // prompt processing maximum batch size
+        uint32_t n_batch;           // logical maximum batch size that can be submitted to llama_decode
+        uint32_t n_ubatch;          // physical maximum batch size
         uint32_t n_seq_max;         // max number of sequences (i.e. distinct states for recurrent models)
         uint32_t n_threads;         // number of threads to use for generation
         uint32_t n_threads_batch;   // number of threads to use for batch processing
@@ -377,6 +378,7 @@ extern "C" {
 
     LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
+    LLAMA_API uint32_t llama_n_ubatch   (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_seq_max  (const struct llama_context * ctx);
 
     LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@@ -650,6 +652,11 @@ extern "C" {
     // Set abort callback
     LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
 
+    // Wait until all computations are finished
+    // This is automatically done when using one of the functions below to obtain the computation results
+    // and is not necessary to call it explicitly in most cases
+    LLAMA_API void llama_synchronize(struct llama_context * ctx);
+
     // Token logits obtained from the last call to llama_decode()
     // The logits for the last token are stored in the last row
     // Logits for which llama_batch.logits[i] == 0 are undefined